diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d0c2d176189273f3d6ade871cbaffa3355f9c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_awaits/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1deda4b991fec436d50dd5943d7640fa28b5ec2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59533c9256b8de7a7b5f05db6b866d636403ba88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16de02d295da7c573198e3e86f30b65e7ca2dd13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cab6e2e62d744d63b29a43732afe68afe685ab0f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696c1c1d5e2c37a2f6b951f6d97f0326f7b73029 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dab64df00f789e9321207098666f3b0638a9f79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d3393056d3c88ca7828d59b9c5a98a0922b51c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d61218acb127cca92738cbae578ba50cfec57919 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9619af0c7eac563221dc4b3775d0236c63a31914 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b49dca7ec8dcab21d75a7f33c7cfcf33ee15f646 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b524d62b2c3b64b67f38db7964719323a8f6677 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf9aae4f82f9f9e3b25baf68cf8d6ea654f5fa2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1cc5c061c0dd1234186cfaade006c315a34c1e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71255af872b3ff1c73e49df4288cf557132886de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/case.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1d88fc6adc9f79ecdc201bb78c0c7dd644d345e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/gen_example.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd7fc6be72a7821e9d2096663cd5757576c08a6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/__pycache__/logging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/case.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..048a71cd6c16a205bbe9d7f845369b93f6a02f2e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/case.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional +from types import ModuleType + +import torch + +_TAGS: dict[str, dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +ArgsType = tuple[Any, ...] + + +def check_inputs_type(args, kwargs): + if not isinstance(args, tuple): + raise ValueError( + f"Expecting args type to be a tuple, got: {type(args)}" + ) + if not isinstance(kwargs, dict): + raise ValueError( + f"Expecting kwargs type to be a dict, got: {type(kwargs)}" + ) + for key in kwargs: + if not isinstance(key, str): + raise ValueError( + f"Expecting kwargs keys to be a string, got: {type(key)}" + ) + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_args: ArgsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + example_kwargs: dict[str, Any] = field(default_factory=dict) + extra_args: Optional[ArgsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_args, self.example_kwargs) + if self.extra_args is not None: + check_inputs_type(self.extra_args, {}) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: dict[str, ExportCase] = {} +_MODULES: set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not isinstance(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + # pyrefly: ignore [bad-argument-type] + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + module_name = module.__name__.split(".")[-1] + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_args"] = parent.example_args + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..834dbce32f10bfb339fd2182a2455b43914441c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import dataclasses +import glob +import inspect +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, + export_case, + ExportCase, +) + + +def _collect_examples(): + case_names = glob.glob(join(dirname(__file__), "*.py")) + case_names = [ + basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py") + ] + + case_fields = {f.name for f in dataclasses.fields(ExportCase)} + for case_name in case_names: + case = __import__(case_name, globals(), locals(), [], 1) + variables = [name for name in dir(case) if name in case_fields] + export_case(**{v: getattr(case, v) for v in variables})(case.model) + +_collect_examples() + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07a6fdd56a4efa92a93af614bf3481cac92ea728 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654731331c67de888953db1dfaa44b23b912f0be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eaa24c94a8f7386f447b74966a411d18e9f177a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9f497144da72884464353a2f3f9284541a125d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d949eb1d94aa03db3b762ac904353d312c30524a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8938855b41f95e5bc00f313c7f12739ffe0af9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354dd0a4b252a70fd2a6e355c91b0076b9b5b58c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d5f14bc24f9d90550a302c744f524a234ccbc5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1ac07341d4354dbba8ceb90c5ef904f25a08f81 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6184f88421528d8d2f6520496225636923786e21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5ffae2864fdb40a03337b1ee36b0a653f66826 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d927aa0f9a2b7fef662348217b69c725eee7329 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1690ccffdc1abe7cbdb68f4845e8630549c9900a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d581f2d11319cc54259ff2e84dcce96bf1e1255 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c396a17dc3504f7a2a2dff46c5861bad1d4c7ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df540f06682f1f04396494f685532d42224f959 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e7496dcda8cfa057c50890f7d0ee91afac752e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d9c1ed3afba1b66d223ae7697a86af775071e3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..935d71d81e6b0781f7bf1396c1b60ac5c3a1d770 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f63fbd3d9edc16bd3b86356236ca6a74c1e7314 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a17df12a89b821b5ee069eb0df36681b75245d67 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbf916f2d5148889575937f0cf0d6a361734ff98 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..670c9646ec266cbb77888e4e03de5724744ce56c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbd77f0bce0157bf4dfed98968ec601c3280d9bd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4097d021db0bb21e7aa8b1ffbe88c7d5586ea4fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bbc5b3dab5f42afe5b10bbb45de4058069d52a3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b659f9c8c694bfb1ba2853e54967cf1451c5a9f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..029471a4876ee4df97f7cabf3d4db0ea45139c62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645941da6070a97465571c89ce956783dcc8d731 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77c54b062bdb47a796af96db8b99aaf6fb23cbb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30677bde44858d3acbeef740d94e84665dacdd52 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..240839e4a94bb08fc175ee4a92a1922c16db4b7d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79acf22b5d92476eee67df30c89c9dd052ad0dbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce3e5f990d8d789cbef655287a6dc796c12bac1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18185f58291b839dfbdc718e34a63539a462b9e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de28d5390e20c53a7e660adf370dc4924609a08c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cd15beba923261bfea7ff4d892ecbee90723ce4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..931ce7f7a50fc5a175101ac57c424c88cf31b54c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +import torch._dynamo as torchdynamo + + +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"torch.escape-hatch"} +model = AssumeConstantResult() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..efd645d13a7d5a13dc69d9ab3593772520b726c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, x): + return x.clone() + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return grad_output + 1 + +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) + +example_args = (torch.randn(3, 2),) +model = AutogradFunction() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..f701f54d4f4ea1cb5816292cd60bb4df3d03c5e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) + +example_args = (torch.randn(3, 4),) +model = ClassMethod() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..22600cc504348d1d261b0ea2b9ed2d57af76b0a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchClassMethod() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b28ceeddc7956d136a8cf786c283344731d3e7ac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNestedFunction() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..50d0ec87a690d063cb0e841fc057a6ae95c369fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) + +example_args = (torch.randn(6),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNonlocalVariables() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..183180ab4fc825385170fea2bec6af184374a67e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) + +example_args = (torch.tensor(True), torch.randn(3, 2)) +tags = {"torch.cond", "python.closure"} +model = CondClosedOverVariable() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..60a75d24639cdac991298e99acf96b8eadbff442 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +y = torch.randn(2) +dim0_x = Dim("dim0_x") + +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) + +example_args = (x, y) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +extra_inputs = (torch.randn(2, 2), torch.randn(2)) +dynamic_shapes = {"x": {0: dim0_x}, "y": None} +model = CondOperands() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..68bb8850bba909a0c6546c8f12a1a3fa1bdc70d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) + +example_args = (torch.randn(6, 4, 3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondPredicate() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..934746aaf6739de7a37077d8ec3c2776586a5657 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check APIs. + """ + + def forward(self, x): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + return torch.zeros((a, 5)) + + +example_args = (torch.tensor(4),) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsSizeExample() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..22f791a3e80474257c27d927bad56cf4c2fbce78 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check API. + """ + + def forward(self, x, y): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + + if a < 6: + return y.sin() + return y.cos() + + +example_args = (torch.tensor(4), torch.randn(5, 5)) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsValueExample() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..7d24cc681a6b62adf40bfd9a2e5283afb3515131 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import functools + +import torch + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y + +example_args = (torch.randn(3, 2), torch.randn(3, 2)) +model = Decorator() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..49e688bc0ac1f09567e3b877aaca29a1d02b4121 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"python.data-structure"} +model = Dictionary() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..cc822e5553e1ab8bd350a26966c22f1a9a1698cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x + +example_args = (torch.randn(3, 2),) +tags = {"python.assert"} +model = DynamicShapeAssert() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..157e460274ad58ba71c886b35364ddc0cd4d886a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + + def forward(self, x): + return torch.zeros(x.shape[0] * 2) + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeConstructor() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..21824ef3a0f66eb25f4d8e8c1ba92f53fdd4c275 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() + +example_args = (torch.randn(3, 2, 2),) +tags = {"torch.dynamic-shape", "python.control-flow"} +model = DynamicShapeIfGuard() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..f8066aed556b9ee588b9744d17ba16c35d8fed6c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import map + +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"torch.dynamic-shape", "torch.map"} +model = DynamicShapeMap() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..decbf036553cb76544a19e531e5aee98d792ae0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import torch + +from torch._export.db.case import SupportLevel +from torch.export import Dim + +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def forward(self, x): + return x[: round(x.shape[0] / 2)] + +x = torch.randn(3, 2) +dim0_x = Dim("dim0_x") +example_args = (x,) +tags = {"torch.dynamic-shape", "python.builtin"} +support_level = SupportLevel.NOT_SUPPORTED_YET +dynamic_shapes = {"x": {0: dim0_x}} +model = DynamicShapeRound() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..360fe15f6f98d9d735366bfa53371d79e0b00209 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeSlicing() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py new file mode 100644 index 0000000000000000000000000000000000000000..c45d4aeebb0282a0f56c58a587b4bfe1655f50e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_view.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeView(torch.nn.Module): + """ + Dynamic shapes should be propagated to view arguments instead of being + baked into the exported graph. + """ + + def forward(self, x): + new_x_shape = x.size()[:-1] + (2, 5) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1) + +example_args = (torch.randn(10, 10),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeView() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..46b2637b398c21bf9399d0a3fa2a964354beea3e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out + +example_args = ( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)] +) +example_kwargs = { + "mykw0": torch.randn(4), + "input0": torch.randn(4), + "input1": torch.randn(4), +} +tags = {"python.data-structure"} +model = FnWithKwargs() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..35a140f4ee2e5d6f42c3509984333db896f1c081 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} +model = ListContains() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..98533cfab5498934a61fbe693bb2497d5dbc9738 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs + +import torch + +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def forward(self, args: list[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] + +example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) +tags = {"python.control-flow", "python.data-structure"} +model = ListUnpack() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..122b0ddfc3429fb31415a146e8e1dcfddb2fe031 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,24 @@ +# mypy: allow-untyped-defs +import torch + + +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation raises a warning. Covered in the test_export.py test_detect_leak_strict test. + """ + + def __init__(self) -> None: + super().__init__() + self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() + + +example_args = (torch.randn(3, 2),) +tags = {"python.object-model"} +model = ModelAttrMutation() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..e4076ac14dada40b4d78812666a9ec6b5e67045b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"python.closure"} +model = NestedFunction() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..80d09f68097edbe676077be183711dabe5cbc664 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() + +example_args = (torch.randn(3, 2),) +tags = {"python.context-manager"} +model = NullContextManager() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..41e66a7c977a83bf59116864c7f443387277f06e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.randn(2, 3)): + if y is not None: + return x + y + return x + + +example_args = (torch.randn(2, 3),) +tags = {"python.object-model"} +support_level = SupportLevel.SUPPORTED +model = OptionalInput() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..804e73c5a6d58ad5b5be179bf67a5d5bc38c2e2b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +from torch.utils import _pytree as pytree + +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + + def forward(self, x): + y, _spec = pytree.tree_flatten(x) + return y[0] + 1 + +example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), +model = PytreeFlatten() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py new file mode 100644 index 0000000000000000000000000000000000000000..86d3b4645330c47c3625736b695d635f4ab58c70 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/scalar_output.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +dim1_x = Dim("dim1_x") + +class ScalarOutput(torch.nn.Module): + """ + Returning scalar values from the graph is supported, in addition to Tensor + outputs. Symbolic shapes are captured and rank is specialized. + """ + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x.shape[1] + 1 + +example_args = (x,) +tags = {"torch.dynamic-shape"} +dynamic_shapes = {"x": {1: dim1_x}} +model = ScalarOutput() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..f17092f9afc681b91a982a8a2479ac1dde4f455d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +from enum import Enum + +import torch + +class Animal(Enum): + COW = "moo" + +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") + +example_args = (torch.randn(3, 2),) +model = SpecializedAttribute() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..aa62b86d16d9b6a1c539976a891f58bd732ae30d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def forward(self, x): + # constant + ret = [i + x for i in range(10)] + return ret + +example_args = (torch.randn(3, 2),) +tags = {"python.control-flow"} +model = StaticForLoop() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..f169380159a45489142ce5ae3523b2e4504c6721 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x + +example_args = (torch.randn(3, 2, 2),) +tags = {"python.control-flow"} +model = StaticIf() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbc263e7ff2240a3cf8618c56f152e744aa40e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + + +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 + +example_args = (torch.randn(3, 2), "attr") +tags = {"python.builtin"} +model = TensorSetattr() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..99ad42a153c512d65aaae1bcac2377ee1e124f25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class A: + @classmethod + def func(cls, x): + return 1 + x + +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def forward(self, x): + a = A() + return type(a).func(x) + + +example_args = (torch.randn(3, 4),) +tags = {"python.builtin"} +model = TypeReflectionMethod() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a52d80b895b3b2c2d85b878ca4efea511e73ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) + + +example_args = (torch.randn(3, 2),) +tags = {"torch.operator"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = TorchSymMin() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..3156b3a1bf2ec6f6361395de3dacb098ecf20c3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/examples/user_input_mutation.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + + +class UserInputMutation(torch.nn.Module): + """ + Directly mutate user input in forward + """ + + def forward(self, x): + x.mul_(2) + return x.cos() + + +example_args = (torch.randn(3, 2),) +tags = {"torch.mutation"} +model = UserInputMutation() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/gen_example.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..8e44cade322bdde858c5dd05ac116cef47202a33 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,21 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/logging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..9d18a5c0ea08e86095a44240657034ffff3135d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/db/logging.py @@ -0,0 +1,47 @@ +from typing import Optional + +def exportdb_error_message(case_name: str) -> str: + from .examples import all_examples + from torch._utils_internal import log_export_usage + + ALL_EXAMPLES = all_examples() + # Detect whether case_name is really registered in exportdb. + if case_name in ALL_EXAMPLES: + url_case_name = case_name.replace("_", "-") + return f"See {case_name} in exportdb for unsupported case. \ + https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}" + else: + log_export_usage( + event="export.error.casenotregistered", + message=case_name, + ) + return f"{case_name} is unsupported." + + +def get_class_if_classified_error(e: Exception) -> Optional[str]: + """ + Returns a string case name if the export error e is classified. + Returns None otherwise. + """ + + from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError + + ALWAYS_CLASSIFIED = "always_classified" + DEFAULT_CLASS_SIGIL = "case_name" + + # add error types that should be classified, along with any attribute name + # whose presence acts like a sigil to further distinguish which errors of + # that type should be classified. If the attribute name is None, then the + # error type is always classified. + _ALLOW_LIST = { + Unsupported: DEFAULT_CLASS_SIGIL, + UserError: DEFAULT_CLASS_SIGIL, + TorchRuntimeError: None, + } + if type(e) in _ALLOW_LIST: + # pyrefly: ignore [index-error] + attr_name = _ALLOW_LIST[type(e)] + if attr_name is None: + return ALWAYS_CLASSIFIED + return getattr(e, attr_name, None) + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c20fde7baa76c3db59f20fab008be7406518f41d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fdd8dfec2de9c3bdd7a201cd15859a2ee4af38c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..260f051fc8ebc5c5752ccf2c45f1fc43043853f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..9874dc1520fdbd6f4adc061dd7bccee031710797 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: dict[str, Any]) -> None: + self.data: dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..40613c1283228bb5500a93c5b4ca80d6a448ce6d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,45 @@ +# pyre-strict +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar, Union + +import torch + + +_T = TypeVar("_T") + + +class ProxyValue(Generic[_T]): + # pyre-ignore + def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self) -> Iterator[_T]: + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ce2ac03c23600c86ff02e38a2a4bfeefef9e2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e96e5e02a3de4a181b60139e60f4e17d6590a62e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6bd181bd7e7b03a6481d5d612c08cee478c790 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecc2d3cd7839cb87d0ad2f1b6e80bf073a1fa2ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9893435266f578558509c5b917b26cbb8413135f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1dbf21520c84f694b087c3fc9e95ea1d660e65f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885630c7e7b48b143b4c22d9990394b8e3610663 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b7d84b7a460a12416951f30bb80e3a3dab2ddfb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdd12e95ae19b1fda00cfebeab897aa6d6513c65 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a85d70d8815022825ac2ec7292e7293cc93994a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e745d334ae623647721674376ca731ceae87e469 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa9c8a66035a8bed531c3e4d28800475cae4e85 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae50d1a73ac161e8a57f658faee5514549a6d3b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5bdbe2f8220f75514e277565895b3ef0cb69989 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3e14a9008691d5ebf630c5e6e6246db3010d2b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3547a5f73c77485f7cd63f89ecbd13ef8c642e98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.graph_module import GraphModule + + +_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" + + +def _node_metadata_hook( + node: torch.fx.Node, + metadata: Optional[dict[str, Any]] = None, + fake_mode: Optional[FakeTensorMode] = None, +) -> None: + """ + Hook for adding the appropriate metadata to nodes that are created during a + pass using graph.create_node. An example of how to use it: + + ``` + with _set_node_metadata_hook(gm, + functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"}) + ): + pass(gm) + ``` + + This hook should not work for all generic cases -- specifically it assumes + that nodes being added are only call_function nodes, and copies over the + first argument node's nn_module_stack. + """ + # pyrefly: ignore [bad-assignment] + fake_mode = fake_mode or contextlib.nullcontext() + + assert node.op == "call_function" and callable(node.target), ( + f"node: {node}, target: {node.target}" + ) + + if ( + isinstance(node.target, torch._ops.OpOverload) + and len(node.target._schema.returns) == 0 + ): + node.meta["val"] = None + else: + fake_args, fake_kwargs = pytree.tree_map_only( + torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs) + ) + # pyrefly: ignore [bad-context-manager] + with fake_mode, enable_python_dispatcher(): + fake_res = node.target(*fake_args, **fake_kwargs) + node.meta["val"] = fake_res + + if metadata is not None: + for k, v in metadata.items(): + node.meta[k] = v + + # Copy over metadata from argument nodes + arg_meta = [ + arg.meta + for arg in pytree.tree_flatten((node.args, node.kwargs))[0] + if isinstance(arg, torch.fx.Node) + ] + if len(arg_meta) == 0: + return + arg_meta = arg_meta[0] + + node.meta["nn_module_stack"] = node.meta.get( + "nn_module_stack", + arg_meta.get( + "nn_module_stack", + { + _EMPTY_NN_MODULE_STACK_KEY: ( + _EMPTY_NN_MODULE_STACK_KEY, + _EMPTY_NN_MODULE_STACK_KEY, + ) + }, + ), + ) + + node.meta["torch_fn"] = node.meta.get( + "torch_fn", + ( + f"{node.target.__name__}_0", + # pyrefly: ignore [missing-attribute] + f"{node.target.__class__.__name__}.{node.target.__name__}", + ), + ) + + +@contextlib.contextmanager +def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "node_metadata_hook must be a callable." + + # Add the hook to all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._register_create_node_hook(f) + try: + yield + finally: + # Restore hook for all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._unregister_create_node_hook(f) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..345401e9f76e5e82d462f3a5c56a30bb3e1f5e8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import math +import operator +import traceback +from functools import partial +from typing import NamedTuple, TYPE_CHECKING + +import sympy + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return math.inf + if val in (-sympy.oo, -int_oo): + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError("Export constraints cannot be non-integer expressions") + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): + def __init__( + self, + range_constraints: dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, node, lower, upper, assert_msg): + last_node = node + if lower > -math.inf: + last_node = self._insert_assert_async( + last_node, operator.ge, node, lower, assert_msg + ) + + if upper < math.inf: + last_node = self._insert_assert_async( + last_node, operator.le, node, upper, assert_msg + ) + + def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + graph = last_node.graph + with graph.inserting_after(last_node): + cmp = graph.call_function(op, (lower, upper), {}) + with graph.inserting_after(cmp): + cmp_tensor = graph.call_function( + torch.ops.aten.scalar_tensor.default, (cmp,), {} + ) + with graph.inserting_after(cmp_tensor): + assert_async = graph.call_function( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + ) + return assert_async + + def call(self, graph_module) -> PassResult: + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if "val" not in node.meta: + continue + + val = node.meta["val"] + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + + def add_assertions(val): + call_backs: list[Callable] = [] + messages: list[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols( + symbol + ): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial( + self._assert_range_constraint, + lower=min_val, + upper=max_val, + ) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + + def sym_size_cb(node, assert_msg, dim): + with node.graph.inserting_after(node): + dim_node = module.graph.call_function( + torch.ops.aten.sym_size.int, + (node, dim), + {}, + ) + cb(node=dim_node, assert_msg=assert_msg) + + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(node=node, assert_msg=f"{node}" + msg) + + module.recompile() + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if ( + self.counter == 0 + and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass + ): + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in graph_module.graph.nodes: + if not node.meta.get("stack_trace", None) and node.op not in [ + "placeholder", + "output", + ]: + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + return PassResult(graph_module, True) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_scalar.default: + continue + + compare_arg = node.args[0] + if not ( + isinstance(compare_arg, torch.fx.Node) + and compare_arg.op == "call_function" + and compare_arg.target in (operator.le, operator.ge) + and len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + lhs, rhs = compare_arg.args + + def maybe_get_symint(x): + if ( + isinstance(x, torch.fx.Node) + and "val" in x.meta + and isinstance(x.meta["val"], torch.SymInt) + ): + return x.meta["val"].node.expr + return x + + lhs = maybe_get_symint(lhs) + rhs = maybe_get_symint(rhs) + + if compare_op is operator.ge: + lhs, rhs = rhs, lhs + + if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): + symint = lhs + scalar = rhs + elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int): + symint = rhs + scalar = lhs + else: + continue + + if symint not in range_constraints: + raise RuntimeError( + f"Unable to find symint {symint} in {range_constraints}" + ) + + previous_range = existing_inline_assertions.get( + symint, ValueRanges(-math.inf, math.inf) + ) + + if symint is lhs: + bounds = ValueRanges(-math.inf, scalar) + else: + bounds = ValueRanges(scalar, math.inf) + existing_inline_assertions[symint] = previous_range & bounds + + return existing_inline_assertions diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a82564886889deabfc758d61e32289ab7843a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +from typing import TYPE_CHECKING + +import torch +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +if TYPE_CHECKING: + from torch.export.exported_program import ModuleCallSignature + from torch.export.graph_signature import ExportGraphSignature + + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__( + self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature + ) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm: torch.fx.GraphModule) -> PassResult | None: + def get_arg_spec(arg) -> TensorArgument | ConstantArgument: + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(name="", value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + nn_module_stack = None + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_outputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_inputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + nn_module_stack = None + for node in reversed(module.graph.nodes): + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_outputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + + def copy_sig(sig) -> ModuleCallSignature: + from torch.export.exported_program import ModuleCallSignature + + return ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=sig.in_spec, + out_spec=sig.out_spec, + forward_arg_names=None, + ) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target is torch.ops.higher_order._export_tracepoint: + # There's some subtlety worth noting. Here fqn corresponds to + # the call name, whereas path corresponds to the module name. + # They are not necessarily the same! When a submodule is shared + # through different aliases, there are as many _export_tracepoint + # markers as there are aliases, since the shared submodule is + # wrapped once for each alias. + path = node.kwargs["path"] + fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) + + module_key = next(reversed(node.meta["nn_module_stack"])) + if "@" in module_key: + suffix = module_key.split("@")[-1] + path = f"{path}@{suffix}" + + call_fqn = f"{fqn}@{suffix}" + if call_fqn not in self.specs: + self.specs[call_fqn] = copy_sig(self.specs[fqn]) + fqn = call_fqn + + kind = node.kwargs["kind"] + for i, arg in enumerate(node.args): + # We only update the signature of the alias used to call + # the submodule. Otherwise the signatures of all aliases + # would get conflated; the inputs/outputs of every call + # would be recorded in every other call as well. + if fqn == path: + if kind == "module_call_inputs": + self.specs[path].inputs.append(get_arg_spec(arg)) + elif kind == "module_call_outputs": + self.specs[path].outputs.append(get_arg_spec(arg)) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target is operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) + + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..58534856422c73b20fc85877c8d13ea88532aa45 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py @@ -0,0 +1,304 @@ +# mypy: allow-untyped-defs +import collections +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + ): + super().__init__(gm) + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.Node) -> bool: + if ( + node.target is torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.pt2e_quant.dequantize_affine, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] + + for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) is type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target is aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): # type: ignore[override] + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + env[n] = self.unknown_value + return super().run(initial_env=env) + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag(gm: torch.fx.GraphModule) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..45dd734c72959cd23c00d88e18dbcf80b8cd3227 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,99 @@ +import copy +from typing import Optional + +import torch +from torch._export.pass_base import ( + _ExportPassBaseDeprecatedDoNotUse, + Argument, + PassResult, +) +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1e5fb6a9d7fb47ed6d2a9164313b04bbab37c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py @@ -0,0 +1,80 @@ +import functools +from collections import defaultdict + +import torch +from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, +) +from torch._library.fake_profile import OpProfile, TensorMetadata + + +def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> None: + """ + This is used by draft_export to insert guards in front of calls to custom + operators which have a generated fake kernel. + """ + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + with ( + _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + metadata={"stack_trace": node.meta.get("stack_trace")}, + ), + ), + gm.graph.inserting_before(node), + ): + for arg in (*node.args, *node.kwargs.values()): + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta.get("val"), torch.Tensor + ): + val = arg.meta["val"] + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(arg,), + kwargs={ + "dtype": val.dtype, + "device": val.device, + "layout": val.layout, + }, + ) + + gm.recompile() + + +def get_op_profiles( + gm: torch.fx.GraphModule, ops_to_guard: set[str] +) -> dict[str, set[OpProfile]]: + """ + This is used by draft_export to get a list of custom operator profiles so + that we can generate fake kernels. + """ + + def _get_op_profile(node: torch.fx.Node) -> OpProfile: + args_profile = tuple( + TensorMetadata.maybe_from_tensor(arg.meta.get("val")) + if isinstance(arg, torch.fx.Node) + else None + for arg in (*node.args, *node.kwargs.values()) + ) + + out_profile = None + meta = node.meta.get("val") + assert meta is not None + if isinstance(meta, torch.Tensor): + out_profile = TensorMetadata.maybe_from_tensor(meta) + elif isinstance(meta, (list, tuple)): + out_profile = tuple(TensorMetadata.maybe_from_tensor(m) for m in meta) # type: ignore[assignment] + assert out_profile is not None + + return OpProfile(args_profile, out_profile) # type: ignore[arg-type] + + op_profiles: dict[str, set[OpProfile]] = defaultdict(set) + + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + op_profiles[str(node.target)].add(_get_op_profile(node)) + + return op_profiles diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..607989cd919cbb6d4cf59aab3071a9f7c5b5375f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,417 @@ +# mypy: allow-untyped-defs +import collections +import logging +from typing import Any, Optional, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_reference_type +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.graph_module import _get_attr + + +log = logging.getLogger(__name__) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors, + ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, + but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to + the same underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self) -> None: + # Underlying dict that we use to implement this mapping. + self._constant_attrs: dict[ + Union[int, torch.Tensor, FakeScriptObject, torch.utils._pytree.TreeSpec], + list[Any], + ] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: dict[int, torch.ScriptObject] = {} + + def __getitem__(self, key: _ConstantAttributeType) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject)) + return self._constant_attrs[real_key] + + def __setitem__(self, key: _ConstantAttributeType, value): + # we shouldn't actually call this, should go to add() instead to handle aliasing + raise NotImplementedError( + """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead. +The same key can be mapped to multiple values, for handling constant aliasing.""" + ) + + def add(self, key: _ConstantAttributeType, value: Any) -> None: + if isinstance(key, torch.ScriptObject): + if hash(key) not in self._constant_attrs: + self._constant_attrs[hash(key)] = [] + self._constant_attrs[hash(key)].append(value) + self._script_object_map[hash(key)] = key + elif isinstance(key, (torch.Tensor, FakeScriptObject)): + if key not in self._constant_attrs: + self._constant_attrs[key] = [] + self._constant_attrs[key].append(value) + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key: _ConstantAttributeType): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + if len(node.meta["nn_module_stack"]) == 0: + return constant_name + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def _get_first_fqn( + const_attrs: ConstantAttrMap, + key: _ConstantAttributeType, +) -> Any: + fqns = const_attrs.get(key) + return fqns[0] if fqns else None + + +def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]: + """ + If there is a tensor constant created while tracing, here is how the graph + looks like: + + %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0] + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,)) + %detach_ : [num_users=?] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,)) + + To check to see if the tensor constant is being used, we want to traverse to + the detach node to see if it's actually being used. + + This function returns None if this constant is being used, otherwise it returns the + lift_fresh and detach node to be removed later. + """ # noqa: B950 + if len(node.users) > 1: + return None + + lift_fresh_node = next(iter(node.users.keys())) + if not ( + lift_fresh_node.op == "call_function" + and lift_fresh_node.target + in ( + torch.ops.aten.lift_fresh.default, + torch.ops.aten.lift_fresh_copy.default, + ) + ): + return None + + if len(lift_fresh_node.users) > 1: + return None + + # Case 1: lift node is not used anywhere + if len(lift_fresh_node.users) == 0: + return [lift_fresh_node, node] + + detach_node = next(iter(lift_fresh_node.users.keys())) + if not ( + detach_node.op == "call_function" + and detach_node.target + in ( + torch.ops.aten.detach_.default, + torch.ops.aten.detach.default, + ) + ): + return None + + if len(detach_node.users) > 0: + return None + else: + # Case 2: Lift node's child is not used anywhere + return [detach_node, lift_fresh_node, node] + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + """ + Takes a graph module, graph signature, and modifies them inplace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: dict[str, _ConstantAttributeType] = {} + + input_specs = graph_signature.input_specs + num_custom_obj = sum( + input_spec.kind == InputKind.CUSTOM_OBJ for input_spec in input_specs + ) + num_tensor_constants = sum( + input_spec.kind == InputKind.CONSTANT_TENSOR for input_spec in input_specs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes)) + used_target_names = set() + + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(input_nodes) == len(input_specs) + for i, (node, input_spec) in enumerate(zip(input_nodes, input_specs)): + used_target_names.add(input_spec.target) + if input_spec.kind == InputKind.USER_INPUT: + first_user_input = node + first_user_input_loc = i + break + + lifted_objs = ConstantAttrMap() + renamed_targets = {} + for node in list(gm.graph.nodes): + if node.op == "get_attr": + if nodes_to_remove := _unused_constant(node): + # Remove the node if it's not being used + for node_rm in nodes_to_remove: + gm.graph.erase_node(node_rm) + continue + + constant_val = _get_attr(gm, node.target) + # These are not hashable and not gonna be lifted + # so we can skip them earlier + if isinstance(constant_val, torch.fx.GraphModule): + continue + if "LoweredBackendModule" in type(constant_val).__name__: + continue + if "AOTInductorRunnerWrapper" in type(constant_val).__name__: + continue + if isinstance(constant_val, torch.utils._pytree.TreeSpec): + continue + + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + renamed_targets[node.name] = const_placeholder_node.name + continue + + # For ScriptObject, Tensor and FakeScriptObject constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance( + constant_val, (torch.ScriptObject, FakeScriptObject) + ) or is_opaque_reference_type(type(constant_val)): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_custom_obj += 1 + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + # Remove the parameterness of constant_val + if isinstance(constant_val, torch.nn.Parameter): + log.debug( + "%s created when tracing %s is a parameter. But " + "it's not registered with register_parameter(). export will treat it as a constant tensor", + str(node.target), + str(node.meta.get("stack_trace", "")), + ) + # We get the real data out of the parameter by disabling the surrounding fake mode. + with unset_fake_temporarily(): + constant_val = constant_val.data + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_tensor_constants += 1 + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + # Once the FQN has been used, remove nn_module_stack, stack_trace + const_placeholder_node.meta.pop("nn_module_stack") + const_placeholder_node.meta.pop("stack_trace", None) + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + elif isinstance(constant_val, FakeScriptObject): + class_fqn = constant_val.script_class_name + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn, constant_val + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, + class_fqn=class_fqn, + fake_val=constant_val, + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs.add(constant_val, const_placeholder_node) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + renamed_targets[node.name] = const_placeholder_node.name + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + if constant_val in constant_attrs: + for fqn in constant_attrs[constant_val]: + all_constants[fqn] = constant_val + else: + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + for spec in graph_signature.output_specs: + if spec.arg.name in renamed_targets: + spec.arg.name = renamed_targets[spec.arg.name] + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> dict[str, _ConstantAttributeType]: + """When tracing, we produce a graph with FakeScriptObject in the + meta["val"]. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: dict[ + str, + _ConstantAttributeType, + ] = {} + for node in gm.graph.nodes: + if "val" not in node.meta: + continue + + old_meta = node.meta["val"] + + if isinstance(old_meta, torch.ScriptObject): + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + elif isinstance(old_meta, FakeScriptObject): + class_fqn = old_meta.script_class_name # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn, old_meta) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants + + +def _materialize_and_lift_constants( + gm: torch.fx.GraphModule, + export_graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + return constants diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..ceed7cd23aa0e953b99586052629668cc53c4bdd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,36 @@ +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target in [ + torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten._assert_tensor_metadata.default, + ]: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + + # We don't necessarily want to run DCE here because it could affect + # nodes that are in the module_call_graph attribute of the exported + # program. We will leave it to the pass caller to call DCE. + return PassResult(graph_module, modified) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..14ab3e817ed703cbe0844198deca5c06f2e6effc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch._higher_order_ops.wrap import wrap_with_autocast + +from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target + in [ + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ] + ) + + +def _is_enter_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch.amp.autocast_mode._enter_autocast + ) + + +def _is_exit_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch.amp.autocast_mode._exit_autocast + ) + + +def _is_autocast_sub_mod(node: torch.fx.Node) -> bool: + """ + Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target is torch.amp.autocast_mode._enter_autocast + ): + # TODO: check if current auto-cast type is the same as the args of + # _enter_autocast. If so, return False, i.e. do not create a submodule. + return True + return False + + +def _check_valid_autocast_block( + enter_autocast_node: torch.fx.Node, exit_autocast_node: torch.fx.Node +) -> None: + assert _is_enter_autocast_node(enter_autocast_node) + assert _is_exit_autocast_node(exit_autocast_node) + assert exit_autocast_node.args[0] == enter_autocast_node + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node) + if len(autocast_nodes) > 0: + assert len(autocast_nodes) > 1 # need at least an enter node and an exist node + enter_autocast_node = autocast_nodes[0] + exit_autocast_node = autocast_nodes[-1] + _check_valid_autocast_block(enter_autocast_node, exit_autocast_node) + + _replace_with_hop_helper(node, enter_autocast_node, wrap_with_autocast) + sub_graph.erase_node(exit_autocast_node) + sub_graph.erase_node(enter_autocast_node) + + +def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + split_autocast creates a new graph module that splits the input graph module into multiple submodules + based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. + + Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are split + into a submodule. Nested autocast regions are not split. + `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. + + Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph + module. Nodes marked with the same number are grouped into the same submodule. + A # 0 + enter_autocast # 1 + B # 1 + exit_autocast # 1 + C # 2 + enter_autocast # 3 + D # 3 + exit_autocast # 3 + E # 4 + """ + enter_autocast_node_stack: list[torch.fx.Node] = [] + first_node_after_outer_most_exit: bool = False + + def node_call_back(node: torch.fx.Node) -> bool: + nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit + increment_id = False + if first_node_after_outer_most_exit or ( + len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node) + ): + assert len(enter_autocast_node_stack) == 0 + first_node_after_outer_most_exit = False + increment_id = True + if _is_enter_autocast_node(node): + enter_autocast_node_stack.append(node) + elif _is_exit_autocast_node(node): + assert len(enter_autocast_node_stack) > 0 + last_enter_autocast_node = enter_autocast_node_stack.pop() + assert node.args[0] == last_enter_autocast_node + if len(enter_autocast_node_stack) == 0: + # next node should be in the next submodule since + # autocast block ends + first_node_after_outer_most_exit = True + return increment_id + + return sequential_split(gm, node_call_back) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replace_autocast_with_hop_pass(). + Split the graph module into multiple subgraphs based on the autocast nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered + as a subgraph. + """ + need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # split_autocast returns a new graph module that could have different output + # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`. + new_gm = _split_autocast(gm) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None: + if _is_autocast_sub_mod(node): + _replace_with_hop(node) + else: + assert node.op == "call_module" + assert isinstance(node.target, str) + node_inline_(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_autocast_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2324d1f2cfa20c96003d3ae9e634784994648b10 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -0,0 +1,676 @@ +# mypy: allow-untyped-defs +import logging +import operator +from typing import Optional, Union + +import torch +import torch.export._trace +from torch._ops import OpOverload +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel, + dequantize_per_tensor, + quantize_per_tensor, +) +from torch.ao.quantization.utils import calculate_qmin_qmax +from torch.fx.graph_module import _assign_attr + + +log = logging.getLogger(__name__) + +# Those values will need to be carried over multiple operators. +_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None +_SCALE: Optional[Union[float, torch.fx.Node]] = None +_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None + + +def int_to_valid_dtype(val: int) -> torch.dtype: + from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import. + + if isinstance(val, torch.dtype): + return val + dtype = _TORCH_ENUM_TO_DTYPE[val] + if dtype == torch.quint8: + return torch.uint8 + elif dtype == torch.qint8: + return torch.int8 + return dtype + + +def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node: + return gm.graph.call_function(int_to_valid_dtype, (val,)) + + +def insert_quantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + return gm.graph.call_function( + quantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + + +def get_dequantized( + val: torch.Tensor, + scale: Union[float, torch.Tensor], + zero_point: Union[float, torch.Tensor], + qmin: Union[float, int], + qmax: Union[float, int], + dtype: torch.dtype, + axis: Optional[int], + qscheme: Optional[torch.qscheme], +) -> torch.Tensor: + if qscheme is torch.per_tensor_affine: + return dequantize_per_tensor( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + elif qscheme is torch.per_channel_affine: + return dequantize_per_channel( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + axis, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def insert_dequantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + axis_node: Optional[Union[int, torch.fx.Node]], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + if qscheme is torch.per_tensor_affine: + return gm.graph.call_function( + dequantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + elif qscheme is torch.per_channel_affine: + return gm.graph.call_function( + dequantize_per_channel, + ( + val_node, + scale_node, + zero_point_node, + axis_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def get_qmin_qmax(dtype: torch.dtype) -> tuple[Union[int, float], Union[int, float]]: + return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type] + + +def insert_qmin_qmax_node( + gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node] +) -> tuple[torch.fx.Node, torch.fx.Node]: + q_min_max_node = gm.graph.call_function( + calculate_qmin_qmax, (None, None, False, dtype_node, False) + ) + qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0)) + qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1)) + return qmin_node, qmax_node + + +def get_script_object( + gm: torch.nn.Module, node: torch.fx.Node +) -> torch._C.ScriptObject: + assert isinstance(node, torch.fx.Node) + assert node.op == "get_attr" + attr_name = node.target + assert isinstance(attr_name, str) + + mod = gm + for attr in attr_name.split("."): + mod = getattr(mod, attr) + assert isinstance(mod, torch._C.ScriptObject) + return mod + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm: torch.fx.GraphModule, + param_node: torch.fx.Node, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + """Directly inline tensor from a get_attr fx node.""" + mod = get_script_object(gm, param_node) + w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined] + w_attr_name, b_attr_name = ( + f"dequantized_{param_node.target}_w", + f"dequantized_{param_node.target}_b", + ) + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm: torch.fx.GraphModule, + get_attr_to_weight_node: torch.fx.Node, + get_attr_to_bias_node: Optional[torch.fx.Node], +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + assert isinstance(get_attr_to_weight_node.target, str) + w_qtensor = getattr(gm, get_attr_to_weight_node.target) + w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w" + + if get_attr_to_bias_node is not None: + assert isinstance(get_attr_to_bias_node.target, str) + b_qtensor = getattr(gm, get_attr_to_bias_node.target) + b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b" + else: + b_qtensor, b_attr_name = None, "" + + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node( + gm: torch.fx.GraphModule, + w_qtensor: torch.Tensor, + b_qtensor: Optional[torch.Tensor], + w_attr_name: str, + b_attr_name: str, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + w_tensor = get_tensor_from_qtensor(w_qtensor) + _assign_attr(w_tensor, gm, w_attr_name) + w_tensor_attr = gm.graph.get_attr(w_attr_name) + + if b_qtensor is not None: + b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False) + _assign_attr(b_tensor, gm, b_attr_name) + b_tensor_attr = gm.graph.get_attr(b_attr_name) + else: + b_tensor_attr = None + + return w_tensor_attr, b_tensor_attr + + +def get_tensor_from_qtensor( + qtensor: torch.Tensor, dequant: bool = True +) -> torch.Tensor: + # Manual conversion because qint8 is not used anymore. + if qtensor.dtype in [torch.qint8, torch.quint8]: + tensor = qtensor.int_repr() + else: + tensor = qtensor + + # Weights need dequantization with scaling and zero_point adjustment, but + # bias does not need that. + if dequant: + qscheme = qtensor.qscheme() + if qscheme == torch.per_channel_affine: + scale, zero_point, axis = ( + qtensor.q_per_channel_scales(), + qtensor.q_per_channel_zero_points(), + qtensor.q_per_channel_axis(), + ) + else: + scale, zero_point, axis = ( + qtensor.q_scale(), # type: ignore[assignment] + qtensor.q_zero_point(), # type: ignore[assignment] + None, + ) + dtype = tensor.dtype + qmin, qmax = get_qmin_qmax(dtype) + return get_dequantized( + tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme + ) + return tensor + + +def insert_fused_activation_node( + gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node +) -> torch.fx.Node: + if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]: + fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,)) + return fx_node + + +def _conv1d_op_with_squeeze( + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +) -> torch.Tensor: + # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze + # operations before and after the conv2d operation to match the dimension of weights. + # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 + s_inp = torch.ops.aten.unsqueeze(inp, 2) + conv1d_res = torch.ops.aten.conv2d( + s_inp, + weight, + bias, + stride, + padding, + dilation, + groups, + ) + uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2) + return uns_conv1d_res + + +def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Conv specific transformation function.""" + assert isinstance(node.target, torch._ops.OpOverload) + opname = node.target._opname + scale_node, zero_point_node = node.args[2], node.args[3] + + op_f = ( + torch.ops.aten.conv2d + if opname in ["conv2d", "conv2d_relu"] + else _conv1d_op_with_squeeze + ) + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using Conv2dPrepackParam from conv_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + op_f, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using ConvPrepackedParam. + param = get_script_object(gm, param_node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + op_f, + ( + inp_node, + param_0, + param_1, + param.stride(), # type: ignore[attr-defined] + param.padding(), # type: ignore[attr-defined] + param.dilation(), # type: ignore[attr-defined] + param.groups(), # type: ignore[attr-defined] + ), + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Linear specific transformation function.""" + scale_node, zero_point_node = node.args[2], node.args[3] + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using LinearPrepackParam from linear_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using LinearPackedParams. + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1) + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_op_where_last_two_arguments_are_scale_and_zero_point( + gm: torch.fx.GraphModule, node: torch.fx.Node +): + """ + This transformation function can be used for function where the last two + parameters are scale and zero point. Additionally, the function's parameters + do not need any unpacking. + """ + to_standard_op = { + "mul": torch.ops.aten.mul, + "mul_relu": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "add_relu": torch.ops.aten.add, + "softmax": torch.ops.aten.softmax, + "cat": torch.ops.aten.cat, + "hardswish": torch.ops.aten.hardswish, + } + + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2])) + return op_res_node, scale_node, zero_point_node + + +def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Transform scalar overload for basic arithmetic.""" + to_standard_op = { + "mul": torch.ops.aten.mul.Scalar, + "add": torch.ops.aten.add.Scalar, + } + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_res_node = gm.graph.call_function(to_standard_op[opname], args) + return op_res_node, _SCALE, _ZERO_POINT + + +def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node): + """ + Transformation for functions under prepacked namespace, where they share + the same handling logic that [...]OpContext contains all parameters. + """ + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_f = None + if opname == "conv2d_clamp_run": + op_f = torch.ops.aten.conv2d + elif opname == "linear_clamp_run": + op_f = torch.ops.aten.linear + else: + raise RuntimeError(f"Invalid operator {opname}") + + assert isinstance(args[1], torch.fx.Node) + so = get_script_object(gm, args[1]) + + func_args = [] + func_args += [args[0]] + func_args += so.unpack()[:2] # type: ignore[attr-defined] + if opname == "conv2d_clamp_run": + func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:] + + op_res_node = gm.graph.call_function(op_f, tuple(func_args)) + return op_res_node + + +def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node): + args = node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function( + torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3]) + ) + op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0)) + return op_res_node, scale_node, zero_point_node + + +def fx_transform_quantized_op_to_standard_op( + gm: torch.fx.GraphModule, node: torch.fx.Node +) -> torch.fx.Node: + global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE + + assert isinstance(node.target, torch._ops.OpOverload) + opname, overload = node.target._opname, node.target._overloadname + + key = f"{opname}.{overload}" + opname_to_transform_f = { + "conv1d.new": _transform_conv_with_packedparam, + "conv1d_relu.new": _transform_conv_with_packedparam, + "conv1d.default": _transform_conv_with_packedparam, + "conv1d_relu.default": _transform_conv_with_packedparam, + "conv2d.new": _transform_conv_with_packedparam, + "conv2d_relu.new": _transform_conv_with_packedparam, + "conv2d.default": _transform_conv_with_packedparam, + "conv2d_relu.default": _transform_conv_with_packedparam, + "linear.default": _transform_linear_with_packedparam, + "linear_relu.default": _transform_linear_with_packedparam, + "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "batch_norm2d.default": _transform_batch_norm, + "mul.Scalar": _transform_scalar_arithmetic, + "add.Scalar": _transform_scalar_arithmetic, + } + + if f"{key}" not in opname_to_transform_f: + raise RuntimeError(f"Unsupported quantized op during transformation: {key}") + + op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node) + + # Add fused activation layer. + op_res_node = insert_fused_activation_node(gm, opname, op_res_node) + _SCALE, _ZERO_POINT = scale_node, zero_point_node + + assert _INPUT_Q_DTYPE is not None + qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE) + q_fx_node = insert_quantized_node( + gm, + op_res_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + None, + torch.per_tensor_affine, + ) + return dq_fx_node + + +def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): + """ + Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with + PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv). + + Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y + + After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y + + (qd == quantized_decomposed library, q = quantize, dq = dequantize) + ^ + | + getattr(w), getattr(b) from Conv2dParamPrepack + + During each iteration, the transformation spits out the transformed operator, its quantized output, + and its dequantized value together. We did this because dequantization need to use the + scale and zero point parameters from the quantization to recover the approximate original value. After each + iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear). + + For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject. + During the transformation, we unpack those objects, get their dequantized tensor, populate those + as attributes to the module, and use getattr to access them. + + One exception in the transformation is conv_prepack and linear_prepack. Those calls pack + weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls. + During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the + quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters + to the operator by converting them to a getattr fx.node. + + For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear + without the need of doing de/quantization. + + Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization + data type, which is the same across the entire program, but it only shows up in the very first quantization + call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar. + """ + + global _INPUT_Q_DTYPE + + quantized = False + + last_quantized_node = None + # pyrefly: ignore [bad-assignment] + for node in gm.graph.nodes: + if isinstance(node.target, OpOverload): + with gm.graph.inserting_before(node): + namespace, opname = node.target.namespace, node.target._opname + if namespace == "quantized" and opname not in [ + "conv_prepack", + "linear_prepack", + ]: + quantized = True + fx_node = fx_transform_quantized_op_to_standard_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "prepacked": + quantized = True + fx_node = _transform_prepacked_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "aten" and opname == "quantize_per_tensor": + inp_node, scale_node, zero_point_node, dtype_node = node.args + dtype_node = fx_enum_to_dtype(gm, dtype_node) + _INPUT_Q_DTYPE = dtype_node + qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node) + q_fx_node = insert_quantized_node( + gm, + inp_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + None, + torch.per_tensor_affine, + ) + node.replace_all_uses_with(dq_fx_node) + last_quantized_node = dq_fx_node + elif namespace == "aten" and opname == "dequantize": + assert last_quantized_node is not None + node.replace_all_uses_with(last_quantized_node) + else: + last_quantized_node = node + + # Post-processing again to remove legacy ScriptObjects and quantizated tensors + # stored as attributes or in the buffer. This is used to clean up the GraphModule + # to not trigger tracing errors like missing __obj_flatten__ functions. + def _clean_attr(mod: torch.nn.Module): + for submod in mod.modules(): + attr_names_to_clean = set() + for k, v in submod.__dict__.items(): + if isinstance(v, torch.ScriptObject): + attr_names_to_clean.add(k) + if k == "_buffers": + buffer_name_to_clean = set() + # pyrefly: ignore [missing-attribute] + for b_name, b_value in v.items(): + if isinstance(b_value, torch.Tensor) and b_value.dtype in [ + torch.qint8, + torch.quint8, + ]: + buffer_name_to_clean.add(b_name) + for b_name in buffer_name_to_clean: + # pyrefly: ignore [missing-attribute] + v.pop(b_name, None) + for attr_name in attr_names_to_clean: + delattr(submod, attr_name) + + if quantized: + """ + TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily + bypass test cases. + + The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing + will throw errors. However, the current way of SetAttr does inplace update to attributes, so + this pass regard them as dead code and remove them. Below is an example of GraphModule before + and after the dead code elimination pass. + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data = self.data; data = None + data_1 = self.data + add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None + data_2 = self.data + copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + """ + gm.graph.eliminate_dead_code() + _clean_attr(gm) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5a15a5950575527b9beca532e4b0229b2603c1a0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_set_grad_enabled_node(node: torch.fx.Node) -> torch.fx.Node | bool: + return ( + node + and node.op == "call_function" + and node.target is torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod( + node: torch.fx.Node, omit_if_same_with_ambient: bool = False +) -> bool | torch.Tensor: + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target is torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + _replace_with_hop_helper(node, set_grad_node, wrap_with_set_grad_enabled) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replace_set_grad_with_hop_pass(). + Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + """ + need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # sequential_split returns a new graph module that could have different output + # args names. We need to fix the graph signature. + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_set_grad_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..489bc19ed1d50d13f7bc8d7cd73f940bb34f451d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch._ops import HigherOrderOperator, OpOverload + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name if schema.overload_name != "" else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py new file mode 100644 index 0000000000000000000000000000000000000000..862244aac8837fd10c3d86838d81db6bd0c62a7e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import copy +import operator +from typing import TYPE_CHECKING + +import torch + +from ..utils import node_replace_, nodes_map + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._ops import HigherOrderOperator + from torch.export.graph_signature import ExportGraphSignature + + +def _replace_with_hop_helper( + node: torch.fx.Node, + enter_block_node: torch.fx.Node, + wrap_hoo: HigherOrderOperator, +) -> None: + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + + def set_hoo_node_meta(call_func_node): + call_func_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + call_func_node.meta["torch_fn"] = ( + f"{wrap_hoo.__name__}", + # pyrefly: ignore [missing-attribute] + f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", + ) + if isinstance(output_args, (tuple, list)): + call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args) + elif isinstance(output_args, torch.fx.Node): + call_func_node.meta["val"] = (output_args.meta["val"],) + + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + get_attr_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + # Split_module pass intentionally doesn't add output node + # if the graph doesn't return anything. + # TODO (tmanlaibaatar) Figure out if this is right behaviour + # for split_module + if isinstance(output_node, torch.fx.Node) and output_node.op != "output": + output_node = None + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + enter_block_node_args = enter_block_node.args + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + ) + # Create the metadata + set_hoo_node_meta(call_func_node) + node_replace_(node, call_func_node) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users: + idx: int = get_item_node.args[1] # type: ignore[assignment] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + output_args.name, + ) + # Modify the subgraph to output a singleton list. + output_node.args = ((output_args,),) + # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph. + get_item_node = graph.create_node( + "call_function", + operator.getitem, + (call_func_node, 0), + {}, + ) + # Create the metadata + get_item_node.meta = output_args.meta + set_hoo_node_meta(call_func_node) + node_replace_(node, get_item_node) + else: + raise NotImplementedError( + f"replace_with_hop_pass doesn't support output type {type(output_args)}" + ) + else: + # TODO (shangdiy): remove this line, since the export graph can be non-functional + node.graph.erase_node(node) + + +def _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature | None, + maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Helper function for replacing graph nodse with higher order nodes. + For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`. + """ + # new_gm is a new graph module that could have different output args names. + # We need to fix the graph signature. + replace_ctx = contextlib.nullcontext() + new_signature = None + if graph_signature is not None: + # Cannot deep copy a real ScriptObject, which is referenced + # in the FakeScriptObject. Copy should be good enough to guard + # against accidental mutation to original graph_signature. + new_signature = copy.copy(graph_signature) + new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output"))) + assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len( + new_signature.output_specs + ) + for arg_node, out_spec in zip( + new_gm_out_node.args[0], new_signature.output_specs + ): + if arg_node is None: + assert out_spec.arg.value is None # type: ignore[union-attr] + elif ( + isinstance(arg_node, torch.fx.Node) + and out_spec.arg.name != arg_node.name + ): + out_spec.arg.name = arg_node.name + + replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment] + + with replace_ctx: + nodes_map( + list(new_gm.graph.nodes), + lambda node: ( + maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node + ), + ) + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature + + +def _replace_with_hop_pass_helper( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature | None, + sequential_split_and_maybe_inline_subgraphs: Callable[ + [torch.fx.GraphModule, ExportGraphSignature | None], + tuple[torch.fx.GraphModule, ExportGraphSignature | None], + ], +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs( + gm, graph_signature + ) + # recursively call + for node in new_gm.graph.nodes: + if node.op == "get_attr": + subgm = getattr(new_gm, node.target) + if not isinstance(subgm, torch.fx.GraphModule): + continue + new_subgm, _ = _replace_with_hop_pass_helper( + subgm, + None, + sequential_split_and_maybe_inline_subgraphs, + ) + setattr(new_gm, node.target, new_subgm) + + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb3983e26089454df59793b3b36c65dcfbdc2c37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963ada41daa250d8b5556f8b9153082fa4b62308 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca9a55509673181d3b23eae747952317b1197cb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d03fac3cc1e68793d415fe3166ffb5e6fe4e1083 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/union.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/union.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..218e93ff16313ed230d59ce535ec2f4916c24ada Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/__pycache__/union.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d002874d48245d2053c9bdc72bca02ebca606e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,324 @@ +import dataclasses +from typing import Any, Optional, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a Dim object. + """ + + min: int + max: Union[int, None] + derived: list[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dims: dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + dims: dict[str, dict[str, Union[int, list[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + # pyrefly: ignore [bad-argument-type] + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + "dynamic_shapes": ( + [ + ["dx", 4], + ["dx + 1", 4], + ], + ["_DimHint.STATIC"], + ["_DimHint.STATIC", "_DimHint.STATIC"], + None, + ), + "dims": { + "dx": { + "min": 4, + "max": 16, + "derived": ["dx + 1"], + }, + }, + } + ``` + """ + dims: dict[str, dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, Dim], + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.type.name + + assert isinstance(val, Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, # type: ignore[attr-defined,union-attr] + "max": root.max, # type: ignore[attr-defined,union-attr] + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + # pyrefly: ignore [bad-assignment, bad-argument-type] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[dict[str, Any], tuple[Any], list[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim # type: ignore[assignment, operator] + if remainder != 0: + ddim = ddim + int(remainder) # type: ignore[assignment, operator] + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str], + ) -> Union[None, int, Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO() + elif val == "_DimHint.DYNAMIC": + return _DimHint.DYNAMIC() + elif val == "_DimHint.STATIC": + return _DimHint.STATIC() + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] # type: ignore[return-value] + + return tree_map(deserialize_shape, dynamic_shapes) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift new file mode 100644 index 0000000000000000000000000000000000000000..155f52595740c5a1d57b8071a11b509ef16d5fce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift @@ -0,0 +1,377 @@ +// @generated by update_schema.py +// checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>> + +namespace py3 torch._export +namespace cpp2 torch._export.schema + +enum ArgumentKind { + UNKNOWN = 0, + POSITIONAL = 1, + KEYWORD = 2, +} + + +enum Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +} + + +enum MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +} + + +enum ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, + FLOAT8E4M3FN = 29, + FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, +} + + +struct Device { + 10: string type; + 20: optional i64 index; +} + +union SymExprHint { + 10: i64 as_int; + 20: bool as_bool; + 30: double as_float; +} + +struct SymExpr { + 10: string expr_str; + 20: optional SymExprHint hint; +} + +union SymInt { + 10: SymExpr as_expr; + 20: i64 as_int; +} + +union SymFloat { + 10: SymExpr as_expr; + 20: double as_float; +} + +union SymBool { + 10: SymExpr as_expr; + 20: bool as_bool; +} + +struct TensorMeta { + 10: ScalarType dtype; + 20: list sizes; + 30: bool requires_grad; + 40: Device device; + 50: list strides; + 60: SymInt storage_offset; + 70: Layout layout; +} + +union SymIntArgument { + 10: string as_name; + 20: i64 as_int; +} + +union SymFloatArgument { + 10: string as_name; + 20: double as_float; +} + +union SymBoolArgument { + 10: string as_name; + 20: bool as_bool; +} + +struct TensorArgument { + 10: string name; +} + +struct TokenArgument { + 10: string name; +} + +union OptionalTensorArgument { + 20: TensorArgument as_tensor; + 10: bool as_none; +} + +struct GraphArgument { + 10: string name; + 20: Graph graph; +} + +struct CustomObjArgument { + 10: string name; + 20: string class_fqn; +} + +struct ComplexValue { + 10: double real; + 20: double imag; +} + +union Argument { + 10: bool as_none; + 20: TensorArgument as_tensor; + 30: list as_tensors; + 50: i64 as_int; + 70: list as_ints; + 80: double as_float; + 90: list as_floats; + 100: string as_string; + 101: list as_strings; + 110: SymIntArgument as_sym_int; + 120: list as_sym_ints; + 130: ScalarType as_scalar_type; + 140: MemoryFormat as_memory_format; + 150: Layout as_layout; + 160: Device as_device; + 170: bool as_bool; + 180: list as_bools; + 182: SymBoolArgument as_sym_bool; + 184: list as_sym_bools; + 200: GraphArgument as_graph; + 190: list as_optional_tensors; + 210: CustomObjArgument as_custom_obj; + 220: string as_operator; + 230: SymFloatArgument as_sym_float; + 240: list as_sym_floats; + 250: OptionalTensorArgument as_optional_tensor; + 260: ComplexValue as_complex; + 280: list> as_int_lists; + 290: map as_string_to_argument; +} + +struct NamedArgument { + 10: string name; + 20: Argument arg; + 30: optional ArgumentKind kind; +} + +struct Node { + 10: string target; + 20: list inputs; + 30: list outputs; + 40: map metadata; + 50: optional bool is_hop_single_tensor_return; +} + +struct Graph { + 10: list inputs; + 20: list outputs; + 30: list nodes; + 40: map tensor_values; + 50: map sym_int_values; + 60: map sym_bool_values; + 70: bool is_single_tensor_return; + 80: map custom_obj_values; + 90: map sym_float_values; +} + +struct UserInputSpec { + 10: Argument arg; +} + +union ConstantValue { + 10: bool as_none; + 20: i64 as_int; + 30: double as_float; + 40: string as_string; + 50: bool as_bool; +} + +struct InputToConstantInputSpec { + 10: string name; + 20: ConstantValue value; +} + +struct InputToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct InputToBufferSpec { + 10: TensorArgument arg; + 20: string buffer_name; + 30: bool persistent; +} + +struct InputToTensorConstantSpec { + 10: TensorArgument arg; + 20: string tensor_constant_name; +} + +struct InputToCustomObjSpec { + 10: CustomObjArgument arg; + 20: string custom_obj_name; +} + +struct InputTokenSpec { + 10: TokenArgument arg; +} + +union InputSpec { + 10: UserInputSpec user_input; + 20: InputToParameterSpec parameter; + 30: InputToBufferSpec buffer; + 40: InputToTensorConstantSpec tensor_constant; + 50: InputToCustomObjSpec custom_obj; + 70: InputTokenSpec token; + 60: InputToConstantInputSpec constant_input; +} + +struct UserOutputSpec { + 10: Argument arg; +} + +struct LossOutputSpec { + 10: TensorArgument arg; +} + +struct BufferMutationSpec { + 10: TensorArgument arg; + 20: string buffer_name; +} + +struct ParameterMutationSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct GradientToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct GradientToUserInputSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct UserInputMutationSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct OutputTokenSpec { + 10: TokenArgument arg; +} + +union OutputSpec { + 10: UserOutputSpec user_output; + 20: LossOutputSpec loss_output; + 30: BufferMutationSpec buffer_mutation; + 40: GradientToParameterSpec gradient_to_parameter; + 50: GradientToUserInputSpec gradient_to_user_input; + 60: UserInputMutationSpec user_input_mutation; + 70: OutputTokenSpec token; + 80: ParameterMutationSpec parameter_mutation; +} + +struct GraphSignature { + 10: list input_specs; + 20: list output_specs; +} + +struct RangeConstraint { + 10: optional i64 min_val; + 20: optional i64 max_val; +} + +struct ModuleCallSignature { + 10: list inputs; + 20: list outputs; + 30: string in_spec; + 40: string out_spec; + 50: optional list forward_arg_names; +} + +struct ModuleCallEntry { + 10: string fqn; + 30: optional ModuleCallSignature signature; +} + +struct NamedTupleDef { + 10: list field_names; +} + +struct GraphModule { + 10: Graph graph; + 50: GraphSignature signature; + 60: list module_call_graph; + 40: map metadata; + 70: map treespec_namedtuple_fields; +} + +struct SchemaVersion { + 10: i64 major; + 20: i64 minor; +} + +struct ExportedProgram { + 10: GraphModule graph_module; + 20: map opset_version; + 30: map range_constraints; + 60: SchemaVersion schema_version; + 70: list verifiers; + 80: string torch_version; + 90: list guards_code; +} + +struct PayloadMeta { + 10: string path_name; + 20: bool is_param; + 30: bool use_pickle; + 40: optional TensorMeta tensor_meta; +} + +struct PayloadConfig { + 10: map config; +} + +struct AOTInductorModelPickleData { + 1: string library_basename; + 2: list input_names; + 3: list output_names; + 4: optional i64 floating_point_input_dtype; + 5: optional i64 floating_point_output_dtype; + 6: optional bool aot_inductor_model_is_cpu; +} + +struct ExternKernelNode { + 10: string name; + 20: Node node; +} + +struct ExternKernelNodes { + 10: list nodes; +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..0d95ca32e6455ad2e8b13e1274a39a9ae0e78fd5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,520 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Annotated, Optional + +from torch._export.serde.union import _Union, _union_dataclass + + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (8, 15) +TREESPEC_VERSION = 1 + + +# NOTE: If you updated the schema, please run `scripts/export/update_schema.py` +# to update the auto generated files. +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + UINT16 = 28 + FLOAT8E4M3FN = 29 + FLOAT8E5M2 = 30 + FLOAT8E4M3FNUZ = 31 + FLOAT8E5M2FNUZ = 32 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: Annotated[str, 10] + index: Annotated[Optional[int], 20] = None + + +@_union_dataclass +class SymExprHint(_Union): + as_int: Annotated[int, 10] + as_bool: Annotated[bool, 20] + as_float: Annotated[float, 30] + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: Annotated[str, 10] + hint: Annotated[Optional[SymExprHint], 20] = None + + +@_union_dataclass +class SymInt(_Union): + as_expr: Annotated[SymExpr, 10] + as_int: Annotated[int, 20] + + +@_union_dataclass +class SymFloat(_Union): + as_expr: Annotated[SymExpr, 10] + as_float: Annotated[float, 20] + + +@_union_dataclass +class SymBool(_Union): + as_expr: Annotated[SymExpr, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorMeta: + dtype: Annotated[ScalarType, 10] + sizes: Annotated[list[SymInt], 20] + requires_grad: Annotated[bool, 30] + device: Annotated[Device, 40] + strides: Annotated[list[SymInt], 50] + storage_offset: Annotated[SymInt, 60] + layout: Annotated[Layout, 70] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@_union_dataclass +class SymIntArgument(_Union): + as_name: Annotated[str, 10] + as_int: Annotated[int, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymFloats. +# The "as_float" field is used in the case where we have a list containing a mix +# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to +# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints +# to the "as_float" field. +@_union_dataclass +class SymFloatArgument(_Union): + as_name: Annotated[str, 10] + as_float: Annotated[float, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@_union_dataclass +class SymBoolArgument(_Union): + as_name: Annotated[str, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorArgument: + name: Annotated[str, 10] + + +@dataclass +class TokenArgument: + name: Annotated[str, 10] + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values serialized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@_union_dataclass +class OptionalTensorArgument(_Union): + as_tensor: Annotated[TensorArgument, 20] + as_none: Annotated[bool, 10] + + +@dataclass +class GraphArgument: + name: Annotated[str, 10] + graph: Annotated["Graph", 20] + + +@dataclass +class CustomObjArgument: + name: Annotated[str, 10] + class_fqn: Annotated[str, 20] + + +@dataclass +class ComplexValue: + real: Annotated[float, 10] + imag: Annotated[float, 20] + + +# This is actually a union type +@_union_dataclass +class Argument(_Union): + as_none: Annotated[bool, 10] + as_tensor: Annotated[TensorArgument, 20] + as_tensors: Annotated[list[TensorArgument], 30] + as_int: Annotated[int, 50] + as_ints: Annotated[list[int], 70] + as_float: Annotated[float, 80] + as_floats: Annotated[list[float], 90] + as_string: Annotated[str, 100] + as_strings: Annotated[list[str], 101] + as_sym_int: Annotated[SymIntArgument, 110] + as_sym_ints: Annotated[list[SymIntArgument], 120] + as_scalar_type: Annotated[ScalarType, 130] + as_memory_format: Annotated[MemoryFormat, 140] + as_layout: Annotated[Layout, 150] + as_device: Annotated[Device, 160] + as_bool: Annotated[bool, 170] + as_bools: Annotated[list[bool], 180] + as_sym_bool: Annotated[SymBoolArgument, 182] + as_sym_bools: Annotated[list[SymBoolArgument], 184] + as_graph: Annotated[GraphArgument, 200] + as_optional_tensors: Annotated[list[OptionalTensorArgument], 190] + as_custom_obj: Annotated[CustomObjArgument, 210] + as_operator: Annotated[str, 220] + as_sym_float: Annotated[SymFloatArgument, 230] + as_sym_floats: Annotated[list[SymFloatArgument], 240] + as_optional_tensor: Annotated[OptionalTensorArgument, 250] + as_complex: Annotated[ComplexValue, 260] + as_int_lists: Annotated[list[list[int]], 280] + as_string_to_argument: Annotated[dict[str, "Argument"], 290] + + +class ArgumentKind(IntEnum): + UNKNOWN = 0 + POSITIONAL = 1 + KEYWORD = 2 + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: Annotated[str, 10] + arg: Annotated[Argument, 20] + kind: Annotated[Optional[ArgumentKind], 30] = None + + +@dataclass +class Node: + target: Annotated[str, 10] + inputs: Annotated[list[NamedArgument], 20] + outputs: Annotated[list[Argument], 30] + metadata: Annotated[dict[str, str], 40] + is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None + + +@dataclass +class Graph: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + nodes: Annotated[list[Node], 30] + tensor_values: Annotated[dict[str, TensorMeta], 40] + sym_int_values: Annotated[dict[str, SymInt], 50] + sym_bool_values: Annotated[dict[str, SymBool], 60] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: Annotated[bool, 70] = False + custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field( + default_factory=dict + ) + sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Annotated[Argument, 10] + + +@_union_dataclass +class ConstantValue(_Union): + as_none: Annotated[bool, 10] + as_int: Annotated[int, 20] + as_float: Annotated[float, 30] + as_string: Annotated[str, 40] + as_bool: Annotated[bool, 50] + + +@dataclass +class InputToConstantInputSpec: + name: Annotated[str, 10] + value: Annotated[ConstantValue, 20] + + +@dataclass +class InputToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class InputToBufferSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + persistent: Annotated[bool, 30] + + +@dataclass +class InputToTensorConstantSpec: + arg: Annotated[TensorArgument, 10] + tensor_constant_name: Annotated[str, 20] + + +@dataclass +class InputToCustomObjSpec: + arg: Annotated[CustomObjArgument, 10] + custom_obj_name: Annotated[str, 20] + + +@dataclass +class InputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@_union_dataclass +class InputSpec(_Union): + user_input: Annotated[UserInputSpec, 10] + parameter: Annotated[InputToParameterSpec, 20] + buffer: Annotated[InputToBufferSpec, 30] + tensor_constant: Annotated[InputToTensorConstantSpec, 40] + custom_obj: Annotated[InputToCustomObjSpec, 50] + token: Annotated[InputTokenSpec, 70] + constant_input: Annotated[InputToConstantInputSpec, 60] + + +@dataclass +class UserOutputSpec: + arg: Annotated[Argument, 10] + + +@dataclass +class LossOutputSpec: + arg: Annotated[TensorArgument, 10] + + +@dataclass +class BufferMutationSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + + +@dataclass +class ParameterMutationSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class GradientToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class GradientToUserInputSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class UserInputMutationSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class OutputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@_union_dataclass +class OutputSpec(_Union): + user_output: Annotated[UserOutputSpec, 10] + loss_output: Annotated[LossOutputSpec, 20] + buffer_mutation: Annotated[BufferMutationSpec, 30] + gradient_to_parameter: Annotated[GradientToParameterSpec, 40] + gradient_to_user_input: Annotated[GradientToUserInputSpec, 50] + user_input_mutation: Annotated[UserInputMutationSpec, 60] + token: Annotated[OutputTokenSpec, 70] + parameter_mutation: Annotated[ParameterMutationSpec, 80] + + +@dataclass +class GraphSignature: + input_specs: Annotated[list[InputSpec], 10] + output_specs: Annotated[list[OutputSpec], 20] + + +@dataclass +class RangeConstraint: + min_val: Annotated[Optional[int], 10] + max_val: Annotated[Optional[int], 20] + + +@dataclass +class ModuleCallSignature: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: Annotated[str, 30] + out_spec: Annotated[str, 40] + + # This field is used to prettify the graph placeholders + # after we Ser/Der and retrace + forward_arg_names: Annotated[Optional[list[str]], 50] = None + + +@dataclass +class ModuleCallEntry: + fqn: Annotated[str, 10] + signature: Annotated[Optional[ModuleCallSignature], 30] = None + + +@dataclass +class NamedTupleDef: + field_names: Annotated[list[str], 10] + + +@dataclass +class GraphModule: + graph: Annotated[Graph, 10] + signature: Annotated[GraphSignature, 50] + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: Annotated[list[ModuleCallEntry], 60] + metadata: Annotated[dict[str, str], 40] = field(default_factory=dict) + # Mapping of namedtuple types to namedtuple field names, used for BC + treespec_namedtuple_fields: Annotated[dict[str, NamedTupleDef], 70] = field( + default_factory=dict + ) + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be updated. +@dataclass +class SchemaVersion: + major: Annotated[ + int, 10 + ] # Major version number is bumped every time a breaking change is made. + minor: Annotated[ + int, 20 + ] # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: Annotated[GraphModule, 10] + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Annotated[dict[str, int], 20] + range_constraints: Annotated[dict[str, RangeConstraint], 30] + schema_version: Annotated[SchemaVersion, 60] + verifiers: Annotated[list[str], 70] = field(default_factory=list) + torch_version: Annotated[str, 80] = "<=2.4" + guards_code: Annotated[list[str], 90] = field(default_factory=list) + + +######################################################################### +# Container types for inference tasks, not being used directly for export. +######################################################################### + + +# The metadata for payload saved in PT2 archive. +# payload includes params, buffers, tensor constants, and custom objects. +@dataclass +class PayloadMeta: + # the path of the payload in the archive file, e.g. "weight_0" + path_name: Annotated[str, 10] + is_param: Annotated[bool, 20] + # whether the payload is serialized using pickle. + # Only custom objects and tensor subclasses that are not fake tensors + # are serialized using pickle. + use_pickle: Annotated[bool, 30] + # Custom Objects don't have tensor_meta and will be serialized using pickle + tensor_meta: Annotated[Optional[TensorMeta], 40] + + +# The mapping from payload FQN to its metadata. +@dataclass +class PayloadConfig: + config: Annotated[dict[str, PayloadMeta], 10] + + +# +# The structure is used to serialize instances of AOTInductorModel to pass +# them from the publishing pipeline to the predictor. +# +# All new fields should be marked as optional. +# +@dataclass +class AOTInductorModelPickleData: + # Base name of an associated .so AOTInductor library. Typically looks like: + # "abc.so". + library_basename: Annotated[str, 1] + + # AOTInductor engine input names. + input_names: Annotated[list[str], 2] + + # AOTInductor engine output names. + output_names: Annotated[list[str], 3] + + # These fields tell whether floating point inputs/outputs should be converted to + # a certain type. If None, the dtypes that the AOTInductor engine inferred from the sample + # inputs are used. + floating_point_input_dtype: Annotated[Optional[int], 4] = None + floating_point_output_dtype: Annotated[Optional[int], 5] = None + + # Whether AOTInductor runtime is for CPU. + aot_inductor_model_is_cpu: Annotated[Optional[bool], 6] = None + + +@dataclass +class ExternKernelNode: + # name is not the unique identifier of the node + name: Annotated[str, 10] + node: Annotated[Node, 20] + + +@dataclass +class ExternKernelNodes: + nodes: Annotated[list[ExternKernelNode], 10] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.yaml b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f13741416cb35c4a6ac482c9f95c8d87a61e9d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,559 @@ +# @generated by update_schema.py +# checksum<> +AOTInductorModelPickleData: + kind: struct + fields: + library_basename: + type: str + input_names: + type: List[str] + output_names: + type: List[str] + floating_point_input_dtype: + type: Optional[int] + default: None + floating_point_output_dtype: + type: Optional[int] + default: None + aot_inductor_model_is_cpu: + type: Optional[bool] + default: None +Argument: + kind: union + fields: + as_none: + type: bool + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str + as_sym_float: + type: SymFloatArgument + as_sym_floats: + type: List[SymFloatArgument] + as_optional_tensor: + type: OptionalTensorArgument + as_complex: + type: ComplexValue + as_int_lists: + type: List[List[int]] + as_string_to_argument: + type: Dict[str, Argument] +ArgumentKind: + kind: enum + fields: + UNKNOWN: 0 + POSITIONAL: 1 + KEYWORD: 2 +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +ComplexValue: + kind: struct + fields: + real: + type: float + imag: + type: float +ConstantValue: + kind: union + fields: + as_none: + type: bool + as_int: + type: int + as_float: + type: float + as_string: + type: str + as_bool: + type: bool +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + verifiers: + type: List[str] + default: '[]' + torch_version: + type: str + default: <=2.4 + guards_code: + type: List[str] + default: '[]' +ExternKernelNode: + kind: struct + fields: + name: + type: str + node: + type: Node +ExternKernelNodes: + kind: struct + fields: + nodes: + type: List[ExternKernelNode] +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' + sym_float_values: + type: Dict[str, SymFloat] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] + metadata: + type: Dict[str, str] + default: '{}' + treespec_namedtuple_fields: + type: Dict[str, NamedTupleDef] + default: '{}' +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec + token: + type: InputTokenSpec + constant_input: + type: InputToConstantInputSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToConstantInputSpec: + kind: struct + fields: + name: + type: str + value: + type: ConstantValue +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +InputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str + forward_arg_names: + type: Optional[List[str]] + default: None +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument + kind: + type: Optional[ArgumentKind] + default: None +NamedTupleDef: + kind: struct + fields: + field_names: + type: List[str] +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] + is_hop_single_tensor_return: + type: Optional[bool] + default: None +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: TensorArgument + as_none: + type: bool +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec + token: + type: OutputTokenSpec + parameter_mutation: + type: ParameterMutationSpec +OutputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +ParameterMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +PayloadConfig: + kind: struct + fields: + config: + type: Dict[str, PayloadMeta] +PayloadMeta: + kind: struct + fields: + path_name: + type: str + is_param: + type: bool + use_pickle: + type: bool + tensor_meta: + type: Optional[TensorMeta] +RangeConstraint: + kind: struct + fields: + min_val: + type: Optional[int] + max_val: + type: Optional[int] +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 + UINT16: 28 + FLOAT8E4M3FN: 29 + FLOAT8E5M2: 30 + FLOAT8E4M3FNUZ: 31 + FLOAT8E5M2FNUZ: 32 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_bool: + type: bool + as_float: + type: float +SymFloat: + kind: union + fields: + as_expr: + type: SymExpr + as_float: + type: float +SymFloatArgument: + kind: union + fields: + as_name: + type: str + as_float: + type: float +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +TokenArgument: + kind: struct + fields: + name: + type: str +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 8 +- 15 +TREESPEC_VERSION: 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema_check.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec1fdb9026b9e2f2dec6d9f13ca0d6246904f3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import dataclasses +import hashlib +import inspect +import re +import typing +from enum import IntEnum +from typing import Annotated, Any, ForwardRef, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +_CPP_TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "F64", + bool: "bool", +} + +_THRIFT_TYPE_MAP = { + str: "string", + int: "i64", + float: "double", + bool: "bool", +} + + +def _staged_schema(): + yaml_ret: dict[str, Any] = {} + defs = {} + cpp_enum_defs: dict[str, str] = {} + cpp_class_defs: dict[str, str] = {} + cpp_type_decls: list[str] = [] + cpp_json_defs: list[str] = [] + thrift_enum_defs: list[str] = [] + thrift_type_defs: dict[str, str] = {} + + def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + def dump_type(t, level: int) -> tuple[str, str, str]: + if getattr(t, "__name__", None) in cpp_enum_defs: + return t.__name__, "int64_t", t.__name__ + elif t in _CPP_TYPE_MAP: + return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t]) + elif isinstance(t, str): + assert t in defs + assert t not in cpp_enum_defs + assert "[" not in t + return t, f"ForwardRef<{t}>", t + elif isinstance(t, ForwardRef): + return ( + t.__forward_arg__, + f"ForwardRef<{t.__forward_arg__}>", + t.__forward_arg__, + ) + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o is list: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "List", + "std::vector", + "list<", + ">", + ) + elif o is dict: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "Dict", + "std::unordered_map", + "map<", + ">", + ) + elif o == Union: + assert level == 0, "Optional is only supported at the top level." + args = typing.get_args(t) + assert len(args) == 2 and args[1] is type(None) + yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1) + return ( + f"Optional[{yaml_type}]", + f"std::optional<{cpp_type}>", + f"optional {thrift_type}", + ) + elif o is Annotated: + return dump_type(t.__origin__, level) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + yaml_arg_types, cpp_arg_types, thrift_arg_types = zip( + *[dump_type(x, level + 1) for x in typing.get_args(t)] + ) + return ( + (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), + (f"{cpp_head}<{', '.join(cpp_arg_types)}>"), + f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}", + ) + elif isinstance(t, type): + return (t.__name__, t.__name__, t.__name__) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_cpp_value(v) -> str: + if v is None: + return "std::nullopt" + elif v is True: + return "true" + elif v is False: + return "false" + elif v == {}: + return "{}" + elif v == []: + return "{}" + elif v == (): + return "{}" + elif isinstance(v, str): + return f'"{v}"' + else: + raise AssertionError( + f"Default value {v} is not supported yet in export schema." + ) + + def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: + t, cpp_type, thrift_type = dump_type(f.type, 0) + ret = {"type": t} + cpp_default: Optional[str] = None + assert typing.get_origin(f.type) is Annotated, ( + f"Field {f.name} must be annotated with an integer id." + ) + thrift_id = f.type.__metadata__[0] + assert type(thrift_id) is int, ( + f"Field {f.name} must be annotated with an integer id." + ) + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + cpp_default = dump_cpp_value(value) + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + return ret, cpp_type, cpp_default, thrift_type, thrift_id + + yaml_ret = {} + cpp_ret = {} + thrift_ret = {} + thrift_ids = set() + for f in dataclasses.fields(ty): + yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f) + yaml_ret[f.name] = yaml_res + cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default} + thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id} + if thrift_id in thrift_ids: + raise AssertionError( + f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}." + ) + thrift_ids.add(thrift_id) + return yaml_ret, cpp_ret, thrift_ret + + def _handle_int_enum(name, ty): + yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + cpp_enum_defs[name] = f""" +enum class {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}}; + +inline std::string_view printEnum(const {name}& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::{x.name}: return {chr(34)}{x.name}{chr(34)};" for x in ty])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.name}{chr(34)}) {{ t = {name}::{x.name}; return; }}" for x in ty])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} +""" + thrift_enum_defs.append( + f""" +enum {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}} +""" + ) + + def _handle_struct(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "struct", "fields": fields} + field_decls = "\n".join( + f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};" + for name, f in cpp_fields.items() + ) + + def accessor(name, ty): + type_name = fields[name]["type"] + if type_name in cpp_enum_defs: + return f""" + {type_name} get_{name}() const {{ + return static_cast<{type_name}>({name}); + }} + + void set_{name}({type_name} def) {{ + {name} = static_cast(def); + }} +""" + return f""" + const {ty}& get_{name}() const {{ + return {name}; + }} + + void set_{name}({ty} def) {{ + {name} = std::move(def); + }} +""" + + to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)" + to_json_def = f"""{{ +{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])} +}} +""" + from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)" + + from_json_def = f"""{{ + {name} nlohmann_json_default_obj; +{ + chr(10).join( + [ + f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items() + ] + ) + } +}} +""" + cpp_class_defs[name] = f""" +class {name} {{ + private: +{field_decls} + + public: +{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])} + friend {to_json_decl}; + friend {from_json_decl}; +}}; +""" + cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}") + cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[name] = f""" +struct {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + def _handle_union(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "union", "fields": fields} + + def accessor(name, ty, idx): + return f""" + const {ty}& get_{name}() const {{ + return std::get<{idx + 1}>(variant_); + }} + + void set_{name}({ty} def) {{ + variant_.emplace<{idx + 1}>(std::move(def)); + tag_ = Tag::{name.upper()}; + }} +""" + + to_json_branches = "".join( + [ + f""" + if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{ + nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}(); + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + from_json_branches = "".join( + [ + f""" + if (nlohmann_json_j.contains("{name}")) {{ + nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>()); + nlohmann_json_t.tag_ = Tag::{name.upper()}; + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + + cpp_class_defs[name] = f""" +class {name} {{ + struct Void {{}}; + + public: + enum class Tag {{ + {", ".join([name.upper() for name in cpp_fields])} + }}; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const {{ + return tag_; + }} +{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])} + friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{ +{to_json_branches} + }} + + friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{ +{from_json_branches} + }} +}}; + +inline std::string_view printEnum(const {name}::Tag& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::Tag::{x.upper()}: return {chr(34)}{x.upper()}{chr(34)};" for x in cpp_fields])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}::Tag& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.upper()}{chr(34)}) {{ t = {name}::Tag::{x.upper()}; return; }}" for x in cpp_fields])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} + +""" + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[name] = f""" +union {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + class_ordering = {} + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + class_ordering[name] = inspect.findsource(value)[1] + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"]) + yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert yaml_ret["TREESPEC_VERSION"] > 0 + + cpp_header = f""" +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{ +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END }} +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> {{ + static void to_json(json& j, const std::optional& opt) {{ + if (opt == std::nullopt) {{ + j = nullptr; + }} else {{ + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + }} + }} + + static void from_json(const json& j, std::optional& opt) {{ + if (j.is_null()) {{ + opt = std::nullopt; + }} else {{ + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + }} + }} +}}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch {{ +namespace _export {{ + +template +class ForwardRef {{ + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {{}} + ForwardRef(ForwardRef&&); + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {{}} + ForwardRef& operator=(ForwardRef&&); + ForwardRef& operator=(const ForwardRef& other) {{ + ptr_ = std::make_unique(*other.ptr_); + return *this; + }} + ~ForwardRef(); + const T& operator*() const {{ + return *ptr_; + }} + + const T* operator->() const {{ + return ptr_.get(); + }} + + void emplace(T&& t) {{ + ptr_ = std::make_unique(std::move(t)); + }} + + private: + std::unique_ptr ptr_; +}}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) {{ + j = *p; +}} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) {{ + p.emplace(j.template get()); +}} + +class F64 {{ + public: + double get() const {{ + return value_; + }} + + void set(double value) {{ + value_ = value; + }} + + private: + double value_; +}}; + +inline void to_json(nlohmann::json& j, const F64& f) {{ + if (std::isinf(f.get())) {{ + j = "Infinity"; + }} else if (std::isinf(-f.get())) {{ + j = "-Infinity"; + }} else if (std::isnan(f.get())) {{ + j = "NaN"; + }} else {{ + j = f.get(); + }} +}} + +inline void from_json(const nlohmann::json& j, F64& f) {{ + if (j == "Infinity") {{ + f.set(std::numeric_limits::infinity()); + }} else if (j == "-Infinity") {{ + f.set(-std::numeric_limits::infinity()); + }} else if (j == "NaN") {{ + f.set(std::numeric_limits::quiet_NaN()); + }} else {{ + f.set(j.get()); + }} +}} + +{chr(10).join(cpp_type_decls)} +{"".join(cpp_enum_defs.values())} +{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +{chr(10).join(cpp_json_defs)} + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; +}} // namespace _export +}} // namespace torch +""" + thrift_schema = f""" +namespace py3 torch._export +namespace cpp2 torch._export.schema +{chr(10).join(thrift_enum_defs)} +{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +""" + return yaml_ret, cpp_header, thrift_schema + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_content(s: str): + return hashlib.sha256(s.strip().encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: dict[str, Any] + checksum_next: str + yaml_path: str + additions: dict[str, Any] + subtractions: dict[str, Any] + base: dict[str, Any] + checksum_head: Optional[str] + cpp_header: str + cpp_header_path: str + thrift_checksum_head: Optional[str] + thrift_checksum_real: Optional[str] + thrift_checksum_next: str + thrift_schema: str + thrift_schema_path: str + + +def update_schema(): + import importlib.resources + + # pyrefly: ignore [bad-argument-type] + if importlib.resources.is_resource(__package__, "schema.yaml"): + # pyrefly: ignore [bad-argument-type] + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_head = match.group(1) + + thrift_content = importlib.resources.read_text( + # pyrefly: ignore [bad-argument-type] + __package__, + "export_schema.thrift", + ) + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content) + _check(match is not None, "checksum not found in export_schema.thrift") + assert match is not None + thrift_checksum_head = match.group(1) + thrift_content = thrift_content.splitlines() + assert thrift_content[0].startswith("// @" + "generated") + assert thrift_content[1].startswith("// checksum<<") + thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) + + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_head = None + thrift_checksum_head = None + thrift_checksum_real = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src, cpp_header, thrift_schema = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + # pyrefly: ignore [missing-attribute] + yaml_path = __package__.replace(".", "/") + "/schema.yaml" + # pyrefly: ignore [missing-attribute] + thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift" + torch_prefix = "torch/" + assert yaml_path.startswith(torch_prefix) # sanity check + assert thrift_schema_path.startswith(torch_prefix) # sanity check + + return _Commit( + result=src, + checksum_next=_hash_content(repr(src)), + yaml_path=yaml_path, + additions=additions, + subtractions=subtractions, + base=dst, + checksum_head=checksum_head, + cpp_header=cpp_header, + cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", + thrift_checksum_head=thrift_checksum_head, + thrift_checksum_real=thrift_checksum_real, + thrift_checksum_next=_hash_content(thrift_schema), + thrift_schema=thrift_schema, + thrift_schema_path=thrift_schema_path, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if kind == "struct" and "default" not in d: + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incompatible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/serialize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..c64aaff9ae1f2b693c753a3b26fa94462cfca870 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,3936 @@ +# mypy: allow-untyped-defs +import base64 +import copy +import copyreg +import dataclasses +import heapq +import inspect +import io +import json +import keyword +import logging +import math +import operator +import re +import traceback +import typing +from collections import namedtuple, OrderedDict +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Annotated, Any, cast, final, Optional, Union + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.non_strict_utils import _enable_graph_inputs_of_type_nn_module +from torch._export.verifier import load_verifier +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import prefix_str, SymT +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._traceback import CapturedTraceback +from torch.utils._triton import has_triton + +from ..utils import remove_proxy_from_state_dict +from . import schema +from .schema import ( # type: ignore[attr-defined] + Argument, + ArgumentKind, + BufferMutationSpec, + ComplexValue, + ConstantValue, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToConstantInputSpec, + InputToCustomObjSpec, + InputTokenSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + NamedTupleDef, + Node, + OptionalTensorArgument, + OutputSpec, + OutputTokenSpec, + ParameterMutationSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SchemaVersion, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymFloat, + SymFloatArgument, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TokenArgument, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union + + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[ + FakeTensor, + int, + torch.SymInt, + float, + torch.SymFloat, + bool, + torch.SymBool, + ep.CustomObjArgument, +] + +DEFAULT_PICKLE_PROTOCOL = 2 + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.uint16: ScalarType.UINT16, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16, + torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, + torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, + torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + +_SYM_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + operator.neg, + operator.pos, + operator.and_, + operator.or_, + math.trunc, + torch.sym_not, + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + operator.pow, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, + operator.truediv, + operator.and_, +} + + +assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_OPS) + + +@dataclass +class SerializedArtifact: + exported_program: bytes + state_dict: bytes + constants: bytes + example_inputs: bytes + + +@dataclass +class _SerializedProgram: + exported_program: ExportedProgram + state_dict: bytes + constants: bytes + example_inputs: bytes + + +class LazyMap(dict): + """ + Dictionary class for deferred instantiation of node metadata values. + Purpose is to avoid creation of symbolic-shape tensors before relevant shape guards are parsed. + """ + + def __init__(self): + self.map = {} + self.evaluated = set() + + def __setitem__(self, k, v): + self.map[k] = v + + def __getitem__(self, k): + out = self.map[k] + if k in self.evaluated: + return out + self.evaluated.add(k) + self.map[k] = out() + return self.map[k] + + def __repr__(self): + return self.map.__repr__() + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def deserialize_size(sizes: Sequence[SymInt]) -> tuple[int, ...]: + for sym_int_size in sizes: + assert sym_int_size.type == "as_int", ( + f"Only as_int is supported, got {sym_int_size.type}" + ) + return tuple(sym_int_size.as_int for sym_int_size in sizes) + + +def deserialize_stride(strides: Sequence[SymInt]) -> tuple[int, ...]: + for sym_int_stride in strides: + assert sym_int_stride.type == "as_int", ( + f"Only as_int is supported, got {sym_int_stride.type}" + ) + return tuple(sym_int_stride.as_int for sym_int_stride in strides) + + +def deserialize_scalar_type(st: ScalarType) -> torch.dtype: + return _SERIALIZE_TO_TORCH_DTYPE[st] + + +def deserialize_storage_offset(offset: SymInt) -> int: + assert offset.type == "as_int", f"Only as_int is supported, got {offset.type}" + return offset.as_int + + +def _print_sympy(s: Union[torch.SymInt, torch.SymBool, torch.SymFloat, sympy.Expr]): + if isinstance(s, (torch.SymInt, torch.SymBool, torch.SymFloat)): + s = s.node.expr + return sympy.printing.repr.srepr(s) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, sympy.Symbol, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, (torch.SymInt, sympy.Symbol)) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymInt.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_int=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat: + if isinstance(s, (torch.SymFloat, sympy.Symbol, float)): + if symbolic_shapes.is_concrete_float(s): + return SymFloat.create(as_float=float(s)) + else: + assert isinstance(s, (torch.SymFloat, sympy.Symbol)) + if s.node.hint is None: + return SymFloat.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymFloat.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_float=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=_print_sympy(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(t.storage_offset()), + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps( + _dataclass_to_dict(tensor_meta), cls=EnumEncoder + ).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor( + serialized_tensor_meta: bytes, is_parameter: bool +) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert _CURRENT_DESERIALIZER is not None, ( + "Need access to current deserializer state" + ) + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + # pyrefly: ignore [bad-return] + return fake_tensor + + +def serialize_torch_artifact( + artifact: Optional[Any], pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL +) -> bytes: + if artifact is None: + return b"" + + assert FakeTensor not in copyreg.dispatch_table, ( + "Refusing to stomp on existing FakeTensor reducer" + ) + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer, pickle_protocol=pickle_protocol) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact( + serialized: Union[dict[str, Any], tuple[Any, ...], bytes], +): + if isinstance(serialized, (dict, tuple)): + return serialized + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + # weights_only=False as we want to load custom objects here (e.g. ScriptObject) + try: + artifact = torch.load(buffer, weights_only=True) + except Exception as e: + buffer.seek(0) + artifact = torch.load(buffer, weights_only=False) + log.warning( + "Fallback to weights_only=False succeeded. " + "Loaded object of type %s after initial failure: %s", + type(artifact), + exc_info=e, + ) + assert isinstance(artifact, (tuple, dict)) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]: + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return None + if val in (-sympy.oo, -int_oo): + return None + if isinstance(val, sympy.Integer): + return int(val) + + # TODO: Remove this adjustment when Ed gets rid of fractional ranges + log.warning( + "Export constraints cannot be non-integer expressions. Found " + "type %s, and value %s. We will attempt to %s " + "this value.", + type(val), + val, + adjust, + ) + + if adjust == "floor": + return math.floor(val) + elif adjust == "ceil": + return math.ceil(val) + else: + raise RuntimeError(f"Got invalid adjustment {adjust}") + + +def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val is None: + return default + if val in [-int_oo, int_oo]: + return val + if val == math.inf: + return int_oo + if val == -math.inf: + return -int_oo + return sympy.Integer(val) + + +def _symbol_index(sym: sympy.Symbol, sym_type: SymT): + return int(str(sym)[len(prefix_str[sym_type]) :]) + + +def serialize_range_constraints( + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] + _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _get_schema_from_target(target): + if isinstance(target, torch._ops.OpOverload): + return target._schema + elif type(target) in _serialization_registry: + return _serialization_registry[type(target)].op_schema(target) + raise RuntimeError(f"Cannot find schema for {type(target)}") + + +@dataclass +class GraphState: + inputs: list[Argument] = field(default_factory=list) + outputs: list[Argument] = field(default_factory=list) + nodes: list[Node] = field(default_factory=list) + tensor_values: dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: dict[str, SymBool] = field(default_factory=dict) + sym_float_values: dict[str, SymFloat] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict) + + +class Final(type): + def __new__(metacls, name, bases, classdict): + for b in bases: + if isinstance(b, Final): + raise TypeError(f"type '{b.__name__}' is not an acceptable base type") + return type.__new__(metacls, name, bases, dict(classdict)) + + +def is_metadata_matched(config, entry_metadata): + metadata_attrs = ["num_cpu_threads", "num_warps", "num_stages", "num_ctas"] + for attr in metadata_attrs: + if hasattr(config, attr) and hasattr(entry_metadata, attr): + if getattr(config, attr) != getattr(entry_metadata, attr): + return False + return True + + +def get_triton_kernel_and_cache_entry(node: torch.fx.Node): + assert ( + node.target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ) + + assert has_triton(), "triton required to serialize triton kernels" + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + assert isinstance(node.kwargs["kernel_idx"], int) + kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel( + node.kwargs["kernel_idx"] + ) + + # For Autotuner, we need to look at the underlying JITFunction's cache + # since the Autotuner itself doesn't have a cache + is_autotuner = isinstance(kernel, Autotuner) + # pyrefly: ignore [missing-attribute] + actual_kernel = kernel.fn if is_autotuner else kernel + + if hasattr(actual_kernel, "device_caches"): + caches = actual_kernel.device_caches + assert len(caches.keys()) == 1 + cache = next(iter(caches.values()))[0] + elif hasattr(actual_kernel, "cache"): + # old path, still used for cpu triton builds + caches = actual_kernel.cache + assert len(caches.keys()) == 1 + cache = next(iter(caches.values())) + else: + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"kernel caches not found for kernel {actual_kernel.__name__}" + ) + + if len(cache.keys()) == 1: + return actual_kernel, next(iter(cache.values())) + + has_constexprs = ( + isinstance(actual_kernel, JITFunction) + and hasattr(actual_kernel, "constexprs") + and len(actual_kernel.constexprs) > 0 + ) + + if has_constexprs: + constexpr_vals = {} + # pyrefly: ignore [missing-attribute] + for constexpr_idx in actual_kernel.constexprs: + # pyrefly: ignore [missing-attribute] + if constexpr_idx < len(actual_kernel.arg_names): + # pyrefly: ignore [missing-attribute] + param_name = actual_kernel.arg_names[constexpr_idx] + kwargs_dict = node.kwargs.get("kwargs", {}) + if isinstance(kwargs_dict, dict): + if param_name in kwargs_dict: + constexpr_vals[param_name] = kwargs_dict[param_name] + + expected_values = [ + # pyrefly: ignore [missing-attribute] + constexpr_vals[actual_kernel.arg_names[idx]] + # pyrefly: ignore [missing-attribute] + for idx in actual_kernel.constexprs + # pyrefly: ignore [missing-attribute] + if actual_kernel.arg_names[idx] in constexpr_vals + ] + + matching_entries = [] + for sig_key, cache_entry in cache.items(): + constexpr_matches = re.findall(r"\('constexpr',\s*([^)]+)\)", sig_key) + if constexpr_matches: + constexpr_values = [] + for match in constexpr_matches: + if match in ("True", "False"): + constexpr_values.append(match == "True") + elif "." in match or "e" in match or "E" in match: + constexpr_values.append(float(match)) + else: + constexpr_values.append(int(match)) + + if constexpr_values == expected_values: + matching_entries.append((sig_key, cache_entry)) + else: + matching_entries = list(cache.items()) + + if len(matching_entries) == 0: + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {actual_kernel.__name__}. " + f"Available cache keys: {list(cache.keys())}" + ) + + if len(matching_entries) == 1: + return actual_kernel, matching_entries[0][1] + + if is_autotuner: + for _sig_key, cache_entry in matching_entries: + entry_metadata = cache_entry.metadata + # pyrefly: ignore [missing-attribute] + for config in kernel.configs: + if is_metadata_matched(config, entry_metadata): + return actual_kernel, cache_entry + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'} " + f"and couldn't disambiguate using configs. " + ) + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for non-autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'}. " + f"This should not happen. Available cache keys: {[key for key, _ in matching_entries]}" + ) + + +@final +class GraphModuleSerializer(metaclass=Final): + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: list[ep.ModuleCallEntry], + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: dict[str, str] = {} + self.treespec_namedtuple_fields: dict[str, NamedTupleDef] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + val = node.meta["val"] + log.debug("[handle_placeholder] %s: %s", node.name, val) + if isinstance(val, torch.Tensor): + graph_input = Argument.create( + as_tensor=self.serialize_tensor_output(node.name, val) + ) + elif isinstance(val, torch.SymInt): + graph_input = Argument.create( + as_sym_int=self.serialize_sym_int_output(node.name, val) + ) + elif isinstance(val, torch.SymFloat): + raise AssertionError("SymFloat graph input is not implemented yet.") + elif isinstance(val, (int, bool, str, float, type(None))): + graph_input = self.serialize_input(val) + elif isinstance(val, ep.CustomObjArgument): + class_fqn = val.class_fqn + graph_input = Argument.create( + as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) + ) + self.graph_state.custom_obj_values[node.name] = ( + self.serialize_script_obj_meta(val) + ) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + log.debug("[handle_output] %s: %s", node.name, node_args) + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + meta_val = node.meta.get("val") + log.debug( + "[handle_call_function] %s: %s(%s, {%s}) -> %s", + node.name, + node.target, + node.args, + node.kwargs, + meta_val, + ) + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + if node.target in _SYM_OPS or ( + meta_val is not None + and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)) + ): + assert len(node.kwargs) == 0 + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[self.serialize_output(node.name, meta_val)], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + + def _is_hop_single_tensor_return(node) -> bool: + assert isinstance(node.target, torch._ops.HigherOrderOperator) + # HOP schema is not always available, so we look at node.meta["val"] + meta_val = node.meta.get("val", None) + return meta_val is not None and isinstance(meta_val, torch.Tensor) + + # Special handle serialization for aoti_call_delegate + if node.target is torch._higher_order_ops.aoti_call_delegate: + serializable_args = list(node.args) + + # AOTI lowered module is not serializable, serialize the aoti_path instead + lowered_module_name: str = node.args[0].name # type: ignore[assignment, no-untyped-def, union-attr] + assert hasattr(node.graph.owning_module, lowered_module_name) + lowered_module = getattr(node.graph.owning_module, lowered_module_name) # type: ignore[no-untyped-def] + serializable_args[0] = lowered_module.aoti_path + + # AOTI compiled graph module in node.args[0] is stateful, and will fail the verifier check + # Skip serializing original_gm as a workaround + serializable_args[1] = None + + serializable_weight_nodes = [] + if serializable_args[2] is not None and isinstance( + serializable_args[2], Iterable + ): + for weight_node in serializable_args[2]: + # skip passing custom obj into the weight arg as an hack + # The schema of weight input is a list of Tensors. + # Downstream runtime is not actively consuming the weighs arg for anything meaningful. + if isinstance(weight_node, torch.fx.Node) and isinstance( + weight_node.meta.get("val", None), ep.CustomObjArgument + ): + continue + serializable_weight_nodes.append(weight_node) + serializable_args[2] = serializable_weight_nodes + + def serialize_tensor_list_output(node): + meta_val = node.meta.get("val", None) + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(serializable_args, node.kwargs), + outputs=serialize_tensor_list_output(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=False, + ) + elif ( + node.target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ): + kernel, kernel_cache_entry = get_triton_kernel_and_cache_entry(node) + kernel_cache_metadata = kernel_cache_entry.metadata + + meta_val = node.meta["val"] + assert isinstance(meta_val, dict) + + output_keys = meta_val.keys() + output_indices = [] + + constexpr_keys = {p.name for p in kernel.params if p.is_constexpr} + found_constexpr = False + args_new = () + i = 0 + + assert isinstance(node.kwargs["kwargs"], dict) + for k, v in node.kwargs["kwargs"].items(): + # don't serialize constexpr since they will + # be embedded into the binary and don't + # need to be passed around as attributes + if k in constexpr_keys: + found_constexpr = True + continue + + assert not found_constexpr, ( + "non-constexpr args found after constexpr arg(s)" + ) + + if k in output_keys: + output_indices.append(i) + args_new += (v,) # type: ignore[assignment] + i += 1 + + assert isinstance(node.kwargs["grid"], list) + + kernel_name_with_hash = ( + f"{kernel.fn.__name__}_{kernel_cache_metadata.hash}" + ) + kwargs_new = { + "name": kernel_name_with_hash, + "grid": node.kwargs["grid"][0], + "output_indices": output_indices, + "num_warps": kernel_cache_metadata.num_warps, + } + if hasattr(kernel_cache_metadata, "num_cpu_threads"): + kwargs_new["num_cpu_threads"] = ( + kernel_cache_metadata.num_cpu_threads + ) + + if hasattr(kernel_cache_metadata, "shared"): + kwargs_new["shared_memory_bytes"] = kernel_cache_metadata.shared + + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(args_new, kwargs_new), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=_is_hop_single_tensor_return(node), + ) + else: + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=_is_hop_single_tensor_return(node), + ) + elif type(node.target) in _serialization_registry: + # Sanity check for unhandled serialization. + assert type(node.target) in _serialization_registry, ( + f"{type(node.target)} is not supported in export serialization." + ) + + handler = _serialization_registry[type(node.target)] + namespace = handler.namespace() + op_name = handler.to_op_name(node.target) + assert isinstance(namespace, str) and isinstance(op_name, str) + assert ":" not in namespace and ":" not in op_name + ex_node = Node( + target=f"#{namespace}:{op_name}", + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + log.debug("[handle_get_attr] %s", node.name) + + def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]: + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + + def _output_node_name_at_index(self, node, index) -> str: + user_node = self._output_node_at_index(node, index) + if user_node is None: + return f"{node.name}_unused_{index}" + else: + return user_node.name + + def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]: + ret = {} + + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + assert isinstance(ty, str) + + return path + "," + ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [ + f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" + for source_fn in source_fn_st + ] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + if torch_fn := node.meta.get("torch_fn"): + ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + + if custom := node.meta.get("custom"): + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for node {node.name} with error {e}" + ) from e + + return ret + + def serialize_script_obj_meta( + self, script_obj_meta: ep.CustomObjArgument + ) -> CustomObjArgument: + log.debug("[serialize_script_obj_meta] %s", script_obj_meta) + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]: + if isinstance(op, torch._ops.OpOverload): + args_names = [arg.name for arg in op._schema.arguments] + else: + assert op in _SYM_OPS + args_names = list(inspect.signature(op).parameters.keys()) + serialized_args = [] + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument( + name=args_name, + arg=self.serialize_input(arg), + kind=ArgumentKind.POSITIONAL, + ) + ) + return serialized_args + + def serialize_inputs( + self, + target: Any, # torch._ops.OpOverload and other custom operator types. + args, + kwargs=None, + ) -> list[NamedArgument]: + schema = None + serialized_args = [] + + if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): + obj = args[0] + method = args[1] + schema = target.schema(obj, method) + else: + assert isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + schema = _get_schema_from_target(target) + assert schema is not None + kwargs = kwargs or {} + + for i, schema_arg in enumerate(schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input( + kwargs[schema_arg.name], schema_arg.type + ), + kind=ArgumentKind.KEYWORD, + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i], schema_arg.type), + kind=ArgumentKind.POSITIONAL, + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", arg=self.serialize_input(a), kind=ArgumentKind.POSITIONAL + ) + for a in args + ] + inputs.extend( + [ + NamedArgument( + name=name, + arg=self.serialize_input(a), + kind=ArgumentKind.KEYWORD, + ) + for name, a in kwargs.items() + ] + ) + return inputs + + def is_inductor_sym_int_arg(self, arg) -> bool: + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node and should be + # verified with is_sym_int_arg() + return type(arg) is int or isinstance(arg, torch.SymInt) + + def is_sym_int_arg(self, arg) -> bool: + return type(arg) is int or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_float_arg(self, arg) -> bool: + return isinstance(arg, float) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_float_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_bool_values + ) + + # should be torch._C.JitType but that annotation is busted + def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: + import torch._inductor.ir as inductor_ir + + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create( + as_graph=GraphArgument(name=arg.target, graph=graph) + ) + elif type(attr).__name__ == "LoweredBackendModule": + # Special handling for executorch_call_delegate HOP + # It's first argument is a LoweredBackendModule, for which we + # serialize name and backend id of the lowered module + module_name = getattr(attr, "module_name", None) + backend_id = getattr(attr, "backend_id", None) + assert module_name is not None, "module_name should not be None" + assert backend_id is not None, "backend_id should not be None" + return Argument.create(as_string=f"{module_name}-{backend_id}") + else: + raise SerializeError( + f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" + ) + elif self.is_sym_int_arg(arg): + return Argument.create( + as_sym_int=SymIntArgument.create(as_name=arg.name) + ) + elif self.is_sym_float_arg(arg): + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=arg.name) + ) + elif self.is_sym_bool_arg(arg): + return Argument.create( + as_sym_bool=SymBoolArgument.create(as_name=arg.name) + ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn + ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, inductor_ir.TorchBindObject): + # This is a special branch for handling TorchBindObject + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + arg_val = arg.get_real_obj() + class_fqn = arg_val._type().qualified_name() + self.custom_objs[arg_name] = arg_val + return Argument.create(as_custom_obj=CustomObjArgument(arg_name, class_fqn)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, torch.SymFloat): + # This is a special branch for handling SymFloat args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_float_arg(arg) being true + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=str(arg)) + ) + elif type(arg) is bool: + return Argument.create(as_bool=arg) + elif type(arg) is str: + return Argument.create(as_string=arg) + elif type(arg) is int: + return Argument.create(as_int=arg) + elif type(arg) is float: + return Argument.create(as_float=arg) + elif type(arg) is complex: + return Argument.create( + as_complex=ComplexValue(real=arg.real, imag=arg.imag) + ) + elif arg is None: + return Argument.create(as_none=True) + elif isinstance(arg, dict): + serialized_dict = {} + for key, value in arg.items(): + if not isinstance(key, str): + raise SerializeError(f"Dict keys must be strings, got {type(key)}") + serialized_dict[key] = self.serialize_input(value) + return Argument.create(as_string_to_argument=serialized_dict) + elif isinstance(arg, (list, tuple)): + if len(arg) == 0: + if arg_type is not None: + if isinstance(arg_type, torch.OptionalType): + arg_type = arg_type.getElementType() # type: ignore[assignment] + assert isinstance(arg_type, torch.ListType) + elem_type = arg_type.getElementType() + if isinstance(elem_type, torch.OptionalType): + elem_type = elem_type.getElementType() + + if isinstance(elem_type, torch.BoolType): + return Argument.create(as_bools=[]) + elif isinstance(elem_type, torch.IntType): + return Argument.create(as_ints=[]) + elif isinstance(elem_type, torch.FloatType): + return Argument.create(as_floats=[]) + elif isinstance(elem_type, torch.StringType): + return Argument.create(as_strings=[]) + elif isinstance(elem_type, torch.TensorType): + return Argument.create(as_tensors=[]) + else: + # I believe empty symint lists default to ints, but + # please file an issue if this is not the case + raise SerializeError(f"Empty list with type {elem_type} nyi.") + else: + # We could serialize this by default to a tensor list. This + # is needed in the HOO case + log.warning( + "Unsure how to serialize the given empty list, " + "as we don't know what is the type of this argument. " + "Serializing it as a tensor list by default." + ) + return Argument.create(as_tensors=[]) + + if all(type(a) is bool for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(type(a) is int for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(type(a) is float for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(type(a) is str for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(self.is_inductor_sym_int_arg(a) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node + values = [] + for a in arg: + if isinstance(a, torch.SymInt): + values.append(SymIntArgument.create(as_name=str(a))) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(isinstance(a, torch.SymFloat) for a in arg): + return Argument.create( + as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_float_arg(a) for a in arg): + # list of sym_float + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymFloatArgument.create(as_name=a.name)) + elif isinstance(a, float): + values.append(SymFloatArgument.create(as_float=a)) + return Argument.create(as_sym_floats=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.name) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all( + isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg + ): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.get_name()) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all( + isinstance(a, tuple) and all(type(x) is int for x in a) for a in arg + ): + # list of int tuples + return Argument.create(as_int_lists=[list(t) for t in arg]) + else: + raise SerializeError( + f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" + ) + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create( + as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] + ) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") # type: ignore[attr-defined] + and arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create( + as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) + ) + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError( + f"Unsupported argument type: {type(arg)} with schema arg_type {arg_type}" + ) + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument: + assert name not in self.graph_state.sym_float_values + self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val) + return SymFloatArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + log.debug("[serialize_input_spec] %s", spec) + if spec.kind == ep.InputKind.USER_INPUT: + if isinstance(spec.arg, ep.ConstantArgument): + if type(spec.arg.value) is int: + constant_spec = ConstantValue.create(as_int=spec.arg.value) + elif type(spec.arg.value) is bool: + constant_spec = ConstantValue.create(as_bool=spec.arg.value) + elif type(spec.arg.value) is str: + constant_spec = ConstantValue.create(as_string=spec.arg.value) + elif type(spec.arg.value) is float: + constant_spec = ConstantValue.create(as_float=spec.arg.value) + elif spec.arg.value is None: + constant_spec = ConstantValue.create(as_none=True) + else: + raise SerializeError( + f"Unhandled constant input {spec.arg.value} to serialize" + ) + return InputSpec.create( + constant_input=InputToConstantInputSpec( + name=spec.arg.name, value=constant_spec + ) + ) + else: + return InputSpec.create( + user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument( + name=spec.arg.name, class_fqn=spec.arg.class_fqn + ), + custom_obj_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return InputSpec.create( + token=InputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + log.debug("[serialize_output_spec] %s", spec) + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.PARAMETER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + parameter_mutation=ParameterMutationSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return OutputSpec.create( + token=OutputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + log.debug("\n[serialize_signature]") + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.SymFloatArgument): + return Argument.create(as_sym_float=SymFloatArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) + ) + else: + raise AssertionError("TODO") + + def serialize_treespec(self, treespec: pytree.TreeSpec) -> str: + # We want to additionally save all the field names of the namedtuples in + # case users want to check that the treespec types are equivalent + def store_namedtuple_fields(ts: pytree.TreeSpec) -> None: + if ts.type is None: + return + if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): + serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ + ts.context + ].serialized_type_name + if serialized_type_name in self.treespec_namedtuple_fields: + field_names = self.treespec_namedtuple_fields[ + serialized_type_name + ].field_names + if field_names != ts.context._fields: + raise SerializeError( + f"The given TreeSpec's namedtuple type {ts.context} " + f"was found to have field names {ts.context._fields} " + f"but somehow previously was found to have field names {field_names}." + ) + else: + self.treespec_namedtuple_fields[serialized_type_name] = ( + NamedTupleDef(field_names=ts.context._fields) + ) + + for child in ts.children(): + store_namedtuple_fields(child) + + serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) + store_namedtuple_fields(treespec) + return serialized_treespec + + def serialize_module_call_signature( + self, module_call_signature: ep.ModuleCallSignature + ) -> ModuleCallSignature: + log.debug("[serialize_module_call_signature] %s", module_call_signature) + return ModuleCallSignature( + inputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=self.serialize_treespec(module_call_signature.in_spec), + out_spec=self.serialize_treespec(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def serialize_module_call_graph( + self, module_call_graph: list[ep.ModuleCallEntry] + ) -> list[ModuleCallEntry]: + log.debug("\n[serialize_module_call_graph]") + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.serialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + + def _is_single_tensor_list_return(target: Any) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + assert node.op == "call_function" and isinstance( + node.target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + + schema = _get_schema_from_target(node.target) + returns = schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + # When the return type is annotated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=True)) + elif isinstance(meta, FakeTensor): + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = self._output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node_name = self._output_node_name_at_index(user_node, i) + args.append(self.serialize_tensor_output(sub_user_node_name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt, float, SymFloat)): + user_node_name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(user_node_name, meta)) + else: + raise ValueError( + f"Unhandled output type {type(meta)} from node {node.format_node()}" + ) + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + outputs = [] + for i, element_meta_val in enumerate(meta_val): + user_node = self._output_node_at_index(node, i) + if isinstance(element_meta_val, list): + # e.g "-> Tensor[]" + assert user_node is not None + + tensors = [] + for j, m in enumerate(element_meta_val): + if not isinstance(m, torch.Tensor): + raise SerializeError( + f"Serialize list output with type {type(m)} nyi" + ) + + name = self._output_node_name_at_index(user_node, j) + tensors.append(self.serialize_tensor_output(name, m)) + outputs.append(Argument.create(as_tensors=tensors)) + + else: + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{i}" + ) + + outputs.append(self.serialize_output(name, element_meta_val)) + + return outputs + elif isinstance(meta_val, dict): + tensor_args = [] + # use the dict key as the idx + for idx, meta in meta_val.items(): + if not isinstance(meta, torch.Tensor): + raise SerializeError( + f"Serialize list output with type {type(meta)} nyi" + ) + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=True) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create( + as_tensor=self.serialize_tensor_output(name, meta_val) + ) + elif isinstance(meta_val, (bool, torch.SymBool)): + # e.g "-> SymBool" + return Argument.create( + as_sym_bool=self.serialize_sym_bool_output(name, meta_val) + ) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + assert not isinstance(meta_val, bool) + return Argument.create( + as_sym_int=self.serialize_sym_int_output(name, meta_val) + ) + elif isinstance(meta_val, (float, torch.SymFloat)): + # e.g "-> SymFloat" + return Argument.create( + as_sym_float=self.serialize_sym_float_output(name, meta_val) + ) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert user.target is operator.getitem, ( + f"User node {user} of {node} is incorrect" + ) + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + log.debug( + "[serialize_graph]\n\n%s", graph_module.print_readable(print_output=False) + ) + + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError( + f"Failed serializing node {node} in graph: {node.format_node()}\n Original exception {traceback.format_exc()}" + ) from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_float_values=self.graph_state.sym_float_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize_graph_module_metadata(self, meta: dict[str, Any]): + ret = {} + if custom := meta.get("custom"): + log.debug("\n[serialize_graph_module_metadata] %s", custom) + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for graph with error {e}" + ) from e + + return ret + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + log.debug("\n[serialize]") + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + metadata=self.serialize_graph_module_metadata(graph_module.meta), + treespec_namedtuple_fields=self.treespec_namedtuple_fields, + ) + + +@final +class ExportedProgramSerializer(metaclass=Final): + def __init__( + self, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + ): + self.opset_version: dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + self.pickle_protocol = pickle_protocol + + def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: + """ + Args: + exported_program: Exported Program to serialize + """ + exported_program.validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints( + exported_program.range_constraints + ) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants: dict[str, Any] = gm_serializer.custom_objs.copy() + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, + guards_code=exported_program._guards_code, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep, set(constants.keys())) + + # Proxy cannot be dumped, so we remove them. + new_state_dict = remove_proxy_from_state_dict( + exported_program.state_dict, in_place=False + ) + return _SerializedProgram( + serialized_ep, + serialize_torch_artifact(new_state_dict, self.pickle_protocol), + serialize_torch_artifact(constants, self.pickle_protocol), + serialize_torch_artifact( + exported_program.example_inputs, self.pickle_protocol + ), + ) + + +@final +class GraphModuleDeserializer(metaclass=Final): + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: list[ep.ModuleCallEntry] + names_to_symbols: dict[str, sympy.Symbol] + state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: dict[str, _ConstantAttributeType] + example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]] + + def __init__(self) -> None: + self.serialized_name_to_node: dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = LazyMap() + self.unbacked_symbols: set[sympy.Symbol] = set() + try: + yield + finally: + ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) = saved + + def deserialize_extension_operator(self, serialized_target: str): + namespace, op_name = serialized_target.split(":") + namespace = namespace[1:] # starting with # + handler = _deserialization_registry[namespace] + return handler.from_op_name(op_name) + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith( + "_operator" + ): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("math"): + module = math # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("#"): + return self.deserialize_extension_operator(serialized_target) + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def _parse_sym_expr( + self, expr_str: str, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + """ + Parses and does bottom-up processing of sympy.Expr nodes, + populating ShapeEnv & caching symbols as needed. + """ + + def _process_sym_expr( + sym: sympy.Expr, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + if sym.is_Integer or sym.is_Float or sym.is_Boolean: # base case + return sym + else: # recursive case + # important to use str(expr) and not _print_sympy(), + # str(expr) is key for self.symbol_name_to_range + expr_str = str(sym) + for arg in sym.args: + self._parse_sym_expr(arg) + # symbol caching + if expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[expr_str] + else: + self.symbol_name_to_symbol[expr_str] = sym + if isinstance(sym, sympy.Symbol) and symbolic_shapes.symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + self.unbacked_symbols.add(sym) + # hints + if hint is not None and sym not in self.shape_env.var_to_val: + self.shape_env.add_var_to_val(sym, hint) # type: ignore[arg-type] + # ValueRanges + if vr := self.symbol_name_to_range.get(expr_str): + self.shape_env.constrain_symbol_range( + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + # ShapeEnv meta + if isinstance(sym, sympy.Symbol): + self.shape_env.var_to_stack[sym] = CapturedTraceback.extract(skip=1) + return sym + + expr = sympy.sympify( + expr_str, + locals={**self.sympy_functions, **self.symbol_name_to_symbol}, + ) + return _process_sym_expr(expr, hint) + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert type(val) is int + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]: + val = s.value + if s.type == "as_expr": + hint = val.hint.as_float if val.hint else None + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symfloatnode(sym, hint=hint) + elif s.type == "as_float": + assert isinstance(val, float) + return val + else: + raise SerializeError( + f"SymFloat has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + expr = self._parse_sym_expr(val.expr_str) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + requires_grad=tensor_meta.requires_grad, + ), + ) + + def deserialize_script_obj_meta( + self, script_obj_meta: CustomObjArgument + ) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_sym_float": + return self.serialized_name_to_node[output.as_sym_float.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_float": + return output.as_float + elif output.type == "as_bool": + return output.as_bool + elif output.type == "as_none": + return None + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + log.debug("\n[deserialize_graph]") + + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) + self.serialized_name_to_meta[name] = ( + lambda v=tensor_value: self.deserialize_tensor_meta(v) + ) + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_int_value: self.deserialize_sym_int(v) + ) + + for name, sym_float_value in serialized_graph.sym_float_values.items(): + log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_float_value: self.deserialize_sym_float(v) + ) + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) + self.serialized_name_to_meta[name] = ( + lambda v=sym_bool_value: self.deserialize_sym_bool(v) + ) + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) + self.serialized_name_to_meta[name] = ( + lambda v=script_obj_meta: self.deserialize_script_obj_meta(v) + ) + + log.debug("\n[deserialize graph nodes]") + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + log.debug("[deserialize input] %s", input_) + if input_.type in ("as_tensor", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + elif input_.type == "as_sym_int": + if input_.value.type == "as_name": + node_name = input_.value.as_name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + else: + raise SerializeError( + f"Deserializing a constant symint {input_.value} as an input" + ) + elif input_.type in ( + "as_int", + "as_float", + "as_bool", + "as_none", + "as_string", + ): + node_name = self.signature.input_specs[i].arg.name or f"arg{i}" + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError( + f"Failed deserializing node {serialized_node}\n Original exception {traceback.format_exc()}" + ) from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + log.debug("[deserialize output] %s", output) + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ) + + # recompute unbacked bindings + for node in self.graph.nodes: + if (val := node.meta.get("val")) is not None and ( + unbacked_bindings := symbolic_shapes._free_unbacked_symbols_with_path( + val, + (), + shape_env=self.shape_env, + pending=self.unbacked_symbols, + simplify=True, + ) + ): + node.meta["unbacked_bindings"] = unbacked_bindings + + assert len(self.unbacked_symbols) == 0 + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + def _is_single_tensor_return(target) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + return len(returns) == 1 and isinstance( + returns[0].real_type, torch.TensorType + ) + + if ( + target in _SYM_OPS + or target + == torch.ops.aten.item.default # this can produce either SymInt or SymBool + ): + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + elif ( + target + is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional + ): + raise SerializeError( + "deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional" + ) + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + metadata = self.deserialize_metadata(serialized_node.metadata) + for x in (*args, *kwargs.values()): + if isinstance(x, torch.fx.Node) and x.op == "get_attr": + # this means that we have deserialized a graph argument, but + # unfortunately the schema for it does not include metadata; + # so we reuse the metadata of the HOP call for such arguments + x.meta.update(metadata) + # If a serialized HOP node has a length=1 outputs of type `as_tensor``. + # There could be two cases: + # (1) The HOP node returns a single tensor + # (2) The HOP node returns a tuple containing a single tensor + # We distinguish (1) and (2) by the `is_single_tensor_return` + # field in the schema of Node + # For BC, getattr() will return True if `is_single_tensor_return` doesn't + # exist. This is because prior to adding `is_single_tensor_return`, + # only (1) could happen as we handle (2) with type `as_tensors` + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 + and hasattr(serialized_node.outputs[0], "as_tensor") + and getattr(serialized_node, "is_hop_single_tensor_return", True) + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(metadata) + + elif isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + else: + _additional_msg = ( + ( + f"We failed to resolve {target} to an operator. " + + "If it's a custom op/custom triton op, this is usually because the custom op is not registered" + + " when deserializing. Please import the custom op to register it before deserializing." + + " Otherwise, please file an issue on github." + ) + if isinstance(target, str) + else "" + ) + raise SerializeError( + _additional_msg + + f" Unsupported target type for node {serialized_node}: {type(target)}." + ) + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + log.debug( + "[deserialize_node] %s: %s(%s, {%s}) -> %s", + fx_node.name, + fx_node.target, + fx_node.args, + fx_node.kwargs, + fx_node.meta.get("val"), + ) + + # handle ShapeEnv asserts + if target is torch.ops.aten._assert_scalar.default: + if not isinstance((arg := fx_node.args[0]), bool): + expr = arg.meta["val"] # type: ignore[union-attr] + if isinstance(expr, torch.SymBool): + self.shape_env.guard_or_defer_runtime_assert( + expr.node.expr, "", fx_node + ) + elif target is torch.ops.aten.sym_constrain_range_for_size.default: + sym = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(sym, torch.SymInt): + self.shape_env._constrain_range_for_size(sym.node.expr) + + # handle nn_module_stack; serialization throws away empty dicts + if ( + fx_node.op not in ["placeholder", "output"] + and "nn_module_stack" not in fx_node.meta + ): + fx_node.meta["nn_module_stack"] = {} + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + log.debug("[deserialize_input_spec] %s", i) + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None, + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument( + name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn + ), + target=i.custom_obj.custom_obj_name, + ) + elif i.type == "token": + return ep.InputSpec( + kind=ep.InputKind.TOKEN, + arg=ep.TokenArgument(name=i.token.arg.name), + target=None, + ) + elif i.type == "constant_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=ep.ConstantArgument( + name=i.constant_input.name, + value=self.deserialize_constant_input(i.constant_input.value), + ), + target=None, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + log.debug("[deserialize_output_spec] %s", o) + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name, + ) + elif o.type == "parameter_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.PARAMETER_MUTATION, + arg=ep.TensorArgument(name=o.parameter_mutation.arg.name), + target=o.parameter_mutation.parameter_name, + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name, + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name, + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name, + ) + elif o.type == "token": + return ep.OutputSpec( + kind=ep.OutputKind.TOKEN, + arg=ep.TokenArgument(name=o.token.arg.name), + target=None, + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + log.debug("\n[deserialize_signature]") + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, Any], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + log.debug("\n[deserialize]") + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.sympy_functions = { + # all torch.utils._sympy.functions should go here + # TODO(avik): find a better way to keep this collection in sync; + # e.g.., `exec('from torch.utils._sympy.functions import *', ...)` + # would work as long as the public API of that module is complete + "FloorDiv": torch.utils._sympy.functions.FloorDiv, + "ModularIndexing": torch.utils._sympy.functions.ModularIndexing, + "Where": torch.utils._sympy.functions.Where, + "PythonMod": torch.utils._sympy.functions.PythonMod, + "Mod": torch.utils._sympy.functions.Mod, + "CleanDiv": torch.utils._sympy.functions.CleanDiv, + "CeilToInt": torch.utils._sympy.functions.CeilToInt, + "FloorToInt": torch.utils._sympy.functions.FloorToInt, + "CeilDiv": torch.utils._sympy.functions.CeilDiv, + "LShift": torch.utils._sympy.functions.LShift, + "RShift": torch.utils._sympy.functions.RShift, + "PowByNatural": torch.utils._sympy.functions.PowByNatural, + "FloatPow": torch.utils._sympy.functions.FloatPow, + "FloatTrueDiv": torch.utils._sympy.functions.FloatTrueDiv, + "IntTrueDiv": torch.utils._sympy.functions.IntTrueDiv, + "IsNonOverlappingAndDenseIndicator": torch.utils._sympy.functions.IsNonOverlappingAndDenseIndicator, + "TruncToFloat": torch.utils._sympy.functions.TruncToFloat, + "TruncToInt": torch.utils._sympy.functions.TruncToInt, + "RoundToInt": torch.utils._sympy.functions.RoundToInt, + "RoundDecimal": torch.utils._sympy.functions.RoundDecimal, + "ToFloat": torch.utils._sympy.functions.ToFloat, + "Identity": torch.utils._sympy.functions.Identity, + } + self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {} + self.constants = deserialize_torch_artifact(constants) + self.signature = self.deserialize_signature( + serialized_graph_module.signature + ) + + # deserialization does analysis with checks on 0/1, so we create fake range constraints and + # restore the original range constraints afterwards + self.symbol_name_to_range = {} + # we also need to bump unbacked sym[float,int] counters in the + # shape env to accommodate unbacked symbols in the exported program + self.unbacked_symbols = set() + count_unbacked_symfloat, count_unbacked_symint = -1, -1 + unbacked_symfloat_prefix, unbacked_symint_prefix = ( + prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT] + ) + if symbol_name_to_range: + for k, vr in symbol_name_to_range.items(): + lower = vr.lower + self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( + _int_to_sympy_int(lower, -int_oo), vr.upper + ) + if k.startswith(unbacked_symfloat_prefix): + i = int(k[len(unbacked_symfloat_prefix) :]) + count_unbacked_symfloat = max(count_unbacked_symfloat, i) + elif k.startswith(unbacked_symint_prefix): + i = int(k[len(unbacked_symint_prefix) :]) + count_unbacked_symint = max(count_unbacked_symint, i) + + # TODO(pianpwk): if we can clean up unused symbols in range_constraints, + # then this logic can just be handled with self.unbacked_symbols alone + for _ in range(count_unbacked_symfloat + 1): + self.shape_env.unbacked_symfloat_counter += 1 + for _ in range(count_unbacked_symint + 1): + self.shape_env.unbacked_symint_counter += 1 + + if example_inputs is not None and len(example_inputs) > 0: + self.example_inputs = deserialize_torch_artifact(example_inputs) + else: + self.example_inputs = None + self.deserialize_graph(serialized_graph_module.graph) + + with _enable_graph_inputs_of_type_nn_module(self.example_inputs): + module_call_graph = self.deserialize_module_call_graph( + serialized_graph_module.module_call_graph + ) + graph_module = ep._create_graph_module_for_export(self.module, self.graph) + meta = {} + if custom := serialized_graph_module.metadata.get("custom"): + meta["custom"] = json.loads(custom) + if hasattr(serialized_graph_module, "treespec_namedtuple_fields"): + meta["treespec_namedtuple_fields"] = {} + for ( + type_, + fields, + ) in serialized_graph_module.treespec_namedtuple_fields.items(): + meta["treespec_namedtuple_fields"][type_] = fields.field_names + graph_module.meta = meta + return GraphModuleDeserializer.Result( + graph_module=graph_module, + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + example_inputs=self.example_inputs, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + # overwrite name + fx_node.name = name + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target, serialized_node: Node): + schema_args = _get_schema_from_target(target).arguments + argument_kinds = {input.name: input.kind for input in serialized_node.inputs} + actual_args = { + input.name: self.deserialize_input(input.arg) + for input in serialized_node.inputs + } + args = [] + kwargs: OrderedDict[str, Any] = OrderedDict() + for schema_arg in schema_args: + if schema_arg.name in actual_args: + arg = actual_args[schema_arg.name] + kind = argument_kinds[schema_arg.name] + if kind == ArgumentKind.POSITIONAL: + args.append(arg) + continue + elif kind == ArgumentKind.KEYWORD and not keyword.iskeyword( + schema_arg.name + ): + kwargs[schema_arg.name] = arg + continue + + # If there's no ArgumentKind found, fallback to the old cases. + is_positional = ( + not schema_arg.has_default_value() and not schema_arg.kwarg_only + ) + if is_positional: + args.append(actual_args[schema_arg.name]) + elif keyword.iskeyword(schema_arg.name): + assert not schema_arg.kwarg_only + if len(kwargs) > 0: + kwargs = OrderedDict() + args.extend(list(kwargs.values())) + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: list[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_complex": + return complex(inp.as_complex.real, inp.as_complex.imag) + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_float": + return self.deserialize_sym_argument(inp.as_sym_float) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, dict): + if typ_ == "as_string_to_argument": + # Deserialize dict[str, Argument] recursively + return {k: self.deserialize_input(v) for k, v in value.items()} + else: + raise SerializeError(f"Unknown dict type: {typ_}") + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [self.serialized_name_to_node[arg.name] for arg in value] + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ == "as_int_lists": + # Convert list of lists back to list of tuples for Triton grids + return [tuple(dims) for dims in value] + elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value.name] + else: + raise SerializeError(f"Unhandled argument {inp}") + + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_constant_input(self, inp: ConstantValue) -> Any: + if inp.type == "as_int": + return int(inp.as_int) + elif inp.type == "as_float": + return float(inp.as_float) + elif inp.type == "as_string": + return str(inp.as_string) + elif inp.type == "as_bool": + return bool(inp.as_bool) + elif inp.type == "as_none": + return None + else: + raise SerializeError(f"Unhandled constant argument {inp} to deserialize") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymFloatArgument): + if sym_arg.type == "as_float": + return sym_arg.as_float + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + + if ( + len(serialized_node.outputs) == 1 + and "torch.ops.higher_order" in serialized_node.target + and not getattr(serialized_node, "is_hop_single_tensor_return", True) + and serialized_node.outputs[0].type != "as_none" + ): + + def _deserialize_hop_with_single_return(serialized_node, fx_node): + meta_val: list[Any] = [] + arg = None + if serialized_node.outputs[0].type == "as_tensor": + arg = serialized_node.outputs[0].as_tensor + elif isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + arg = serialized_node.outputs[0].value + deserialized_metadata = self.deserialize_metadata( + serialized_node.metadata + ) + assert arg is not None + # pyrefly: ignore [bad-argument-type] + self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + return + + return _deserialize_hop_with_single_return(serialized_node, fx_node) + + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif len(serialized_node.outputs) == 1 and isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + elif ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_none" + ): + # manually rename the node to a unused name to avoid naming conflicts + fx_node.meta["val"] = None + fx_node._rename(f"{self.graph._target_to_str(fx_node.target)}_unused") + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def generate_getitem( + self, + meta_val, + fx_node: torch.fx.Node, + arg: Union[TensorArgument, SymIntArgument, SymFloatArgument], + idx: int, + deserialized_metadata: dict[str, Any], + ): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + elif isinstance(arg, SymFloatArgument): + name = arg.as_name + else: + raise AssertionError( + f"generate_getitem got unknown argument type {type(arg)}" + ) + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems( + self, + meta_val, + fx_node: torch.fx.Node, + args, + deserialized_metadata: dict[str, Any], + ): + for idx, arg in enumerate(args): + if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)): + self.generate_getitem( + meta_val, fx_node, arg, idx, deserialized_metadata + ) + continue + + assert isinstance(arg, Argument) + if arg.type in ("as_tensor", "as_sym_int", "as_sym_float"): + self.generate_getitem( + meta_val, fx_node, arg.value, idx, deserialized_metadata + ) + elif arg.type in ( + "as_tensors", + "as_sym_ints", + "as_sym_floats", + "as_ints", + "as_floats", + "as_strings", + "as_bools", + "as_sym_bools", + ): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + self.generate_getitems( + meta_val[-1], list_output, arg.value, deserialized_metadata + ) + list_output.meta.update(deserialized_metadata) + list_output.meta["val"] = meta_val[-1] + elif arg.type == "as_none": + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name="as_none", + ) + meta_val.append(None) + individual_output.meta["val"] = None + individual_output.meta.update(deserialized_metadata) + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + def deserialize_multiple_outputs( + self, serialized_node: Node, fx_node: torch.fx.Node + ) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: list[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + self.generate_getitems( + meta_val, + fx_node, + serialized_node.outputs[0].as_tensors, + deserialized_metadata, + ) + else: + self.generate_getitems( + meta_val, fx_node, serialized_node.outputs, deserialized_metadata + ) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + ret: dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function to split string by commas, accounting for nested parentheses/brackets + def metadata_split(metadata): + out = [] + start, n = 0, 0 + a, b = "[(", ")]" + for end, c in enumerate(metadata): + if c in a: + n += 1 + elif c in b: + n -= 1 + elif c == "," and n == 0: + out.append(metadata[start:end]) + start = end + 1 + out.append(metadata[start:]) + assert len(out) == 3 + return out + + nn_module_stack = dict( + import_nn_module_stack(*metadata_split(item)) + for item in nn_module_stack_str.split(ST_DELIMITER) + ) + ret["nn_module_stack"] = nn_module_stack + + if source_fn_st_str := metadata.get("source_fn_stack"): + # Originally serializes to "fx_node_name,op_str" + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st + + if torch_fn_str := metadata.get("torch_fn"): + ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) + + if custom_str := metadata.get("custom"): + ret["custom"] = json.loads(custom_str) + + return ret + + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + log.debug("[deserialize_argument_spec] %s", x) + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + elif x.type == "as_sym_float": + return ep.SymFloatArgument(name=x.as_sym_float.as_name) + elif x.type == "as_custom_obj": + return ep.ConstantArgument( + name=x.as_custom_obj.name, value=self.deserialize_input(x) + ) + else: + return ep.ConstantArgument(name="", value=self.deserialize_input(x)) + + def deserialize_module_call_signature( + self, module_call_signature: ModuleCallSignature + ) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def deserialize_module_call_graph( + self, module_call_graph: list[ModuleCallEntry] + ) -> list[ep.ModuleCallEntry]: + log.debug("\n[deserialize_module_call_graph]") + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.deserialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + +@final +class ExportedProgramDeserializer(metaclass=Final): + def __init__(self, expected_opset_version: Optional[dict[str, int]] = None): + self.expected_opset_version: dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: dict[str, sympy.Symbol], + ) -> dict[sympy.Symbol, ValueRanges]: + log.debug("\n[deserialize_range_constraints]") + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + log.debug("[deserialize_range_constraints] %s -> %s", k, v) + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning( + "Symbol %s did not appear in the graph that was deserialized", k + ) + return range_constraints + + def deserialize( + self, + exported_program: ExportedProgram, + state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, torch.Tensor], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + *, + _unsafe_skip_version_check=False, + ) -> ep.ExportedProgram: + assert isinstance(exported_program, ExportedProgram) + version = exported_program.schema_version + + # TODO(zhxchen17) blocked on thrift schema refactor + if version.major != SCHEMA_VERSION[0] and not ( + version.major == 0 and version.minor == 0 + ): + if not _unsafe_skip_version_check: + raise SerializeError( + f"Serialized schema version {exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges( + _int_to_sympy_int(v.min_val, -int_oo), + _int_to_sympy_int(v.max_val, int_oo), + ) + for k, v in exported_program.range_constraints.items() + } + res = GraphModuleDeserializer().deserialize( + exported_program.graph_module, + state_dict, + constants, + example_inputs, + symbol_name_to_range, + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, + res.names_to_symbols, + ) + + result = ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=res.example_inputs, + constants=res.constants, + verifiers=[load_verifier(v) for v in exported_program.verifiers], + ) + result._guards_code = exported_program.guards_code + log.debug("\n[deserialize]: %s", result) + return result + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, float): + if obj == math.inf: + return "Infinity" + elif obj == -math.inf: + return "-Infinity" + elif math.isnan(obj): + return "NaN" + else: + return obj + else: + return obj + + +def _to_json_bytes(obj: Any) -> bytes: + return json.dumps(_dataclass_to_dict(obj), cls=EnumEncoder, allow_nan=False).encode( + "utf-8" + ) + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> SerializedArtifact: + with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): + serialized_program = ExportedProgramSerializer( + opset_version, pickle_protocol + ).serialize(exported_program) + assert isinstance(serialized_program.exported_program, ExportedProgram) + + json_bytes = _to_json_bytes(serialized_program.exported_program) + artifact = SerializedArtifact( + json_bytes, + serialized_program.state_dict, + serialized_program.constants, + serialized_program.example_inputs, + ) + return artifact + + +def _resolve_schema_cls(cls): + if isinstance(cls, str): + resolved = getattr(schema, cls, None) + if resolved is not None: + return resolved + if isinstance(cls, typing.ForwardRef): + return _resolve_schema_cls(cls.__forward_arg__) + return cls + + +def _dict_to_dataclass(cls, data): + cls = _resolve_schema_cls(cls) + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) is Annotated: + return _dict_to_dataclass(cls.__origin__, data) + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + field_type = type_hints[_type] + # pyrefly: ignore [missing-attribute] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + fields = {} + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + # For forward compatibility consideration, we ignore all the keys + # that are not showing up in the dataclass definition. + for f in dataclasses.fields(cls): + name = f.name + if name not in data: + continue + new_field_obj = _dict_to_dataclass(type_hints[name], data[name]) + fields[name] = new_field_obj + return cls(**fields) # type: ignore[operator] + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [_dict_to_dataclass(d_type, d) for d in data] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} + elif cls is float: + return float(data) + return data + + +def _bytes_to_dataclass(cls: Any, artifact_bytes: bytes) -> Any: + artifact_str = artifact_bytes.decode("utf-8") + artifact_dict = json.loads(artifact_str) + artifact_dataclass = _dict_to_dataclass(cls, artifact_dict) + return artifact_dataclass + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[dict[str, int]] = None, + *, + _unsafe_skip_version_check=False, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + serialized_exported_program = _bytes_to_dataclass( + ExportedProgram, artifact.exported_program + ) + return ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, + artifact.state_dict, + artifact.constants, + artifact.example_inputs, + _unsafe_skip_version_check=_unsafe_skip_version_check, + ) + + +def _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants +) -> tuple[Graph, dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_complex": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_sym_float": + return a.as_sym_float + elif a.type == "as_sym_floats": + return a.as_sym_floats + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return a.as_custom_obj + elif a.type == "as_operator": + return None + elif a.type == "as_int_lists": + return None + elif a.type == "as_string_to_argument": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: list[int] + ins: int + + graph_inputs: set[str] = set() + def_table: dict[str, int] = {} + edges: dict[int, Edges] = {} + candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = [] + rank: dict[str, int] = {} + ret: list[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool", "as_float"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + return a.as_tensor.name + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + elif isinstance(a, CustomObjArgument): + return a.name + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input, i) + + for idx, node in enumerate(nodes): + + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + + def add_edge(a): + if s := get_name(a): + if s in constants: + return + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + s = get_name(a) + if s and s not in constants: + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymFloatArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_float_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + elif isinstance(a, CustomObjArgument): + a.name = _rename(a.name, graph.custom_obj_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, (SymIntArgument, SymFloatArgument)): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name) + elif isinstance(a, CustomObjArgument): + a.name = name_table.get(a.name, a.name) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + # pyrefly: ignore [no-matching-overload] + sorted_tensor_values = dict( + sorted(graph.tensor_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_int_values = dict( + sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_float_values = dict( + sorted(graph.sym_float_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_sym_bool_values = dict( + sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) + ) + # pyrefly: ignore [no-matching-overload] + sorted_custom_obj_values = dict( + sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0)) + ) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph, _ = _canonicalize_graph( + a.as_graph.graph.inputs, + a.as_graph.graph.outputs, + a.as_graph.graph, + constants, + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_float_values=sorted_sym_float_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + custom_obj_values=sorted_custom_obj_values, + ) + return graph, name_table + + +def canonicalize( + ep: ExportedProgram, constants: Optional[set[str]] = None +) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + constants (Optional[set[str]]): Set of constants names + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + # pyrefly: ignore [annotation-mismatch] + constants: set[str] = constants or set() + + opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) + range_constraints = dict( + sorted(ep.range_constraints.items(), key=operator.itemgetter(0)) + ) + guards_code = sorted(ep.guards_code) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + elif spec.type == "token": + return 0, None, idx + elif spec.type == "constant_input": + return 6, spec.constant_input.name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 4, None, idx + elif spec.type == "loss_output": + return 4, None, idx + elif spec.type == "parameter_mutation": + return 1, spec.parameter_mutation.parameter_name, idx + elif spec.type == "buffer_mutation": + return 2, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 5, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 6, None, idx + elif spec.type == "user_input_mutation": + return 3, None, idx + elif spec.type == "token": + return 0, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted( + enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input + ) + + if len(sorted_ins) > 0: + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + else: + sorted_inputs = () + input_specs = () + + sorted_outs = sorted( + enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output + ) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants + ) + + def replace_input(spec): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ( + "as_none", + "as_bool", + "as_int", + "as_float", + "as_string", + "as_custom_obj", + ): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + t_custom_obj = spec.custom_obj.arg + t_custom_obj.name = replace_table[t_custom_obj.name] + return + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + elif spec.type == "constant_input": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ("as_none", "as_bool", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "parameter_mutation": + t = spec.parameter_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + verifiers=ep.verifiers, + torch_version=ep.torch_version, + guards_code=guards_code, + ) + + +class ExtensionHandler: + """ + Base class for handling extension operators. + """ + + @classmethod + def namespace(cls) -> str: + raise NotImplementedError(f"{cls.__class__} namespace() must be implemented") + + @classmethod + def to_op_name(cls, op) -> str: + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def from_op_name(cls, name: str): + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def op_schema(cls, op) -> torch.FunctionSchema: + raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented") + + +def register_extension( + op_type: type[Any], + extension_handler: type[ExtensionHandler], +): + """Register custom de/serialization method for a node with non-standard type.""" + assert issubclass(extension_handler, ExtensionHandler), ( + f"Expected ExtensionHandler, got {extension_handler}." + ) + assert op_type not in _serialization_registry, f"{op_type} is already registered." + assert isinstance(op_type, type) # Maybe a good idea to enforce this first. + assert not ( + op_type.__module__.startswith("torch") + or op_type.__module__.startswith("builtins") + ) + assert extension_handler.namespace() not in _deserialization_registry + _serialization_registry[op_type] = extension_handler + _deserialization_registry[extension_handler.namespace()] = extension_handler + + +def _registered_extension_types(): + return tuple(_serialization_registry.keys()) + + +# Registry to store all custom serialization implementations. +# The registry maps a operation to its serialization function (a callable), in their own +# namespace to avoid conflicts. +# Serialization: Op type --> custom handler. +# De-serialization: Namespace --> custom handler. +_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {} +_deserialization_registry: dict[str, type[ExtensionHandler]] = {} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/union.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..c65ad38d337fea7631c122003e263a94cc4870dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_export/serde/union.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Hashable +from dataclasses import dataclass, fields +from typing import TypeVar +from typing_extensions import dataclass_transform + + +T = TypeVar("T", bound="_Union") + + +class _UnionTag(str): + __slots__ = ("_cls",) + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names(self._cls), ( + f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + ) + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.cache +def _get_field_names(cls) -> set[str]: + return {f.name for f in fields(cls)} + + +# If you turn a schema class that inherits from union into a dataclass, please use +# this decorator to configure it. It's safe, faster and allows code sharing. +# +# For example, _union_dataclass customizes the __eq__ method to only check the type +# and value property instead of default implementation of dataclass which goes +# through every field in the dataclass. +@dataclass_transform(eq_default=False) +def _union_dataclass(cls: type[T]) -> type[T]: + assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}." + return dataclass(repr=False, eq=False)(cls) + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any( + f.name in ("type", "_type", "create", "value") + for f in fields(self) # type: ignore[arg-type, misc] + ) + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Union): + return False + return self.type == other.type and self.value == other.value + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..902040163610ab5d1dd39b819b363b780172a6be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf134f3349dbda9e9d7fb2b757d79c63cffdfb8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f5671ba835d4cb90276e21ca4ac8ecfb9b08fbe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b861ef2848c9e581c72651e04e512f0c0aed50ad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33622f7d6ea861ac95a9345d256860a3e2a78598 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..189721661edf27f671bccc1e344d2af8a3415efa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27d8284cfd244118336b201bb522dbef16be8410 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee141c8ab76edcf4f67baf9bff8bf059bd827542 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3778ab9863ffd9e2f596a49fdc0a934b2b218d33 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a36e631f7ea481a60e6b9806bec4444e10799822 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e641aae11efacdf8765d91e3119961dd66184d76 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2a2fe684f9b70e4450b0616468b2cd918614dda Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ef903b2e3e35dba21cedb64dab91ec019fd360 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52f06a936df07f858a12b6c62a76ead66881045f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bf741b588276a5f1ca4784b8c0d03fe06279116 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83003782e45a76e038a2ca9b51308e9d4c0091a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80bcddf8a3ec07c1e5ce728c233a37e1b17e8d38 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a6609e4dbcb8320f0834fe75a4388f10911db6f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..745d789c4e301529bb220ae7e2328f22d86a3666 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c528a5dfbc8e55d4c40c675dc60f1486c7566cec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16d42a23cd3b4d1cc1b5dae8d5d08e7785ca3f4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7e2d7f27b85e158192dfafadaeb4cb62a12a037 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cbf8110de1f56bafb729144d5650779f6f488fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a139486128041a39bd2fae1dfb16a1e7fec53ba5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c44d51f4c22e94e216021de08d7433c03e68232 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c95b9565b440822bc0ed2a25f47f666a809a330 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37d7c54f286aeea5aaa4e620ad12c7f59f8bd55a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02e1c945ed76cd27b8d4c467075b541e958addc4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4548d6745bf2343ef5fefd6fe21441df9f754285 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..0a209ef4d824b524564709475eff9954c59cf126 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py @@ -0,0 +1,824 @@ +""" +Activation offloading for memory optimization in (more like post) partitioners. + +This module provides functionality to offload activations to CPU during forward pass +and reload them during backward pass, reducing GPU memory usage. + +Additional TODO: +* given the fact that PT2 stream support is in active development, testings should + be done once that is more finalized. A issue currently known is that with streams, + each iteration will have its own offload streams, but the streams should be shared + across the iterations. +""" + +import logging +import operator +from dataclasses import dataclass + +import torch +import torch.fx as fx +from torch._dynamo.variables.streams import get_current_stream, new_event, new_stream +from torch._inductor import config as inductor_config +from torch._inductor.fx_passes.overlap_scheduling import benchmark_node, is_compute_node +from torch._subclasses.fake_tensor import extract_tensor_metadata +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..partitioners import _size_of, get_default_op_list, OpTypes + + +log: logging.Logger = logging.getLogger(__name__) + + +# Node name prefixes for offload/reload operations +# NOTE: right now we are using these prefixes as identifiers for offload/reload +CPU_OFFLOAD_PREFIX = "cpu_offload_" +GPU_RELOAD_PREFIX = "gpu_reload_" + + +@dataclass +class ReloadNodeInfo: + """ + Information about backward reload related nodes for each reload operation. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This pattern is divided into two logical groups for optimization purposes: + - Reload group (fork → wait_stream → device_put → record_event → join): + Performs the actual asynchronous data transfer on a separate stream. + These nodes can be moved earlier in the graph to overlap with computation. + - Wait group (wait_event): + Synchronization point that blocks until the data transfer completes. + This must remain at the point where the reloaded data is first needed. + """ + + reload_group_nodes: list[fx.Node] + wait_event_node: fx.Node + transfer_size_bytes: int + transfer_time_ms: float + + +@dataclass +class ReloadQueueEntry: + """ + Entry in the reload queue for prefetch scheduling. + + Attributes: + pattern: The reload pattern information + remaining_time_ms: Remaining overlap time needed in milliseconds + """ + + pattern: ReloadNodeInfo + remaining_time_ms: float + + +def offload_activation_fw(graph: fx.Graph) -> None: + """ + Insert CPU offload operations in the forward pass graph. + + Offload operations are placed after the last effective use of each tensor marked + for offloading. This ensures the tensor is no longer needed on the GPU before + transferring it to CPU memory. + + NOTE: An alternative approach would offload tensors immediately after generation + to maximize compute-communication overlap. However, this requires additional + synchronization to ensure tensor deletion (which occurs on the default stream) + waits for the asynchronous offload operation to complete. This would necessitate + more complex tracking to separate operation scheduling from memory cleanup. + + Args: + graph: The forward graph to modify + """ + + op_types: OpTypes = get_default_op_list() + + def find_all_effective_users(node: fx.Node) -> OrderedSet[fx.Node]: + """ + Find all effective users of a node, where view ops extend the lifetime of the + original node. If a user is a view op, recursively find users of the view. + """ + effective_users: OrderedSet[fx.Node] = OrderedSet() + for user in node.users: + if user.op == "output": + continue + effective_users.add(user) + if op_types.is_view(user): + effective_users.update(find_all_effective_users(user)) + + return effective_users + + output_node: fx.Node = graph.find_nodes(op="output")[0] + fwd_outputs: tuple[fx.Node, ...] = output_node.args[ + 0 + ] # pyrefly: ignore [bad-assignment] + node_to_offload: dict[fx.Node, fx.Node] = dict() + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + + for node in fwd_outputs: + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the last use + all_effective_users: OrderedSet[fx.Node] = find_all_effective_users(node) + if all_effective_users := find_all_effective_users(node): + last_user = max(all_effective_users, key=lambda n: node_to_index[n]) + else: + last_user: fx.Node = node + + # Insert the CPU offload operation after the last user + with graph.inserting_after(last_user): + cpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, torch.device("cpu")), + kwargs={"non_blocking": True}, + name=CPU_OFFLOAD_PREFIX + str(node.name), + ) + cpu_node.meta["val"] = node.meta["val"].to(torch.device("cpu")) + cpu_node.meta["tensor_meta"] = extract_tensor_metadata(cpu_node.meta["val"]) + + node_to_offload[node] = cpu_node + + # Update the return node args + output_node.update_arg( + 0, tuple(node_to_offload.get(node, node) for node in fwd_outputs) + ) + + +def reload_activation_bw(graph: fx.Graph) -> None: + """ + Insert GPU reload operations in the backward pass graph. + + Reload operations are placed before the first use of each offloaded tensor, + transferring it from CPU back to GPU memory before it's needed for computation. + + Args: + graph: The backward graph to modify + """ + + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + output_node: fx.Node = graph.find_nodes(op="output")[0] + + for node in graph.find_nodes(op="placeholder"): + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the first use or output node if no users + # The later should not happen, but inserting before output node is safe + insert_point: fx.Node = ( + min(node.users.keys(), key=lambda n: node_to_index[n]) + if node.users + else output_node + ) + + # Insert the GPU reload operation before the first user + original_device: torch.Device = node.meta["original_device"] + with graph.inserting_before(insert_point): + gpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, original_device), + kwargs={"non_blocking": True}, + name=str(node.name).replace(CPU_OFFLOAD_PREFIX, GPU_RELOAD_PREFIX), + ) + gpu_node.meta["val"] = node.meta["val"].to(original_device) + gpu_node.meta["tensor_meta"] = extract_tensor_metadata(gpu_node.meta["val"]) + + # Replace all uses of the CPU tensor with the GPU tensor + for user in list(node.users.keys()): + if user != gpu_node: + user.replace_input_with(node, gpu_node) + + +def can_offload( + node: fx.Node, + fwd_outputs: OrderedSet[fx.Node], + model_outputs: OrderedSet[fx.Node], + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Determine if a node can be offloaded to CPU. + + Args: + node: The node to check + fwd_outputs: Forward module outputs, including model outputs and activations + model_outputs: Model outputs + + NOTE: Additional context for the logic behind these offloading checks: + + * fwd_outputs: Only saved intermediate tensors should be offloaded. + + * model_outputs / static_lifetime_input_nodes: Tensors that may be accessed outside + the compiled region (e.g., model outputs, static inputs) cannot be offloaded as + they must remain accessible beyond the scope of the compiled graph. + + * views / getitems: Offloading such nodes can lead to segmentation faults. + + * contiguous: Offloading non-contiguous tensors causes CPU-side stride changes + during both forward and backward passes when using the Inductor backend. While + these stride changes cancel each other out, they introduce significant compute + overhead. This is due to the contiguity check in ir.py (see link below). + TODO: This restriction could potentially be bypassed in the future. + Reference: https://github.com/pytorch/pytorch/blob/44ac69388a4a5eb463dbd2a13f00d1e3b924566c/torch/_inductor/ir.py#L3214 + + Additional criteria to consider for offloading optimization: + + * Tensor size: Small tensors may not fully utilize available bandwidth, reducing the + efficiency gains from offloading. + + * Position in forward/backward graph: Activations generated near the end of the forward + pass are typically consumed near the beginning of the backward pass. Offloading such + tensors may be counterproductive since they are quickly reloaded, not having sufficient + time to overlap the transfer with computation. + """ + + log.debug(f"Checking node {node.name} for offloading...") # noqa: G004 + + op_types: OpTypes = get_default_op_list() + + if node not in fwd_outputs: + log.debug("\tSkipped! Can only offload nodes in fwd_module_outputs.") + return False + if node in model_outputs: + log.debug("\tSkipped! Cannot offload model outputs.") + return False + if node in static_lifetime_input_nodes: + log.debug("\tSkipped! Cannot offload static input nodes.") + return False + if op_types.is_view(node): + log.debug("\tSkipped! Cannot offload views.") + return False + if node.target == operator.getitem: + log.debug("\tSkipped! Cannot offload getitems.") + return False + if hasattr(node, "meta") and "val" in node.meta: + if ( + isinstance(val := node.meta["val"], torch.Tensor) + and not val.is_contiguous() + ): + log.debug("\tSkipped! Cannot offload non-contiguous tensors.") + return False + + log.debug("\tGood!") + return True + + +def choose_offload_sets( + fwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Decide which nodes will be offloaded based on the marked nodes and feasibility. + Marks nodes with "saved_for_offloading" if they should and can be offloaded. + + Args: + fwd_module: Forward graph module + bwd_module: Backward graph module + num_fwd_outputs: Number of forward outputs + + Returns: + bool: Whether activation offloading should be performed + """ + + fwd_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0] + ) + model_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0][:num_fwd_outputs] + ) + + should_perform_offloading = False + for node in fwd_module.graph.nodes: + if node.meta.get("should_offload", False) and can_offload( + node, fwd_outputs, model_outputs, static_lifetime_input_nodes + ): + node.meta["saved_for_offloading"] = True + node.meta["original_device"] = node.meta["val"].device + should_perform_offloading = True + + return should_perform_offloading + + +def offload_chosen_sets( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add offload and reload nodes to the forward and backward graphs. + This function adds device_put operations without any stream handling. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + # Add offload nodes in forward graph + offload_activation_fw(fwd_module.graph) + + # Update backward graph inputs to be offloaded tensors + bwd_inputs: dict[str, fx.Node] = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for fwd_node in fwd_module.graph.find_nodes(op="output")[0].args[0]: + if CPU_OFFLOAD_PREFIX not in fwd_node.name: + continue + + bwd_node: fx.Node = bwd_inputs[fwd_node.name.replace(CPU_OFFLOAD_PREFIX, "")] + with bwd_module.graph.inserting_after(bwd_node): + bwd_offload_node: fx.Node = bwd_module.graph.placeholder(name=fwd_node.name) + + bwd_offload_node.meta.update(fwd_node.meta) + bwd_offload_node.meta["saved_for_offloading"] = True + bwd_offload_node.meta["original_device"] = bwd_node.meta["val"].device + bwd_node.replace_all_uses_with(bwd_offload_node) + bwd_module.graph.erase_node(bwd_node) + + # Add reload nodes in backward graph + reload_activation_bw(bwd_module.graph) + + +def add_forward_offload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for forward pass CPU offloading. + + Pattern: record_event → fork → wait_event → record_stream → device_put → record_event_2 → join → wait_event_2 + + This ensures that: + 1. Offloading waits for the last use to complete (record_event on default stream) + 2. Offloading happens on a separate stream (fork → wait_event → device_put) + 3. The tensor is marked as used in the offload stream (record_stream) + 4. Execution returns to the default stream after offloading and + waits for offload to complete (record_event_2 → join → wait_event_2) + + NOTE: For stream optimization and overlapping compute with communication, + the "wait_event_2" ops can be sinked to the end of the graph. + + Args: + graph: The forward graph to modify + """ + + # Find all CPU offload nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if CPU_OFFLOAD_PREFIX in node.name and node.op == "call_function" + ] + if not offload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + offload_nodes[0].args[0].meta["val"].device # type: ignore[assignment] + ) + offload_stream_id: int = new_stream() + + for offload_node in offload_nodes: + offload_ready_event_id: int = new_event() + offload_completion_event_id: int = new_event() + + # Get the tensor being offloaded + tensor_node: fx.Node = offload_node.args[0] # type: ignore[assignment] + + with graph.inserting_before(offload_node): + # Record event on default stream to ensure last use completes + graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_ready_event_id, current_stream_id), + ) + # Fork to offload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, offload_stream_id), + name=f"stream_in_{offload_node.name}", + ) + # Wait for the event on offload stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_ready_event_id, offload_stream_id), + ) + # Inform the CUDA Caching Allocator that this tensor will be accessed in the + # offload stream. Without this, the program may prematurely free its memory + # even though the async offload operation is still in progress, and this can + # lead to memory corruption, especially with reordering for compute and + # communication overlaps. + graph.call_function( + torch.ops.streams.record_stream.default, + args=(tensor_node, offload_stream_id), + name=f"record_stream_{tensor_node.name}", + ) + with graph.inserting_after(offload_node): + # Record event on offload stream after device_put completes + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_completion_event_id, offload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(offload_stream_id, current_stream_id), + name=f"stream_out_{offload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the offload to complete on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_completion_event_id, current_stream_id), + ) + + +def add_backward_reload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for backward pass GPU reloading. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This ensures that: + 1. Reloading doesn't start prematurely (fork → wait_stream) + 2. Reloading happens on a separate stream (device_put) + 3. First use waits for reload completion (record_event → join → wait_event) + + NOTE: The pattern consists of two logical groups: + - First group (fork → wait_stream → device_put → record_event → join): + Performs asynchronous data transfer on a separate stream + - Second group (wait_event): + Data transfer completion check when the data is actually needed + + For prefetch optimization, the first group can be moved earlier in the graph + to overlap computation with data transfer, while the wait_event must remain + at its current position to prevent blocking computation unnecessarily. + + Args: + graph: The backward graph to modify + """ + + # Find all GPU reload nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if GPU_RELOAD_PREFIX in node.name and node.op == "call_function" + ] + if not reload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + reload_nodes[0].args[0].meta["original_device"] # type: ignore[assignment] + ) + reload_stream_id: int = new_stream() + + for reload_node in reload_nodes: + event_id: int = new_event() + + with graph.inserting_before(reload_node): + # Fork to reload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, reload_stream_id), + name=f"stream_in_{reload_node.name}", + ) + # Wait for default stream to prevent premature reloading + graph.call_function( + torch.ops.streams.wait_stream.default, + args=(reload_stream_id, current_stream_id), + ) + with graph.inserting_after(reload_node): + # Record event on reload stream after device_put + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(event_id, reload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(reload_stream_id, current_stream_id), + name=f"stream_out_{reload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the event on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(event_id, current_stream_id), + ) + + +def put_offload_nodes_on_separate_stream( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add stream and event related operations around offload nodes. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + add_forward_offload_stream_ops(fwd_module.graph) + add_backward_reload_stream_ops(bwd_module.graph) + + +def _validate_pattern_nodes( + fork_node: fx.Node, + wait_stream_node: fx.Node, + record_event_node: fx.Node, + join_node: fx.Node, + wait_event_node: fx.Node, +) -> None: + """ + Validate that the pattern nodes match the expected structure. + + Raises ValueError if any node doesn't match expectations. + """ + + if not ( + fork_node.op == "call_function" + and fork_node.target == torch.ops.streams.fork.default + ): + raise ValueError("Expected fork node two nodes before device_put node") + + if not ( + wait_stream_node.op == "call_function" + and wait_stream_node.target == torch.ops.streams.wait_stream.default + ): + raise ValueError("Expected wait_stream node one node before device_put node") + + if not ( + record_event_node.op == "call_function" + and record_event_node.target == torch.ops.streams.record_event.default + ): + raise ValueError("Expected record_event node one node after device_put node") + + if not ( + join_node.op == "call_function" + and join_node.target == torch.ops.streams.join.default + ): + raise ValueError("Expected join node two nodes after device_put node") + + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError("Expected wait_event node three nodes after device_put node") + + +def _calculate_transfer_size(device_put_node: fx.Node) -> int: + """Calculate the size in bytes of data being transferred.""" + + return _size_of(device_put_node.args[0]) # pyrefly: ignore [bad-argument-type] + + +def _estimate_transfer_time_in_ms(transfer_size_bytes: int) -> float: + """ + Estimate transfer time in milliseconds based on size and bandwidth. + NOTE: potentially could be standardized in node estimator class + """ + + return transfer_size_bytes / (1024**3) * 1_000 / inductor_config.cpu_gpu_bw + + +def identify_reload_patterns( + graph: fx.Graph, nodes_list: list[fx.Node], node_to_idx: dict[fx.Node, int] +) -> dict[fx.Node, ReloadNodeInfo]: + """ + Identify backward reload patterns in the graph. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This uses position-based matching since these nodes are inserted together in + add_backward_reload_stream_ops() in a specific order. Since stream operations + do not have data dependencies between them, they are unsuitable for subgroup + pattern matching type of checks. + + Returns a dict mapping device_put node to ReloadNodeInfo containing: + - reload_group_nodes: fork → wait_stream → device_put → record_event → join + - wait_event_node: the wait_event node + - transfer_size_bytes: size of data being transferred + - transfer_time_ms: estimated transfer time in milliseconds + """ + patterns: dict[fx.Node, ReloadNodeInfo] = {} + + # Find all GPU reload device_put nodes whose inputs are placeholder nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if GPU_RELOAD_PREFIX in node.name + and ( + node.args + and isinstance(node.args[0], fx.Node) + and node.args[0].op == "placeholder" + ) + ] + + # Extract patterns for each reload device_put node + for reload_node in reload_nodes: + reload_node_idx: int = node_to_idx[reload_node] + + fork_node: fx.Node = nodes_list[reload_node_idx - 2] + wait_stream_node: fx.Node = nodes_list[reload_node_idx - 1] + record_event_node: fx.Node = nodes_list[reload_node_idx + 1] + join_node: fx.Node = nodes_list[reload_node_idx + 2] + wait_event_node: fx.Node = nodes_list[reload_node_idx + 3] + + # Validate the nodes are what we expect + _validate_pattern_nodes( + fork_node, + wait_stream_node, + record_event_node, + join_node, + wait_event_node, + ) + + # Calculate transfer size and time + transfer_size_bytes: int = _calculate_transfer_size(reload_node) + transfer_time_ms: float = _estimate_transfer_time_in_ms(transfer_size_bytes) + + patterns[reload_node] = ReloadNodeInfo( + reload_group_nodes=[ + fork_node, + wait_stream_node, + reload_node, + record_event_node, + join_node, + ], + wait_event_node=wait_event_node, + transfer_size_bytes=transfer_size_bytes, + transfer_time_ms=transfer_time_ms, + ) + + return patterns + + +def reorder_for_prefetch( + nodes_list: list[fx.Node], + reload_patterns: dict[fx.Node, ReloadNodeInfo], +) -> None: + """ + Reorder nodes to prefetch reload operations by directly manipulating the graph. + + This follows the algorithm as follows: + - Go through nodes in reverse order + - When encountering a reload pattern, add it to a queue with its transfer time + - When encountering a compute node, use its runtime to satisfy overlap requirements + - Place reload patterns when their overlap requirement is satisfied + - When encountering placeholder nodes, flush queue as reloads cannot move before inputs + """ + + # Build a set of all nodes in reload groups for quick lookup + reload_group_nodes_set: set[fx.Node] = set() + for pattern in reload_patterns.values(): + reload_group_nodes_set.update(pattern.reload_group_nodes) + + # Queue to hold reload group nodes waiting to be placed (FIFO) + reload_queue: list[ReloadQueueEntry] = [] + + # Loop through nodes in reverse + for node in reversed(nodes_list): + if node.op == "output": + continue + elif node.op == "placeholder": + # Flush queue - place all remaining reloads after the last placeholder + while reload_queue: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in reversed(entry.pattern.reload_group_nodes): + node.append(reload_group_node) + break + elif node in reload_patterns: + pattern: ReloadNodeInfo = reload_patterns[node] + reload_queue.append( + ReloadQueueEntry( + pattern=pattern, remaining_time_ms=pattern.transfer_time_ms + ) + ) + elif node in reload_group_nodes_set: + continue + else: + if not reload_queue: + continue + compute_runtime_ms: float = ( + benchmark_node(node) if is_compute_node(node) else 0 + ) + reload_queue[0].remaining_time_ms -= compute_runtime_ms + + # Pop and place reload if its remaining time is satisfied (<= 0) + if reload_queue[0].remaining_time_ms <= 0: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in entry.pattern.reload_group_nodes: + node.prepend(reload_group_node) + + +def activation_offload_sink_wait(fwd_module: fx.GraphModule) -> None: + """ + Sink wait_event operations for offload completion to the end of the graph. + + This function identifies wait_event nodes for offload completion and moves them + to the end of the graph, allowing computation to overlap with offload operations. + + Args: + fwd_module: Forward module graph + """ + graph: fx.Graph = fwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Find all CPU offload device_put nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if CPU_OFFLOAD_PREFIX in node.name + ] + + # Collect all wait_event nodes that need to be moved + wait_nodes_to_sink: list[fx.Node] = [] + for offload_node in offload_nodes: + offload_idx: int = node_to_idx[offload_node] + wait_event_node: fx.Node = nodes_list[offload_idx + 3] + + # Validate it's actually a wait_event node + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError( + f"Expected wait_event node three positions after {offload_node.name}" + ) + + wait_nodes_to_sink.append(wait_event_node) + + # Find the output node, and move all wait_event nodes to just before the output node + output_node: fx.Node = graph.find_nodes(op="output")[0] + for wait_node in wait_nodes_to_sink: + output_node.prepend(wait_node) + + +def activation_reload_prefetch(bwd_module: fx.GraphModule) -> None: + """ + Prefetch backward reload operations by moving them earlier in the graph + to overlap communication with computation. + + This function identifies backward reload patterns (fork → wait_stream → device_put → + record_event → join) and moves them earlier in the execution order to overlap + the data transfer with computation, while keeping the wait_event at its original + position. + + Args: + bwd_module: Backward module graph + """ + graph: fx.Graph = bwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Step 1: Identify reload patterns + reload_patterns: dict[fx.Node, ReloadNodeInfo] = identify_reload_patterns( + graph, nodes_list, node_to_idx + ) + + # Step 2: Reorder nodes by directly manipulating the graph + reorder_for_prefetch(nodes_list, reload_patterns) + + +def enable_activation_offloading( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> None: + """ + Main entry point for activation offloading. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + num_fwd_outputs: Number of forward outputs + """ + + # Step 1: Decide which nodes to offload and mark them + should_perform_offloading: bool = choose_offload_sets( + fwd_module, + num_fwd_outputs, + static_lifetime_input_nodes, + ) + if not should_perform_offloading: + return + + # Step 2: Add offload and reload nodes to the graphs + offload_chosen_sets(fwd_module, bwd_module) + + # Step 3: Put offload nodes on separate stream if configured + if config.activation_offload_separate_stream: + put_offload_nodes_on_separate_stream(fwd_module, bwd_module) + if config.activation_offload_sink_wait: + activation_offload_sink_wait(fwd_module) + if config.activation_reload_prefetch: + activation_reload_prefetch(bwd_module) + + fwd_module.graph.lint() + bwd_module.graph.lint() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd1cb735040eb329b8b97f80661ff758c32107ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e098f16221e63d904f35b1abbb2d5d104ebf2e44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6918400da1b01e3951976d16743fd476cb48575 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc4a5e1a532d4948fcc05e7be8228e977b0cfc9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b15e01148adbead2aca585161f10c6ed3aa04ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8818aa127f45c94e2cc8c83f65d66a3919385361 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75ffdf0f51a0ee30087ab9c7a992a474e20fe332 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8755660a7cc4ce85cd12dbdb63086b3b3128c51 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0c603e6775136ff51d5dd16fa258def8b7396c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76ef20ca7e1dc18908ac9983a50a55a57bdb5bb1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0583c5dc3d631cff4d8b605dbcf040995194427 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be1b0911af93d0f8aaa364df8867770a0201f68a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea2ef9e427ce6191774c5c6ea0c2731edcd8663 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf4206cd8826688a3aabf937177bfb385a79f884 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..198f27b718d15985d45fd010e9a828a22b22de3a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4400370b34158a32a452c724cb3ac9be34c0200 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d8a3233095fcf818ca48e2912767aa11937b991 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..672a37c9ac4f21a542feb961e56936e4d80e88a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88b11a335dd7dfbdbf8e3a0559c0f71034e4d700 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d6cafc9ac1da2f597f0fd5f891bc9847aa79f5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..11cef0f9205a511605162042b0c016041b5e8413 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -0,0 +1,873 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the analysis here constructs view and mutation metadata from running +a functionalized version of the graph under compilation. +""" + +import collections +import contextlib +import logging +from collections.abc import Callable +from typing import Optional + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._guards import detect_fake_mode +from torch._logging import getArtifactLogger +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch._subclasses.meta_utils import safe_is_leaf +from torch.fx.experimental.symbolic_shapes import is_concrete_int +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + +from .descriptors import ( + AOTInput, + AOTOutput, + InputMutationAOTOutput, + IntermediateBaseAOTOutput, + PlainAOTOutput, + TangentAOTInput, +) +from .functional_utils import ( + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + MetadataKey, + to_fun, + ViewMetaSequence, + was_inductor_storage_resized, +) +from .schemas import ( + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .subclass_utils import create_subclass_meta +from .utils import _get_autocast_states, KNOWN_TYPES, simple_wraps, strict_zip + + +zip = strict_zip + +log = logging.getLogger(__name__) +static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs") + + +# Note [Tangents memory format] +# We assume tangents memory format to be similar to corresponding output's memory_format. +# The idea is that we are technically making a guess about the strides of our tangents, +# while we trace out the joint. +# If runtime specified tangents will not have the same memory format as predicted traced tangents, +# we coerce them at runtime to traced tangents memory format. + + +# Coercing and collecting traced tangents memory format in one recursive traversal +# mypy: ignore-errors +def coerce_tangent_and_suggest_memory_format(x: Tensor): + updated = False + if not isinstance(x, Tensor): + return x, None, updated + + out = x.detach() + + is_subclass = is_traceable_wrapper_subclass(out) + + memory_format = MemoryFormatMeta.from_tensor(out) + + # pyrefly: ignore [missing-attribute] + if memory_format.memory_format is not None: + was = out + # pyrefly: ignore [bad-argument-type] + out = out.contiguous(memory_format=memory_format.memory_format) + updated = was is not out + + # For subclass we keep memory format of outer strides at the beginning of the list + out_memory_format = [memory_format] if is_subclass else memory_format + + # Note [Tangents memory format, Part 2] + # In the same way that "what strides do we assigns to our tangents" is a question + # that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time, + # The same applies to any tensor subclass metadata, when we have tangents that are subclasses. + # To handle this situation, we have two new methods that a tensor subclass can implement: + # (1) __coerce_tangent_metadata__(self) + # Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata. + # The main example here is a DTensor with the "_Partial" placement. + # If we have a forward output with a _Partial placement, and corresponding tangent + # with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement. + # This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never + # have a tangent with "problematic" metadata, that we cannot convert to. + # (1) __coerce_same_metadata_as_tangent__(self, metadata) + # Given a subclass, and a target differing metadata, + # convert self to have the same metadata as the target. + # With DTensor being the main example, we can use this to convert a DTensor with a Replicate() + # placement into one with a Shard() placement, in the case that we "guessed wrong", + # and traced tangents with a Shard() placement at compile time. + # + if is_subclass and hasattr(out, "__coerce_tangent_metadata__"): + out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined] + + if is_subclass: + # pyrefly: ignore [missing-attribute] + attrs = out.__tensor_flatten__()[0] + + for attr in attrs: + elem = getattr(out, attr) + ( + new_elem, + new_elem_memory_format, + elem_updated, + ) = coerce_tangent_and_suggest_memory_format(elem) + # pyrefly: ignore [missing-attribute] + out_memory_format.append(new_elem_memory_format) + if elem_updated: + setattr(out, attr, new_elem) + + return out, out_memory_format, updated + + +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# Returns: +# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and +# The list of outputs from the forward, but **only** the outputs that we need +# to pass in as tangents into the backward. +# Specifically, aliased outputs from the forward get regenerated, and don't participate +# in the compiled backward function. +def run_functionalized_fw_and_collect_metadata( + f, + *, + flat_args_descs: list[AOTInput], + keep_input_mutations: bool, + # TODO: refactor to kill this flag + is_train: bool = False, + # Note: this is guaranteed to be set when running under dynamo + static_input_indices: Optional[list[int]] = None, + pre_dispatch: bool = False, +) -> Callable[..., ViewAndMutationMeta]: + memo: dict[Tensor, Tensor] = {} + + def _to_fun(t): + if isinstance(t, Tensor): + if t in memo: + return memo[t] + r = to_fun(t) + memo[t] = r + return r + else: + return t + + @simple_wraps(f) + def inner(*flat_args): + # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. + assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) + + input_info: list[InputAliasInfo] = [] + output_info: list[OutputAliasInfo] = [] + + prior_grad_enabled = torch.is_grad_enabled() + prior_autocast_states = _get_autocast_states() + + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + # It doesn't matter if we run this under predispatch or not because it is + # only for figuring out metadata + mode = FunctionalTensorMode(_allow_token_discovery=True) + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: + # precondition: The passed in function already handles unflattening inputs + flattening outputs + flat_f_args = pytree.tree_map(_to_fun, flat_args) + flat_f_args_descs = flat_args_descs + flat_f_outs = f(*flat_f_args) + + # Assert that f does NOT have an AOTOutputs in it, easy mistake to + # make! You need to drop the second output before calling this + # function + assert not pytree.tree_any( + lambda x: isinstance(x, AOTOutput), flat_f_outs + ), ( + f"{f} returned AOTOutput when it shouldn't. Did you remember to wrap the " + "function with without_output_descs before passing it here?" + ) + + # NB: this is just to setup the input descriptors, we will + # recreate these descriptors (with the same convention!) when we + # actually do the trace + flat_f_outs_descs = [PlainAOTOutput(i) for i in range(len(flat_f_outs))] + + # We didn't do any tracing, so we don't need to process the + # unbacked symbols, they will just disappear into the ether. + # Also, prevent memoization from applying. + if fake_mode: + fake_mode.epoch += 1 + fake_mode.reset_nt_tensor_id_counter() + + if prior_autocast_states != _get_autocast_states(): + raise RuntimeError( + "AOTAutograd does not support tracing graphs that mutate the autocast state. " + "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, " + "which will unwind all of their mutations to autocast state before the graph exits. " + "If you encounter this error while using torch.compile, please file a bug." + ) + + # Inspect the state of the input tensor functional wrapper to detect input mutation info + # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version + for arg, f_arg in zip(flat_args, flat_f_args): + # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in + # strides between the functionalized arg inner tensors and non-functionalized arg inner + # tensors. This is a problem as the inner tensor stride change may not be reflected + # correctly in the outer tensor, so disallow this for now. + mutates_data = has_data_mutation(f_arg) + mutates_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=False + ) + if mutates_metadata and is_traceable_wrapper_subclass(arg): + raise RuntimeError( + "Metadata mutations are currently not allowed on tensor subclasses" + ) + mutates_storage_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=True + ) + mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( + f_arg + ) + mutations_under_no_grad_or_inference_mode = ( + mutates_data + and are_all_mutations_under_no_grad_or_inference_mode(f_arg) + ) + mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg) + + if mutates_storage_metadata: + mutates_data = False + + requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + + input_info.append( + InputAliasInfo( + is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg), + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutates_storage_metadata=mutates_storage_metadata, + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + requires_grad=requires_grad, + keep_input_mutations=keep_input_mutations, + ) + ) + + # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, + # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view + # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad + # on the base tensor, but we are obligated to properly set requires-gradness on the real output. + + inp_storage_refs = { + StorageWeakRef(inpt.untyped_storage()): idx + for idx, inpt in enumerate(flat_f_args) + if isinstance(inpt, Tensor) + } + + # We need inp tensor id's to be able to tell if an outputs **are** inputs. + inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)} + # We need output tensor id's to tell if any output._base` attributes **are** other outputs. + # (This is also a dict because we need to know that output's index, so we can regenerate + # the alias from it). + out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)} + + # Keep track of which outputs alias other outputs + out_tensor_alias_counts: collections.defaultdict = collections.defaultdict(int) + # This tells us, for a given group of outputs that alias each other, + # whether they e.g. all came from an unbind call + num_aliased_tensors_that_are_multi_output_views: collections.defaultdict = ( + collections.defaultdict(int) + ) + + out_storage_to_metadata_key_to_tensors: collections.defaultdict[ + Optional[StorageWeakRef], + collections.defaultdict[MetadataKey, set[torch.Tensor]], + ] = collections.defaultdict(lambda: collections.defaultdict(set)) + + curr_storage = None + for o in flat_f_outs: + if isinstance(o, torch.Tensor): + curr_storage = StorageWeakRef(o.untyped_storage()) + out_tensor_alias_counts[curr_storage] += 1 + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # This is an optimization on top of the "alias of intermediates" logic, + # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!] + # + # Before describing the optimization: this is important for AOTAutograd to have good + # perf around, multi-output views. HOWEVER: + # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case, + # around using pre-dispatch tracing to partition out a graph so we can faithfully replay all + # views without having to regenerate them at runtime. + # - It's loosely described in this doc (more details will be added soon): + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit + # - Once that change lands, we should just rip out this "optimization", since: + # (1) It will be fully unnecessary + # (2) Although it is only a few lines of code, it is a bit difficult to reason about + # its correctness with the autograd engine in all cases. + # + # + # What is this optimization? Consider the below case: + # def f(x): + # intermediate = x.mul(2) + # # x and intermediate here require grad + # o1, o2, ... o10 = intermediate.unbind(-1) + # return intermediate, o1, o2, ... o10 + # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following: + # (1) return "intermediate as an extra output of the compiled graph + # (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function. + # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know + # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function, + # this information will be hidden. + # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases + # (like their grad_fn, for example, when the autograd engine needs to do view-replay). + # + # However, intermediate_base logic can be bad for backward performance (we sometimes generate + # as_strided calls during the intermediate base logic, which can have a slow backward formula). + # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd? + # + # For a set of outputs of the graph that alias each other, o_1...o_k, consider: + # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0) + # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate), + # **at most** 1 can escape from the graph (e.g. there is not some other graph input/output + # o_other, that aliases these outputs) + # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad. + # This condition is important because it's what causes slowness in the intermediate_base + # codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and + # aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn. + # "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward. + # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta + # of the other aliases? + # + # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd): + # (a) What happens if we mutate any of o_1 through o_k directly? + # Autograd raises an error: + # "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is + # the output of a function that returns multiple views. Such functions do not allow the output + # views to be modified inplace. You should replace the inplace operation by an out-of-place one." + # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)? + # Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views. + # (c) What if we mutate o_k under no_grad? + # Autograd raises the same error + # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)? + # Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed. + # Autograd raises the same error + # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view? + # We promised that there is at most **one** such alias, e.g. intermediate in the example above. + # You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k + # to be error fn's. + # Since intermediate was the *only* non-multi-output-alias, there are no other aliases + # of `intermediate` around that were produced by the compiled fn and have a valid grad_fn. + # + # Coming back to this optimization: + # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias + # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile + # if all of the above conditions are met. + # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on + # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance. + # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here: + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit, + # then this optimization will probably matter less and might be ok to remove. + is_cur_tensor_multi_out_view = isinstance( + o, FunctionalTensor + ) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined] + o.elem + ) + if is_cur_tensor_multi_out_view: + num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 + if o.requires_grad: + out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ].add(o) + + # maps the id of an intermediate base to its index in the output of the compiled forward + intermediate_base_tensor_id_to_output_idx: dict[int, int] = {} + intermediate_bases: list[torch.Tensor] = [] + intermediate_bases_descs: list[AOTInput] = [] + # Why Do We Care If Storage Changed? + # It's important to understand the implications of storage changes in complex scenarios. Take this example: + # + # def f(x): + # x_storage = x.untyped_storage() + # non_leaf_tensor = torch.ones(4, requires_grad=True).clone() + # + # # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(non_leaf_tensor.untyped_storage()) + # + # out = x.view(-1) + # + # # Restoring x to its original storage, again simulating .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(x_storage) + # + # return out + # + # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing. + # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics, + # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. + # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, + # which could lead to issues later in the code. + for o, desc in zip(flat_f_outs, flat_f_outs_descs): + functional_tensor_storage_changed = isinstance( + o, FunctionalTensor + ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] + o.elem + ) + curr_storage = ( + None + if not isinstance(o, torch.Tensor) + else StorageWeakRef(o.untyped_storage()) + ) + outs_with_identical_metadata_that_require_grad = ( + [] + if not isinstance(o, Tensor) + else [ + curr + for curr in out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ] + if o is not curr + ] + ) + + # See Note [Accessing .grad_fn on FunctionalTensor] + # In-place operations on views will trigger a lazy rebase of the autograd graph; + # this runs during access to the .grad_fn. The rebase logic will invoke view ops + # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure + # these op calls succeed. + grad_fn = None + if isinstance(o, Tensor): + with FunctionalTensorMode(): + grad_fn = o.grad_fn + + is_result_of_custom_autograd_fn = False + # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) + # autograd fns + if type(grad_fn).__name__ == "CppFunction": + is_result_of_custom_autograd_fn = True + if isinstance(grad_fn, torch.autograd.function.BackwardCFunction): + is_result_of_custom_autograd_fn = True + + if not isinstance(o, Tensor): + output_type = OutputType.non_alias + base_idx = None + elif ( + curr_storage in inp_storage_refs + and grad_fn is not None + and is_result_of_custom_autograd_fn + ): + output_type = OutputType.custom_function_view + base_idx = None + elif ( + curr_storage in inp_storage_refs + and not functional_tensor_storage_changed + ): + # pyrefly: ignore [index-error] + base_idx = inp_storage_refs[curr_storage] + is_input_tensor = id(o) in inp_tensor_ids + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + if ( + grad_fn is not None + and num_aliased_outs_that_are_not_multi_output_views == 0 + ): + # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # In particular, given: + # def f(x): + # return list(x.unbind(0)) + # The main reason we ordinarily try to regenerate these output aliases outside of the + # compiled autograd.Function is because if any of the outputs are later mutated, + # autograd needs to perform view-replay to regenerate them. + # However, autograd does not allow users to mutate multi-output views + # in any way that can change the autograd metadata of other aliases. + # So we hide this aliasing from autograd here. + log.debug( + "Encountered AOTAutograd case: differentiable outputs that \ +alias each other from a multi-output view call" + ) + output_type = OutputType.non_alias + elif is_input_tensor: + output_type = OutputType.is_input + else: + output_type = OutputType.alias_of_input + elif functional_tensor_storage_changed and id(o) in inp_tensor_ids: + # When there is a set_() on an input, we cannot rely on checking storages + # to detect if we are returning an input (since the inputs storage is different) + assert curr_storage is not None + base_idx = inp_storage_refs[curr_storage] + output_type = OutputType.is_input + + # We only need to handle the intermediate base case when both + # the intermediate base and the output require gradients. + # See Note [AOT Autograd: outputs aliasing inputs or intermediates!] + elif o._base is not None and o.requires_grad and o._base.requires_grad: + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + if ( + out_tensor_alias_counts[curr_storage] == 1 + or num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + # Note [Intermediate Bases Optimization] + # Normally if we have an output that aliases an intermediate, + # we need to add the extra "intermediate base" logic further down + # to prevent autograd from yelling at us if the user later tries to + # mutate that output. + # However, the common case here is if we have an output that aliases an intermediate, + # but doesn't alias any other outputs. + # In that case, autograd shouldn't have to worry about the aliasing at all + # (if that output is mutated, there are no other live aliases for autograd to worry about). + # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs. + # So as an optimization, we won't do intermediate base handling in this case. + # Instead, we'll hide the aliasing from autograd using aten._unsafe_view(). + if ( + out_tensor_alias_counts[curr_storage] != 1 + and num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + log.debug( + "Encountered AOTAutograd case: differentiable outputs that alias each other \ +from a multi-output view call" + ) + output_type = OutputType.unsafe_view_alias + base_idx = None + else: + # First, check if o's ._base is an existing output + maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None) + if maybe_existing_out_idx is not None: + # Special case where the output is an alias of a graph intermediate, but that intermediate + # is itself also a user output. + output_type = ( + OutputType.alias_of_intermediate_base_is_user_output + ) + base_idx = maybe_existing_out_idx + else: + # Next, check if o's ._base is an intermediate base that we already returned + maybe_existing_base_output_idx = ( + intermediate_base_tensor_id_to_output_idx.get( + id(o._base), None + ) + ) + if maybe_existing_base_output_idx is not None: + output_type = OutputType.alias_of_intermediate + base_idx = maybe_existing_base_output_idx + else: + # Otherwise, take o._base and explicitly return it as an output in the compiled graph + new_out_idx = len(intermediate_bases) + base_idx = new_out_idx + # Indicate to the logic later on (when we trace the joint) + # that this particular output should get it's ._base appended to the forward graph outputs + output_type = ( + OutputType.alias_of_intermediate_save_as_output + ) + intermediate_base_tensor_id_to_output_idx[id(o._base)] = ( + new_out_idx + ) + intermediate_bases.append(o._base) + # NB: The desc we picked here is guaranteed to be + # synchronized with the one in + # graph_capture_wrappers.py because we + # SPECIFICALLY notated this output as + # alias_of_intermediate_save_as_output + intermediate_bases_descs.append( + TangentAOTInput(IntermediateBaseAOTOutput(desc)) + ) + elif ( + # See https://github.com/pytorch/pytorch/issues/100348 for this case. + # This protects against the specific case where a user fn returns (output, output.detach()) + out_tensor_alias_counts[curr_storage] > 1 + and len(outs_with_identical_metadata_that_require_grad) > 0 + and not o.requires_grad + ): + # In theory we could use any of these tensors to regenerate the aliased outputs from, + # since they all alias each other and have identical metadata + out_alias = outs_with_identical_metadata_that_require_grad[0] + existing_out_idx = out_tensor_ids[id(out_alias)] + output_type = OutputType.alias_of_intermediate_base_is_user_output + base_idx = existing_out_idx + else: + output_type = OutputType.non_alias + base_idx = None + + if isinstance(o, torch.Tensor): + dynamic_dims = { + i for i, s in enumerate(o.shape) if not is_concrete_int(s) + } + else: + dynamic_dims = None + + # Save the current FunctionalTensor output. + # + # This will be used at runtime for reconstructing output views from + # their respective base tensors. + # + # The FunctionalTensor will be saved if one of the 2 conditions below + # is true: + view_meta_sequence = None + if ( + # 1. If the output_type is either of: + # (i) alias_of_intermediate; + # (ii) alias_of_intermediate_save_as_output; or + # (iii) alias_of_intermediate_base_is_user_output. + # + # No need to worry about in-place view operations here, since + # this functionalization step elimitates mutations. + # + # i.e. we have access to the actual base tensor, before the + # in-place operation was applied. + output_type + in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ) + ) or ( + # 2. If the output_type is alias_of_input, and no in-place view + # operationthe was run on the input (base tensor). + # + # In this case, we need to check for metadata mutation because + # the runtime explicitly reconstructs the inputs, before actually + # reconstructing the outputs. Due to in-place view operations, the + # fully reconstructed input may not be this output base tensor + # anymore. + output_type == OutputType.alias_of_input + and base_idx is not None + and not input_info[base_idx].mutates_metadata + ): + if isinstance(o, FunctionalTensor): + view_meta_sequence = ViewMetaSequence(o) + + out_info = OutputAliasInfo( + output_type=output_type, + raw_type=type(o), + base_idx=base_idx, + dynamic_dims=dynamic_dims, + requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, + view_meta_sequence=view_meta_sequence, + ) + output_info.append(out_info) + + # See Note [AOT Autograd: Views to avoid tangents aliasing inputs] + def view_avoid_dupes_with_primals(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + return transform_subclass( + t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t) + ) + if isinstance(t, Tensor): + return t.view(t.shape) + return t + + # This analysis function returns *only* the outputs that are meant to be tangents to the backwards. + # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates) + # are *regenerated* later, and not used directly in the autograd graph + def _plain_fake_tensor_like_subclass(x): + # pyrefly: ignore [bad-context-manager] + with detect_fake_mode(): + return torch.empty( + x.shape, dtype=x.dtype, device=x.device, layout=x.layout + ) + + def _is_subclass_mutated_input_tangent_always_subclass(inp): + return ( + isinstance(inp, torch.nested._internal.nested_tensor.NestedTensor) + or torch._functorch.config.disable_guess_zero_tangent_for_mutated_input_subclass + ) + + f_input_tangents_pairs = [ + # Note: [AOTAutograd Tangent Subclassness for mutated inputs] + # Generally when creating tangents to trace with, we assume that tangents will have + # the same subclass-ness as their forward outs + # however: for tangents that correspond to input mutations, in practice it is more likely + # that these tangents will be plain tensors of zeros at runtime, so we tweak our guess + # to assume that these tangents should always be plaint tensors. + # Example: + # def f(x): + # x.mul_(2) + # return x + 1 + # out = f(x) + # out.sum().backward() + # In the above code, we will have a tangent "x_updated_tangent", + # which will be a plain tensor of zeros, *unless* x is used in some compute after executing f + # + # However, there are exceptions to this logic. If a view is created from mutated input and is used in backward, + # The tangent for this subclass input will be a subclass tensor. + # Example: + # def f(a, b): + # a.mul_(2) + # b.mul_(3) + # return b.view(b.shape), a + b + # a_out, b_out = f(..., Subclass) + # (a * b).sum().backward() + # + # We can not deduce it easily now, so introducing a debug config to be able to turn off this for specific cases. + # NJT guarantees to have its tangent as NJT, because it has dedicated integration in Autograd + # See torch/csrc/autograd/python_function.cpp, use_zeros_like. + ( + ( + _plain_fake_tensor_like_subclass(inp) + if is_traceable_wrapper_subclass(inp) + and not _is_subclass_mutated_input_tangent_always_subclass(inp) + else inp + ), + TangentAOTInput(InputMutationAOTOutput(inp_desc)), + ) + for inp, inp_desc, info in zip(flat_f_args, flat_f_args_descs, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + and info.mutates_data + and info.requires_grad + ] + f_input_tangents, f_input_tangents_descs = ( + [x[0] for x in f_input_tangents_pairs], + [x[1] for x in f_input_tangents_pairs], + ) + + f_output_tangents_pairs = [ + (o, TangentAOTInput(desc)) + for o, info, desc in zip(flat_f_outs, output_info, flat_f_outs_descs) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + f_output_tangents, f_output_tangents_descs = ( + [x[0] for x in f_output_tangents_pairs], + [x[1] for x in f_output_tangents_pairs], + ) + + # intermediate bases are also included in the backward graph + f_tangents = f_input_tangents + f_output_tangents + intermediate_bases + f_tangents_descs = ( + f_input_tangents_descs + f_output_tangents_descs + intermediate_bases_descs + ) + + # TODO: I'm pretty sure you don't need a tree_map here + traced_tangents = pytree.tree_map(from_fun, f_tangents) + traced_tangents = pytree.tree_map( + view_avoid_dupes_with_primals, traced_tangents + ) + traced_tangents = [ + coerce_tangent_and_suggest_memory_format(tt)[0] + for i, tt in enumerate(traced_tangents) + ] + # NB: update this if the maps above ever change structure. + # Also, it might be helpful to add coercion information to the tangent desc! + traced_tangents_descs = f_tangents_descs + + nonlocal static_input_indices + static_input_indices = static_input_indices or [] + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + passed_indices = set(static_input_indices) + static_input_indices = [ + i + for i, arg in enumerate(flat_args) + if (isinstance(arg, torch.nn.Parameter) or i in passed_indices) + ] + + static_input_logger.debug( + "static input indices metadata analysis: %s", static_input_indices + ) + + f_mutated_inputs = [ + inp + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + f_metadata_mutated_inputs = [ + inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata + ] + # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. + # When handling subclasses, we need info about **all** outputs of compiled forward graph, + # so we know precisely which graph outputs to wrap back into tensor subclasses + # Ideally we would refactor this so not have an is_train flag, and have the separate + # inference and training paths decide which inputs/output to ask for subclass info on. + # However, we currently stash indexing information on each SubclassMeta about its order + # in the graph outputs list. + f_fw_graph_outs = list(flat_f_outs) + if is_train or not keep_input_mutations: + f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs + else: + # even when "keep_input_mutations" is True, + # we never keep metadata-only mutations in the fw graph + f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs + if is_train: + f_fw_graph_outs = f_fw_graph_outs + intermediate_bases + fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) + + grad_enabled_mutation = None + if torch.is_grad_enabled() != prior_grad_enabled: + grad_enabled_mutation = torch.is_grad_enabled() + torch.set_grad_enabled( + prior_grad_enabled + ) # Restore the prior state after tracing it + log.debug( + ( + "grad_mode mutation encountered in graph. " + "Will emit mutation epilogue, to set grad_mode=%s" + ), + grad_enabled_mutation, + ) + + metadata = ViewAndMutationMeta( + input_info=input_info, + output_info=output_info, + num_intermediate_bases=len(intermediate_bases), + keep_input_mutations=keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + subclass_inp_meta=create_subclass_meta(flat_args), + subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), + subclass_tangent_meta=create_subclass_meta( + traced_tangents, count_symints=False, with_memory_format=True + ), + is_train=is_train, + grad_enabled_mutation=grad_enabled_mutation, + static_input_indices=static_input_indices, + tokens=mode._tokens, + ) + return metadata + + return inner diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py new file mode 100644 index 0000000000000000000000000000000000000000..3d480cdf6f9ac66c12c394b0c43fe6e1aacc06c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py @@ -0,0 +1,749 @@ +""" +AOTAutograd descriptors are a path-like data structure (similar to pytree +paths and sources) that describe the semantic meaning of an input/output to FX +graphs. Although you may know the input/output meaning at the top level of +the original function you traced, because we have many graph capture wrappers +that change the calling convention, it can be difficult to tell how these +correspond to the actual FX graph you get back, to say nothing about the extra +arguments/outputs for tangents, gradients, etc. Descriptors describe the meaning +of arguments. + +Examples +-------- + +Before we talk about the precise semantics, it's helpful to look at some +examples to get some intuition for the meaning of descriptors. Here are some +input descriptors you might find on the joint FX graph: + +* PlainAOTInput(idx=0) - the first input from the original callable, as is + +* ParamAOTInput(target="mod.weight") - the parameter with FQN mod.weight + +* TangentAOTInput(output=PlainAOTOutput(idx=1)) - the input tangent + corresponding to the gradients for the second output in the forward graph + +* ViewBaseAOTInput(base_of=PlainAOTInput(idx=0)) - it turned out the first + input was actually a (differentiable) view of a tensor which aliased with + another input tensor. We replaced this input with a single input for the + base of all of these inputs, replacing the original inputs (one of which is + mentioned in base_of). We would generate a GradAOTOutput for *this* input + (and not the original PlainAOTInputs!) If you have a joint graph where a + view base like this is undesirable, you can eliminate this by cloning + the views outside of the compiled region (assuming you aren't mutating this + tensor). + +* SubclassGetAttrAOTInput(base=AOTInput(idx=0), attr="inner") - this tensor + corresponds to the "inner" tensor from the tensor subclass that is at the + first index. In general, joint graphs from AOTAutograd never take tensor + subclasses as inputs; they are always unpacked into their constituent plain + tensor pieces; use the descriptors to identify the parts of the tensor that + are related. Note that this can be nested (if you have nested tensor + subclasses!) + +Here are some output descriptors you might find on the Joint FX graph: + +* PlainAOTOutput(idx=0) - the first output from the original forward function, + as is + +* GradAOTOutput(grad_of=PlainAOTInput(idx=1)) - the computed gradient for the + second input to the graph, an output of the backward graph + +* InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=0)) - when the first + input is mutated, the new value to be copied into the first input of the + graph. Sometimes, these outputs can be elided and the ``copy_`` is done directly + in the graph (controlled by keep_input_mutations), but if the input + mutation must be differentiated through we always generate an output like this + +* IntermediateBaseAOTOutput(base_of=PlainAOTOutput(idx=0)) - if we return + multiple outputs which alias each other, we instead replace them with a single + output tensor representing the base of all the aliases. This output indicates + it is the base for /one/ of those original outputs. If this is undesirable in + the joint graph, clone all outputs before returning from the graph. + +* SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), idx="inner") - this + tensor correspondings to the inner tensor of the first original output which + is a tensor subclass. This and other subclass components of that output will + get repacked into a tensor subclass. + +High level semantics +-------------------- + +OK, let's formally define a descriptor. Intuitively, suppose we have:: + + def wrapped_graph(*args): + ret = graph(*in_transform(args)) + return out_transform(ret) + +Then the descriptor for input[i] to graph describes a function fin_i such that:: + + fin_i(args) == in_transform(args)[i] + +and the descriptor for output[j] from graph describes a function fout_j such that:: + + fout_j(out_transform(ret)) == ret[j] + +AKA input descriptors tell you how to get from outer inputs to inner inputs, +while output descriptors tell you how to get from outer outputs to inner +outputs (inverse data flow!) + +We haven't said anything about what these transformations actually do. There +are three major transformations AOTAutograd does (performed in this order): + +* View/mutation handling +* Autograd +* Subclasses + +So intuitively, descriptors are built like this: + +1. **PlainAOTInput, PlainAOTOutput.** + + We start off descriptors describing the exact inputs/outputs of the + original flattened user function. This user function is assumed to already + be flattened; you would chain on pytree KeyPaths to further describe where + in the pytree each input/output lived if you needed to deal with + unflattened functions: this can be done from userland on top of + descriptors, so the main descriptors mechanism doesn't handle it. + +2. **SyntheticBaseAOTInput, ViewBaseAOTInput, MetadataMutationAOTOutput, + InputMutationAOTOutput, IntermediateBaseAOTOutput** + + We deal with mutations and aliasing by removing duplicate PlainAOTInputs + and introduce some new artificial inputs/outputs. These inputs do not + have a straightforward correspondence to the original user inputs, but if + you are implementing a pass that doesn't care about the exact semantics of + inputs, you should handle all of these uniformly in the same way as regular + inputs. + +3. **TangentAOTInput, GradAOTOutput** + + We deal with autograd by introducing a tangent input for every + differentiable AOTOutput (including the new ones introduced above), and a + gradient output for every differentiable AOTInput (also including new ones + introduced above.) The arguments to these AOTInput/AOTOutput can ONLY be + the ones we already have above (from steps 1-2). As AOTAutograd does not + currently support double backwards, you never have tangents of grads or + vice versa (but in the future we could!) + +4. **SubclassGetAttrAOTInput, SubclassGetAttrAOTOutput, et al.** + + We deal with subclasses by introducing flattened inputs/outputs (including + potentially symbolic sizes/strides) for every AOTInput/AOTOutput that was a + subclass. As above, the arguments to these AOTInput/AOTOutput can ONLY be + the ones we have above (from steps 1-3). Recursive subclasses are + supported, so these descriptors can nest with each other (so descriptors + from step 4 are fair game as well.) + +5. **ForwardTokenAOTInput, ForwardTokenAOTOutput, BackwardTokenAOTInput, BackwardTokenAOTOutput.** + + Some extra token inputs/outputs get added, these are synthetic and are just here to + prevent DCE/reordering. + +The important thing about the pipeline is that descriptors can ONLY be +created from top-to-bottom. So for example, you can have:: + + SubclassGetAttrAOTInput(TangentAOTInput(PlainAOTOutput(...))) # OK + +As you can see that PlainAOTOutput -> TangentAOTInput -> +SubclassGetAttrAOTInput is consistent with the pipeline ordering), but you can +NEVER have:: + + TangentAOTInput(SubclassGetAttrAOTOutput(PlainAOTOutput(...)) # BAD + +This is inconsistent; we always do autograd BEFORE we process subclasses! + +Similarly, for example, this is illegal:: + + GradAOTOutput(SubclassGetAttrAOTInput(PlainAOTInput(...))) # BAD + +It is illegal because subclasses are handled *after* create joint during +wrapper construction. Instead, you would have:: + + SubclassGetAttrAOTOutput(GradAOTOutput(PlainAOTInput(...))) # OK + +This intuitively captures the fact that we always to autograd directly on the +subclass, rather than after desugaring the subclass into its inner tensors. + +Descriptor index +---------------- + +Here is a list of all AOTInput/AOTOutput, organized by how likely you need to +handle them: + +* AOTInput + + * Important: + + * PlainAOTInput (the primals!) + * ParamAOTInput + * TangentAOTInput + * SubclassGetAttrAOTInput et al. (if you use subclasses) + + * View related (can be eliminated by cloning inputs to graph; if you don't + eliminate them, make sure to handle pairing them with GradAOTOutput): + + * ViewBaseAOTInput + * SyntheticBaseAOTInput + + * Non-tensor, mostly just ignore them: + + * DummyAOTInput + * PhiloxForwardSeedAOTInput + * PhiloxForwardBaseOffsetAOTInput + * PhiloxBackwardSeedAOTInput + * PhiloxBackwardBaseOffsetAOTInput + * ForwardTokenAOTInput + * BackwardTokenAOTInput + +* AOTOutput + + * Important: + + * PlainAOTOutput + * GradAOTOutput + * SubclassGetAttrAOTOutput et al. (if you use subclasses) + + * More obscure (if not eliminated, make sure you handle pairing them with + TangentAOTInput): + + * InputMutationAOTOutput (can be eliminated if mutations are non-differentiable) + * IntermediateBaseAOTOutput (can be eliminated by cloning outputs of graph) + * MetadataMutationAOTOutput (uhh, just don't mutate metadata?) + + * Non-tensor, mostly just ignore them: + + * PhiloxUpdatedForwardOffsetAOTOutput + * PhiloxUpdatedBackwardOffsetAOTOutput + * ForwardTokenAOTOutput + * BackwardTokenAOTOutput + * DummyAOTOutput + +For convenience, we also have DifferentiableAOTInput and +DifferentiableAOTOutput to help you classify which inputs/outputs can be +wrapped by GradAOTOutput/TangentAOTInput (respectively), which are essentially +all tensor AOTInput/AOTOutput excluding the subclass descriptors. + +Implementation details +---------------------- + +The stylized view above is good for understanding how to interpret +descriptors, but the way that descriptors are generated in code is a bit more +complicated. Specifically, AOTAutograd is structured as a series of wrappers +on the original user function, which are composed together to form the final +function to trace. As a result of this, AOTAutograd ends up first building +the full AOTInputs for a function to be traced (as it builds the wrappers and +modifies the flat arguments to be compatible with the new input signature of +the wrapper), and then in reverse builds up the AOTOutput as it is tracing. + +There is one major exception to this general idea of "build AOTInput first", +and then "build AOTOutput second": when we create TangentAOTInput, we need to +reference AOTOutputs (which output we are the tangents of) which we generally +haven't created yet. There's two ways we deal with this: + +- After the precompile steps (dedup and synthetic base handling), we do an + initial pass to collect forward metadata that produces the initial set of + PlainAOTOutputs which we use to create the tangent inputs. + +- We also sometimes just violate causality and predict that an AOTOutput will + be created in a particular way at some later point in time when we build an + AOTInput. + +As of July 2025, here is an exhaustive description of how inputs/outputs +traverse the wrappers from AOTAutograd, and what descriptors can be introduced +at these phases. + +:: + + Build wrappers (FLOWS DOWN) Run trace (FLOWS UP) + ------------------------------------------------------------------------------------------------- + Begin PlainAOTInput (n/a) + ParamAOTInput + + Precompile dedupe (remove dupes) (nothing) + + Precompile synthetic base SyntheticBaseAOTInput MetadataMutationAOTOutput + ViewBaseAOTInput + + Forward metadata trace PlainAOTOutput (n/a) + MetadataMutationAOTOutput + + Prepare for autograd (nothing) InputMutationAOTOutput + IntermediateBaseAOTOutput + + Create joint TangentAOTInput GradAOTOutput + w/ InputMutationAOTOutput + w/ IntermediateBaseAOTOutput + + Precompile subclass SubclassGetAttrAOTInput et al. SubclassGetAttrAOTOutput et al. + + Effect tokens ForwardTokenAOTInput ForwardTokenAOTOutput + BackwardTokenAOTInput BackwardTokenAOTOutput + + End (n/a) PlainAOTOutput + +It can be helpful to separately write down the input flow and the output flow +for ease of understanding the data flow: + +* Input desc propagation (happens as we build wrappers) + + * [IN] Begin with original calling convention (PlainAOTInput, ParamAOTInput) + * [IN] Precompile dedupe: (removes duplicate AOTInputs) + * [IN] Precompile synthetic base: SyntheticBaseAOTInput, ViewBaseAOTInput + * Forward metadata trace (mini output desc propagation) + + * [OUT] Original output convention: PlainAOTOutput + * [OUT] Precompile synthetic base: MetadataMutationAOTOutput + + * [IN] Prepare for autograd: (nothing) + * [IN] Create joint: TangentAOTInput (potentially w/ + IntermediateBaseAOTOutput, InputMutationAOTOutput) + * [IN] Precompile subclass: SubclassGetAttrAOTInput et al. + * [IN] Effect tokens: ForwardTokenAOTInput, BackwardTokenAOTInput + (Note: BackwardTokenAOTInput is technically generated not by a wrapper but + actually done by token_discovery which implicitly adds extra arguments + to the FX trace on-the-fly.) + +* Trigger a trace with the modified inputs on the wrapper +* Output desc propagation (happens as we unwind from the user function call in trace) + + * [OUT] Begin with original calling convention: PlainAOTOutput + * [OUT] Effect tokens: ForwardTokenAOTOutput, BackwardTokenAOTOutput + * [OUT] Precompile subclass: SubclassGetAttrAOTOutput et al. + * [OUT] Create joint: GradAOTOutput + * [OUT] Prepare for autograd: InputMutationAOTOutput, IntermediateBaseAOTOutput + * [OUT] Precompile synthetic base: MetadataMutationAOTOutput + * [OUT] Precompile dedupe: (nothing) +""" + +import dataclasses + + +# TODO: the is_* predicates are a little suspicious because (1) they're not +# used by anything and (2) they always report False even when a parameter got +# swizzled into a view base or deduped with a non-parameter. It is pretty +# difficult to exercise these cases but it's not clear if you will write code +# that works correctly in those cases. + + +@dataclasses.dataclass(frozen=True) +class AOTInput: + """Describes where an input from an AOTAutograd produced FX graph comes from""" + + def expr(self) -> str: + raise NotImplementedError("Subclasses must implement expr()") + + def is_param(self) -> bool: + """True if this input is a parameter or derived from a parameter (e.g., subclass attr)""" + return False + + def is_buffer(self) -> bool: + """True if this input is a buffer or derived from a buffer (e.g., subclass attr)""" + return False + + def is_tangent(self) -> bool: + """True if this input is a tangent or derived from a tangent (e.g., subclass attr)""" + return False + + +# Note: Currently, our typing discipline for differentiable versus not is not +# very good, so feel free to rely on runtime tests instead. + + +@dataclasses.dataclass(frozen=True) +class DifferentiableAOTInput(AOTInput): + """A subclass that classifies AOTInput that can be wrapped by GradAOTOutput""" + + +@dataclasses.dataclass(frozen=True) +class AOTOutput: + """Describes where an output from an AOTAutograd produced FX graph will + eventually be bundled into the final output""" + + def expr(self) -> str: + raise NotImplementedError("Subclasses must implement expr()") + + def is_grad(self) -> bool: + """True if this output is a grad or derived from a grad (e.g., subclass attr)""" + return False + + +@dataclasses.dataclass(frozen=True) +class DifferentiableAOTOutput(AOTOutput): + """A subclass that classifies AOTOutput that can be wrapped by TangentAOTInput""" + + +# ------------ + +# AOTInput + +# ------------ + + +@dataclasses.dataclass(frozen=True) +class ParamAOTInput(DifferentiableAOTInput): + """The input is a parameter, whose FQN is target""" + + target: str + + def expr(self) -> str: + return f"self.get_parameter({self.target!r})" + + def is_param(self) -> bool: + return True + + def is_buffer(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class BufferAOTInput(DifferentiableAOTInput): + """The input is a buffer, whose FQN is target""" + + target: str + + def expr(self) -> str: + return f"self.get_buffer({self.target!r})" + + def is_param(self) -> bool: + return False + + def is_buffer(self) -> bool: + return True + + +@dataclasses.dataclass(frozen=True) +class DummyAOTInput(AOTInput): + """In some circumstances, we want to call into a function that expects AOTInput, but + we don't actually care about that logic (most typically, because some code is being used + for both compile-time and run-time; AOTInput processing is not needed in this situation. + Pass a dummy in this situation; but it is better to just have a version of the function + that doesn't have this at all.""" + + idx: int + + def expr(self) -> str: + return f"__dummy{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class PlainAOTInput(DifferentiableAOTInput): + """The input is a plain input, corresponding to a particular positional index. + + Note that AOTInput is always relative to a function with a *flat* calling convention, + e.g., as accepted by `aot_module_simplified`. There are some AOTAutograd APIs that + flatten pytrees, and we don't record PyTree key paths from the flattening (but we + could and should!) + """ + + idx: int + + def expr(self) -> str: + return f"args[{self.idx}]" + + +@dataclasses.dataclass(frozen=True) +class SubclassGetAttrAOTInput(AOTInput): + """Subclass inputs get unpacked into their constituent pieces before going into an FX + graph. This tells you which particular attribute of the subclass this particular + input corresponds to (of the 'base' originally subclass argument.) + """ + + base: AOTInput + attr: str + + def expr(self) -> str: + return f"{self.base.expr()}.{self.attr}" + + def is_param(self) -> bool: + return self.base.is_param() + + def is_buffer(self) -> bool: + return self.base.is_buffer() + + def is_tangent(self) -> bool: + return self.base.is_tangent() + + +@dataclasses.dataclass(frozen=True) +class SubclassSizeAOTInput(AOTInput): + """Which subclass this particular outer size SymInt input (at dim idx) came from.""" + + base: AOTInput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.size({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class SubclassStrideAOTInput(AOTInput): + """Which subclass this particular outer stride SymInt input (at dim idx) came from.""" + + base: AOTInput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.stride({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class ViewBaseAOTInput(DifferentiableAOTInput): + """ + When multiple differentiable inputs are views of the same input, AOTAutograd will replace all of these + views with a single input representing the base. If this is undesirable, you can clone the views + example inputs before passing them into AOTAutograd. + + TODO: In principle we could report ALL of the inputs who this is a base of. + """ + + base_of: AOTInput + + def expr(self) -> str: + return f"{self.base_of.expr()}._base" + + +@dataclasses.dataclass(frozen=True) +class SyntheticBaseAOTInput(DifferentiableAOTInput): + """This is similar to ViewBaseAOTInput, but this happens when none of the views were differentiable, so + we weren't able to get our hands on the true original view and constructed a synthetic one instead + for the sake of autograd. + """ + + base_of: AOTInput + + def expr(self) -> str: + return f"__make_synthetic_base({self.base_of.expr()})" + + +@dataclasses.dataclass(frozen=True) +class PhiloxForwardSeedAOTInput(AOTInput): + """The seed for functionalized Philox RNG calls, specifically for forward graph.""" + + def expr(self) -> str: + return "__philox_forward_seed" + + +@dataclasses.dataclass(frozen=True) +class PhiloxForwardBaseOffsetAOTInput(AOTInput): + """The offset for functionalized Philox RNG calls, specifically for forward graph.""" + + def expr(self) -> str: + return "__philox_forward_base_offset" + + +@dataclasses.dataclass(frozen=True) +class PhiloxBackwardSeedAOTInput(AOTInput): + """The seed for functionalized Philox RNG calls, specifically for backward graph.""" + + def expr(self) -> str: + return "__philox_backward_seed" + + +@dataclasses.dataclass(frozen=True) +class PhiloxBackwardBaseOffsetAOTInput(AOTInput): + """The offset for functionalized Philox RNG calls, specifically for backward graph.""" + + def expr(self) -> str: + return "__philox_backward_base_offset" + + +@dataclasses.dataclass(frozen=True) +class ForwardTokenAOTInput(AOTInput): + """The world token which is threaded through side-effectful operations""" + + idx: int + + def expr(self) -> str: + return f"__forward_token{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class BackwardTokenAOTInput(AOTInput): + """The world token which is threaded through side-effectful operations, for backwards""" + + idx: int + + def expr(self) -> str: + return f"__backward_token{self.idx}" + + +# Technically the "output" here is redundant, tangents always correspond to +# outputs +# NB: this is marked differentiable as it /would/ be differentiable if we +# support double backwards, but we never generate this today because we +# don't support double backwards. +@dataclasses.dataclass(frozen=True) +class TangentAOTInput(DifferentiableAOTInput): + """An input to the joint graph representing the tangent of an output.""" + + output: DifferentiableAOTOutput + + def __post_init__(self) -> None: + assert isinstance(self.output, DifferentiableAOTOutput) + + def expr(self) -> str: + return f"__output_tangent({self.output.expr()})" + + def is_tangent(self) -> bool: + return True + + +# ------------ + +# AOTOutput + +# ------------ + + +@dataclasses.dataclass(frozen=True) +class PlainAOTOutput(DifferentiableAOTOutput): + """A plain tensor output at position idx of the output tuple""" + + idx: int + + def expr(self) -> str: + return f"output[{self.idx}]" + + +@dataclasses.dataclass(frozen=True) +class InputMutationAOTOutput(DifferentiableAOTOutput): + """The mutated value of an input tensor, returned so we can appropriately propagate autograd.""" + + mutated_input: AOTInput + + def expr(self) -> str: + return f"__input_mutation({self.mutated_input.expr()})" + + +@dataclasses.dataclass(frozen=True) +class IntermediateBaseAOTOutput(DifferentiableAOTOutput): + """An intermediate base of multiple outputs which alias each other. We only report ONE of + the outputs that contributed to this base""" + + base_of: "AOTOutput" + + def expr(self) -> str: + return f"__intermediate_base({self.base_of.expr()})" + + +# TODO: it's a little dodgy this is differentiable lol, but we do generate +# these BEFORE autograd is handled +@dataclasses.dataclass(frozen=True) +class MetadataMutationAOTOutput(DifferentiableAOTOutput): + idx: int + + def expr(self) -> str: + return f"__aliased_arg_with_metadata_mutation{self.idx}" + + +# NB: this is marked differentiable as it /would/ be differentiable if we +# support double backwards, but we never generate this today because we +# don't support double backwards. +@dataclasses.dataclass(frozen=True) +class GradAOTOutput(DifferentiableAOTOutput): + """An output representing the computed gradient for a differentiable input, in the joint graph""" + + grad_of: DifferentiableAOTInput + + def __post_init__(self) -> None: + assert isinstance(self.grad_of, DifferentiableAOTInput) + + def expr(self) -> str: + return f"__grad({self.grad_of.expr()})" + + def is_grad(self) -> bool: + return True + + +@dataclasses.dataclass(frozen=True) +class PhiloxUpdatedForwardOffsetAOTOutput(AOTOutput): + """The final offset from the functionalized RNG calls, forward only""" + + def expr(self) -> str: + return "__philox_updated_forward_offset" + + +@dataclasses.dataclass(frozen=True) +class PhiloxUpdatedBackwardOffsetAOTOutput(AOTOutput): + """The final offset from the functionalized RNG calls, backward only""" + + def expr(self) -> str: + return "__philox_updated_backward_offset" + + +@dataclasses.dataclass(frozen=True) +class ForwardTokenAOTOutput(AOTOutput): + """The world token output for side-effectful calls, returned so we cannot DCE it, forward only""" + + idx: int + + def expr(self) -> str: + return f"__forward_token{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class BackwardTokenAOTOutput(AOTOutput): + """The world token output for side-effectful calls, returned so we cannot DCE it, backward only""" + + idx: int + + def expr(self) -> str: + return f"__backward_token{self.idx}" + + +# These are seemingly symmetric with their AOTInput counterparts. The way to +# think about it is that a subclass could be an input or an output, and they +# get exploded into plain tensors on the way in and out. So we need +# descriptors for both. +@dataclasses.dataclass(frozen=True) +class SubclassGetAttrAOTOutput(AOTOutput): + """This output will be bundled into a subclass at this location""" + + base: AOTOutput + attr: str + + def expr(self) -> str: + return f"{self.base.expr()}.{self.attr}" + + def is_grad(self) -> bool: + return self.base.is_grad() + + +@dataclasses.dataclass(frozen=True) +class SubclassSizeAOTOutput(AOTOutput): + """This output size will be bundled into a subclass at this location""" + + base: AOTOutput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.size({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class SubclassStrideAOTOutput(AOTOutput): + """This output stride will be bundled into a subclass at this location""" + + base: AOTOutput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.stride({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class DummyAOTOutput(AOTOutput): + """For cases when you don't actually care about descriptor propagation, do not use under normal + circumstances.""" + + idx: int + + def expr(self) -> str: + return f"__dummy{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class SavedForBackwardsAOTOutput(AOTOutput): + idx: int + + def expr(self) -> str: + return f"__saved_for_backwards_{self.idx}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5af4fc9ee11955b4e6151f9602793c9076c48387 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/functional_utils.py @@ -0,0 +1,548 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities related to functionalization in AOTAutograd: +1. converting to/from functional tensors +2. detecting Tensor mutations - both metadata and Tensor value +3. regenerating/replaying views from their base +4. checking if a graph is functional i.e. whether it contains any mutation ops +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor +from torch._C import _functionalization +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + + +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") + + +def to_fun(t): + if isinstance(t, Tensor): + if is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + else: + return FunctionalTensor.to_functional(t) + else: + return t + + +def sync_functional_tensor(t): + if is_traceable_wrapper_subclass(t): + attrs, _ctx = t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + sync_functional_tensor(getattr(t, attr)) + else: + torch._sync(t) + + +# When subclasses are involved, t here will usually look something like: +# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor)))) +def from_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) # type: ignore[attr-defined] + return t + sync_functional_tensor(t) + return torch._from_functional_tensor(t.elem) + + +def is_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined] + t_inners = [getattr(t, attr) for attr in t_attrs] + any_fun = any(is_fun(x) for x in t_inners) + all_fun = all(is_fun(x) for x in t_inners) + assert any_fun == all_fun + return any_fun + + return isinstance(t, FunctionalTensor) + + +# t here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +def has_data_mutation(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + return any(has_data_mutation(getattr(t, attr)) for attr in attrs) + else: + if isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined] + return False + + +def are_all_mutations_hidden_from_autograd(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. + return all( + are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs + ) + elif isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) + else: + return False + + +def are_all_mutations_under_no_grad_or_inference_mode(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + return all( + are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr)) + for attr in attrs + ) + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t.elem + ) + + +def was_inductor_storage_resized(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs): + raise RuntimeError( + f"storage resizing is not supported on tensor subclass: {type(t)}" + ) + elif not isinstance(t, torch.Tensor): + return False + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_was_inductor_storage_resized(t.elem) + + +# f_arg here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +# Assumption: arg promises to be the "original" tensor wrapped by f_arg +# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: +# - check_only_storage_mutation=True: only return true if there was a storage mutation +# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) +def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): + if is_traceable_wrapper_subclass(f_arg): + attrs, _ = f_arg.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + f_inner_ts = [getattr(f_arg, attr) for attr in attrs] + inner_ts = [getattr(arg, attr) for attr in attrs] + return any( + has_metadata_mutation( + f_inner_t, + inner_t, + check_only_storage_mutation=check_only_storage_mutation, + ) + for f_inner_t, inner_t in zip(f_inner_ts, inner_ts) + ) + else: + if not isinstance(f_arg, torch.Tensor): + assert not isinstance(arg, torch.Tensor) + return False + assert isinstance(f_arg, FunctionalTensor) + assert isinstance(arg, FakeTensor) + + arg_after = torch._from_functional_tensor(f_arg.elem) + # This is true if the current tensor experienced at least one set_() call + maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined] + # However, multiple set_() calls can cancel out. So we also check whether the + # storage of the tensor has changed. + # Note: if an input experienced two set_() calls that cancel out, **and** + # it experiences an data mutation, we pessimistically think that the set_() + # call is necessary here. We could in theory fix this, but this will + # hopefully never happen in user code, and is not needed for fsdp. + if is_sparse_any(arg): + # TODO:add sparse tensors support to functionalization + same_storages = False + else: + same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef( + arg_after.untyped_storage() + ) + has_storage_metadata_mutation = maybe_storage_changed and not same_storages + if check_only_storage_mutation: + return has_storage_metadata_mutation + + # storage metadata mutation is a type of metadata mutation, so return true if we saw one + if has_storage_metadata_mutation: + return True + + maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined] + # This is true if the current tensor experienced at least one metadata mutation. + # So if false, we know there was no metadata mutation + if not maybe_metadata_mutated: + return False + + # However, multi metadata mutations can cancel out. + # So we also check if the concrete sizes/strides on the tensor have changed. + same_sizes = arg.shape == arg_after.shape + same_strides = arg.stride() == arg_after.stride() + same_offsets = arg.storage_offset() == arg_after.storage_offset() + has_metadata_mutation_ = maybe_metadata_mutated and not ( + same_sizes and same_strides and same_offsets + ) + # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. + return has_metadata_mutation_ + + +def gen_alias_from_base( + aliased_base_tensor, + target_meta_tensor, + target_requires_grad, + target_view_meta_sequence: ViewMetaSequence | None = None, + *, + replay_views: bool, +): + # Patch the correct requires_grad field of the output tensor, depending on whether: + # (i) the reconstructed output (out) was came from a tensor that requires grad or not; + # and (ii) the concrete returned output does require grad or not. + def patch_requires_grad(out): + if aliased_base_tensor.requires_grad and not target_requires_grad: + out = out.detach() + elif not aliased_base_tensor.requires_grad and target_requires_grad: + out.requires_grad_(True) + return out + + # If provided, use the target functional tensor for replaying the views. + # + # In summary, we use the fact that FunctionalTensorWrapper saves the view + # functions applied to itself (collected during functionalization) so as + # to replay them (view functions) on the aliased_base_tensor. + if ( + replay_views + and target_view_meta_sequence is not None + and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) + ): + out = _functionalization.apply_view_meta_sequence( + aliased_base_tensor, target_view_meta_sequence.sequence + ) + # If re-applying the ViewMeta sequence succeeded, there should be no more + # problems going forward. We just check we got to the target shape and + # patch requires_grad flag. + assert out.shape == target_meta_tensor.shape, ( + "incorrect out shape after application of ViewMeta sequence: " + f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" + ) + return patch_requires_grad(out) + + # Try to do view-replay if possible. + # fall back to .as_strided() if we can't. + if target_meta_tensor._base is not None: + # The base that we want to replay our view off of might have a different shape than the view's original base. + b = target_meta_tensor._base + abt = aliased_base_tensor + # Don't unnecessarily call as_strided if nothing changed; as_strided's + # backward is poorly implemented and slow + if abt is not b and ( + abt.size() != b.size() + or abt.stride() != b.stride() + or abt.storage_offset() != b.storage_offset() + ): + reshaped_base_tensor = aliased_base_tensor.as_strided( + b.size(), b.stride(), b.storage_offset() + ) + else: + reshaped_base_tensor = aliased_base_tensor + out = target_meta_tensor._view_func(reshaped_base_tensor) + # This shape mismatch can happen due to a bug in inplace/view handling in autograd. + # Try putting a breakpoint here and running + # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types` + # Also, https://github.com/pytorch/pytorch/issues/49825 + # + # As a stopgap, we'll fall back to as_strided. + if out is not None and out.shape == target_meta_tensor.shape: + return patch_requires_grad(out) + + size = target_meta_tensor.size() + stride = target_meta_tensor.stride() + storage_offset = target_meta_tensor.storage_offset() + if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): + aliased_out = torch.view_as_real(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): + aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + else: + aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) + # For outputs aliasing inputs, we need to check if the requires-gradness has changed. + aliased_out = patch_requires_grad(aliased_out) + # For outputs aliasing inputs, we need to check if the dtype has changed. + # as_strided() is the "most generic" view, but it does not cover cross-dtype views + if aliased_out.dtype != target_meta_tensor.dtype: + aliased_out = aliased_out.view(target_meta_tensor.dtype) + return aliased_out + + +def has_same_metadata(t1, t2): + return ( + guard_or_false(sym_eq(t1.size(), t2.size())) + and guard_or_false(t1.layout == t2.layout) + and ( + is_sparse_any(t1) + or ( + guard_or_false(sym_eq(t1.stride(), t2.stride())) + and guard_or_false(t1.storage_offset() == t2.storage_offset()) + ) + ) + and t1.is_conj() == t2.is_conj() + and t1.is_neg() == t2.is_neg() + ) + + +@dataclass(frozen=True) +class MetadataKey: + """ + This should be equal whenever has_same_metadata would return True + """ + + size: tuple[SymIntEqByExpr, ...] + layout: torch.layout + is_sparse: bool + # these are empty when is_sparse + stride: tuple[SymIntEqByExpr, ...] | None + storage_offset: SymIntEqByExpr | None + is_conj: bool + is_neg: bool + + @staticmethod + def make(t): + is_sparse = is_sparse_any(t) + return MetadataKey( + size=tuple(SymIntEqByExpr(s) for s in t.size()), + layout=t.layout, + is_sparse=is_sparse, + stride=None if is_sparse else tuple(SymIntEqByExpr(s) for s in t.stride()), + storage_offset=None if is_sparse else SymIntEqByExpr(t.storage_offset()), + is_conj=t.is_conj(), + is_neg=t.is_neg(), + ) + + +# ViewMeta sequence wrapper for equality comparisons. +# +# Even though we can compare each ViewMeta instance, we compare the resulting +# tensor metadata, instead. That's because the creation of synthetic bases + the +# re-generation of input views might end-up creating a different sequence of +# ViewMeta that is semantically equivalent. i.e. gets to a tensor with the same +# metadata. +# +# Therefore, we store what the end result should look like as serializable +# metadata. +# +# When logging, this class should look like: +# +# ViewMetaSequence(view, select_int, slice_Tensor) +# +# i.e. a parenthesized list of view operations within that ViewMeta sequence. +class ViewMetaSequence: + def __init__(self, tensor: FunctionalTensor) -> None: + assert torch._is_functional_tensor(tensor.elem) + self.sequence = _functionalization.get_view_meta_sequence(tensor.elem) + self.metadata = MetadataKey.make(tensor) + + def __repr__(self) -> str: + suffix = len("_ViewMeta") + types = ", ".join(type(vm).__name__[:-suffix] for vm in self.sequence) + return f"ViewMetaSequence({types})" + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the ViewMeta sequence. One example is when we update the view metadata by + # calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison against any other type is not implemented. + if not isinstance(other, ViewMetaSequence): + return NotImplemented + + return self.metadata == other.metadata + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed +# +# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization +# to confirm that inputs were not mutated when running the user's model with functionalization on. +# But when we have subclass inputs, we can't rely on that: +# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs +# a brand new subclass instance: we are calling __tensor_unflatten__, and going +# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) +def was_tensor_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, +# but shares storage with the old input +def was_tensor_metadata_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg and StorageWeakRef( + arg.untyped_storage() + ) == StorageWeakRef(new_arg.untyped_storage()) + + +# Returns the number of detected copy_ +def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: + allowed_mutation_ops = [ + torch.ops.aten.copy_.default, + torch.ops.aten.set_.source_Tensor, + ] + if hasattr(torch.ops.fsdp, "copy_"): + allowed_mutation_ops.append(torch.ops.fsdp.copy_.default) + + placeholders = set() + mutation_count = 0 + # NB: It would also be nice to verify that the mutations all happen at the + # end, but we also do some administrative views after mutations so this + # isn't actually true. (TODO: Could this cause problems for Inductor?) + error = None + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target in allowed_mutation_ops: + # Can only copy_/set_ into an input + # this is mostly a hack to avoid failing XLA tests. + # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 + if "set_buffer_donor_" not in str(n.args[0]): + if n.args[0] not in placeholders: + error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + mutation_count += 1 + else: + if n.target._schema.is_mutable: + error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return error, mutation_count + + +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + error, mutation_count = _is_functional_graph(fx_g) + assert error is None, error + return mutation_count + + +def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: + placeholders = set() + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target is torch.ops.aten.copy_.default: + # Can only copy_ into an input, and can only do so once + if "set_buffer_donor_" not in str(n.args[0]): + assert n.args[0] in placeholders, ( + f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + ) + placeholders.remove(n.args[0]) + copy_from_node = n.args[1] + # Pre-condition: every node has a "stack_trace" field in its meta, + # but copy_() nodes do not (since we manually added them during functionalization). + # Instead, we manually propagate here. + if "stack_trace" in copy_from_node.meta: + n.meta["stack_trace"] = copy_from_node.meta["stack_trace"] + + +def _check_if_mutation_can_be_in_graph( + keep_input_mutations: bool, + mutates_data, + mutates_metadata, + mutations_hidden_from_autograd, + mutations_under_no_grad_or_inference_mode, + mutates_storage_metadata, + mutation_inductor_storage_resize, + requires_grad, +): + if keep_input_mutations: + in_graph = ( + mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize + ) and ( + (not mutates_metadata and not requires_grad) + or mutations_hidden_from_autograd + or mutations_under_no_grad_or_inference_mode + ) + else: + in_graph = False + # See Note [set_() Input Mutations in AOTAutograd] + # If there was a `set_()`, we require that all mutations were under no_grad, + # so we can (safely) emit the set_() in the graph at runtime + # resize_() gets the same treatment + if mutation_inductor_storage_resize or mutates_storage_metadata: + op_name = "resize_" if mutation_inductor_storage_resize else "set_" + assert in_graph, f"""\ +Encountered a {op_name} on a graph input, but the input has other mutations that we cannot +keep in the graph. This is not supported today. Current state: + keep_input_mutations={keep_input_mutations} + mutates_data={mutates_data} + mutates_metadata={mutates_metadata} + mutations_hidden_from_autograd={mutations_hidden_from_autograd} + mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode} + mutation_inductor_storage_resize={mutation_inductor_storage_resize} + requires_grad={requires_grad}""" + return in_graph diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..7dceaee3dacb23e9fa7d83e8b200628d2d1a71e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py @@ -0,0 +1,506 @@ +# mypy: allow-untyped-defs +""" +This module dispatches the graphs to either the forward-only or joint compilation +pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. +""" + +import contextlib +import dataclasses +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torchgen.utils import dataclass_repr + +from .. import config +from .descriptors import AOTInput, BackwardTokenAOTInput +from .functional_utils import ( + assert_functional_graph, + propagate_input_mutation_stacktraces, +) +from .graph_capture_wrappers import ( + aot_dispatch_subclass, + create_functionalized_fn, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, + handle_effect_tokens_fn, +) +from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta +from .streams import assign_backward_streams, insert_backward_syncs, sync_deallocations +from .utils import ( + call_and_expect_output_descs, + copy_fwd_metadata_to_bw_nodes, + fn_wrappers, + register_buffer_assignment_hook, + root_module_when_exporting_non_strict, + simple_wraps, + unlift_tokens, +) + + +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + + +def _create_graph( + f, + args: list[torch.Tensor], + args_descs: Optional[ + list[AOTInput] + ] = None, # keep compat with old clients; maybe we should split into two impls + *, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + # FunctionalTensorMode must be enabled here. + # See Note [Accessing .grad_fn on FunctionalTensor] + out_descs = None + + if args_descs is None: + inner_f = f + else: + + @simple_wraps(f) + def inner_f(*args): + nonlocal out_descs + assert out_descs is None + out, out_descs = call_and_expect_output_descs(f, args) + return out + + if aot_config.disable_functionalization: + ctx = contextlib.nullcontext() + else: + ctx = FunctionalTensorMode( # type: ignore[assignment] + pre_dispatch=aot_config.pre_dispatch, + export=aot_config.is_export, + # Allow token discovery for joint fn tracing as tokens can be used in backward. + _allow_token_discovery=True, + ) + + with ( + enable_python_dispatcher(), + ctx, + ): + fx_g = make_fx( + inner_f, + decomposition_table=aot_config.decompositions, + record_module_stack=True, + pre_dispatch=aot_config.pre_dispatch, + )(*args) + + if args_descs is not None: + flat_args_descs, _ = pytree.tree_flatten(args_descs) + flat_out_descs, _ = pytree.tree_flatten(out_descs) + + # Unfortunately, flat_args_descs is not guaranteed to match the + # number of actual arguments that show up on the FX graph. + # Specifically, allow_token_discovery=True means that we will + # silently add extra token arguments to the backwards graph. + # + # Although there are a few ways to detect what these tokens are, + # we are going to settle for something dodgy but simple to + # implement: match tangents_token placeholders specifically, + # as these are the only placeholders that are created by token + # discovery (NB: there is NO other code that treats this name + # as load bearing, so this is a bit naughty!) + # + # I originally wanted to detect tokens in exactly the same way + # that they are detected at normal runtime, but to be honest + # the normal runtime detection is pretty strange: it seems the + # backward tokens are not reliably at the end of the argument list + # but *precede* the RNG arguments (I don't understand why this is + # the case). And in unlift_tokens, token arguments are detected + # by seeing if they feed into an effects call! Dastardly. Why + # didn't we just introduce a new type. + + i = 0 + j = 0 + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if n.name.startswith("tangents_token"): + n.meta["desc"] = BackwardTokenAOTInput(j) + j += 1 + else: + assert i < len(flat_args_descs), ( + (fn_wrappers(inner_f)), + [n for n in fx_g.graph.nodes if n.op == "placeholder"], + flat_args_descs, + ) + n.meta["desc"] = flat_args_descs[i] + i += 1 + elif n.op == "output": + n.meta["desc"] = flat_out_descs + + return fx_g + + +# TODO: Refactor the following code so detach() persists item_memo +def _detach_and_copy_item_memo(t): + detached_t = t.detach() + if hasattr(t, "item_memo"): + detached_t.item_memo = t.item_memo + return detached_t + + +def aot_dispatch_base_graph( + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, list[FxValue], list[AOTInput], Optional[SubclassMeta]]: + # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. + # The cases that aot_dispatch_base doesn't need to handle include: + # - outputs that are aliases of graph intermediates + # - outputs that are aliases of graph inputs + # While cases that it does need to handle include: + # - input mutations (including when inputs are aliases of each other) + # - input metadata mutations + fn_to_trace = fn_input_mutations_to_outputs( + flat_fn, + flat_args_descs, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + + if aot_config.disable_functionalization: + updated_flat_args, updated_flat_args_descs = ( + flat_args, + flat_args_descs, + ) + else: + fn_to_trace, updated_flat_args, updated_flat_args_descs = ( + create_functionalized_fn( + fn_to_trace, + flat_args, + flat_args_descs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=False, + ) + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + maybe_subclass_meta, + ) = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + updated_flat_args_descs, + is_joint_structure=False, + meta=fw_metadata, + fw_only=flat_fn, + ) + + if not aot_config.disable_functionalization: + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + ) = handle_effect_tokens_fn( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + meta=fw_metadata, + trace_joint=False, + ) + + aot_graphs_log.debug( + "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(maybe_subclass_meta), + ) + + # We track buffer assignments when exporting in non-strict mode. + # (In contrast, strict mode errors on any attribute assignment.) + mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn) + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be added as a buffer mutation output. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + mod_when_exporting_non_strict, assigned_buffers + ) + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, + _detach_and_copy_item_memo, + updated_flat_args_subclasses_desugared, + ) + else: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared + ) + saved_updated_flat_args_subclasses_desugared_descs = ( + updated_flat_args_subclasses_desugared_descs + ) + + fw_module = _create_graph( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + aot_config=aot_config, + ) + + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # We update metadata to consider any assigned buffers as buffer mutations. + i = len(dict(mod_when_exporting_non_strict.named_parameters())) + for name, _ in mod_when_exporting_non_strict.named_buffers(): + if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined] + fw_metadata.input_info[i] = dataclasses.replace( + fw_metadata.input_info[i], mutates_data=True + ) + fw_metadata.num_mutated_inp_runtime_indices += 1 + i += 1 + + # We add nodes corresponding to buffer assignments as output nodes in the graph. + add_nodes = [] + output_node = list(fw_module.graph.nodes)[-1] + for name in assigned_buffers.values(): # type: ignore[possibly-undefined] + for node in fw_module.graph.nodes: + if node.name == name: + add_nodes.append(node) + node.users[output_node] = None + output_node.args = ((*add_nodes, *output_node.args[0]),) + + hook.remove() # type: ignore[possibly-undefined] + + # As long as we opted to remove input mutations, then + # there should be *NO* mutating ops in the graph at this point. + if not aot_config.disable_functionalization: + copy_count = assert_functional_graph(fw_module.graph) + fw_module.graph.eliminate_dead_code() + fw_module.recompile() + copy_count2 = assert_functional_graph(fw_module.graph) + propagate_input_mutation_stacktraces(fw_module.graph) + assert copy_count == copy_count2 + else: + fw_module.graph.eliminate_dead_code() + + # See Note [Side-Effectful Tokens in AOTAutograd] + num_tokens = len(fw_metadata.tokens) + if num_tokens != 0 and config.unlift_effect_tokens: + unlift_tokens(fw_module, fw_metadata, aot_config) + saved_updated_flat_args_subclasses_desugared = ( + saved_updated_flat_args_subclasses_desugared[num_tokens:] + ) + saved_updated_flat_args_subclasses_desugared_descs = ( + saved_updated_flat_args_subclasses_desugared_descs[num_tokens:] + ) + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_inference_graph", + payload_fn=lambda: fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ), + ) + + # TODO: should factor this into a separate function for export that always only returns just the graph. + if aot_config.is_export: + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) + return ( + fw_module, + saved_updated_flat_args_subclasses_desugared, + saved_updated_flat_args_subclasses_desugared_descs, + maybe_subclass_meta, + ) + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd_graph( + flat_fn: TraceFn, + flat_args: list[Any], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[ + torch.fx.GraphModule, + tuple[list[Any], list[Any]], + tuple[list[AOTInput], list[AOTInput]], + Optional[SubclassMeta], +]: + # NB: flat_fn here is the original user function (as far as + # aot_module_simplified is concerned) + + # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. + joint_inputs = (flat_args, fw_metadata.traced_tangents) + joint_inputs_descs = (flat_args_descs, fw_metadata.traced_tangents_descs) + + fn_prepared_for_autograd = fn_prepped_for_autograd( + flat_fn, + flat_args_descs, + fw_metadata, + aot_config, + ) + joint_fn_to_trace = create_joint( + fn_prepared_for_autograd, flat_args_descs, aot_config=aot_config + ) + joint_fn_handle = joint_fn_to_trace.handle + + if aot_config.disable_functionalization: + updated_joint_inputs, updated_joint_inputs_descs = ( + joint_inputs, + joint_inputs_descs, + ) + else: + joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = ( + create_functionalized_fn( + joint_fn_to_trace, + joint_inputs, + joint_inputs_descs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + joint_fn_handle=joint_fn_handle, + ) + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + subclass_tracing_info = aot_dispatch_subclass( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + is_joint_structure=True, + meta=fw_metadata, + fw_only=flat_fn, + ) + + joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_joint_inputs = subclass_tracing_info.plain_tensor_args + updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs + + if not aot_config.disable_functionalization: + (joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = ( + handle_effect_tokens_fn( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + meta=fw_metadata, + trace_joint=True, + ) + ) + + # When we call _create_graph, this may mutate the metadata of joint + # inputs. But callers are expecting to get the original joint inputs. So + # we make aliases of all the inputs to make sure we have a copy that + # doesn't get modified. + # + # This destroys requires_grad/grad_fn information. However, backends + # beneath AOTAutograd are indifferent to this information, so it doesn't + # matter. + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, _detach_and_copy_item_memo, updated_joint_inputs + ) + else: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_joint_inputs + ) + maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta + + fx_g = _create_graph( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + aot_config=aot_config, + ) + + # Redundant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + + # Have to copy before eliminate_dead_code otherwise the + # fw node match might be erased + copy_fwd_metadata_to_bw_nodes(fx_g) + + # After copying metadata, assign streams to gradient accumulation nodes + assign_backward_streams(fx_g) + + # Insert syncs for newly assigned backward streams + insert_backward_syncs(fx_g) + + # Sync deallocations for tensors where the stream w/ their last usage + # is distinct from their allocation strea + sync_deallocations(fx_g) + + fx_g.graph.eliminate_dead_code() + if not aot_config.disable_functionalization: + # There should be *NO* mutating ops in the graph at this point. + assert_functional_graph(fx_g.graph) + + fx_g.recompile() + + # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect + # when we need to manually detach() some inputs in the forward. + # Higher order ops might eventually need to do the same. + if aot_config.is_export: + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) + return ( + fx_g, + saved_updated_joint_inputs, + updated_joint_inputs_descs, + maybe_subclass_meta, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef84cb488604c1c55b36890f270f3255a8ee138 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -0,0 +1,1395 @@ +# mypy: allow-untyped-defs +""" +This module is responsible for transforming functions to be traced into a form +that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) +to handle. + +It does so by: +1. functionalization (including RNG functionalzation) +2. creating a joint graph when required +3. transforming mutations into extra outputs +4. dispatching subclasses +""" + +import warnings +from collections.abc import Callable +from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import Any, Optional, TypeVar, Union +from unittest.mock import patch + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker +from torch._guards import detect_fake_mode +from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.proxy_tensor import ( + _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, + maybe_disable_thunkify, + maybe_enable_thunkify, +) +from torch.fx.experimental.symbolic_shapes import ( + guard_or_true, + PropagateUnbackedSymInts, + sym_eq, +) +from torch.nn.utils import stateless +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._pytree import TreeSpec + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .descriptors import ( + AOTInput, + AOTOutput, + BackwardTokenAOTOutput, + ForwardTokenAOTInput, + ForwardTokenAOTOutput, + GradAOTOutput, + InputMutationAOTOutput, + IntermediateBaseAOTOutput, + PhiloxBackwardBaseOffsetAOTInput, + PhiloxBackwardSeedAOTInput, + PhiloxForwardBaseOffsetAOTInput, + PhiloxForwardSeedAOTInput, + PhiloxUpdatedBackwardOffsetAOTOutput, + PhiloxUpdatedForwardOffsetAOTOutput, +) +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, + was_inductor_storage_resized, +) +from .logging_utils import setup_stacktrace_preservation_hooks +from .schemas import ( + AOTConfig, + FxValue, + JointTraceFn, + MutationType, + OutputType, + PreppedForAutogradTraceFn, + SubclassMeta, + SubclassTracingInfo, + TraceFn, + ViewAndMutationMeta, +) +from .subclass_utils import ( + create_subclass_meta, + remap_unwrapped_subclass_arg_indices, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from .utils import ( + call_and_expect_output_descs, + maybe_to_fresh_input, + simple_wraps, + without_output_descs, +) + + +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( + fn: Callable, + args_descs: list[AOTInput], + meta: ViewAndMutationMeta, + keep_data_input_mutations: bool, +) -> Any: + @simple_wraps(fn) + def inner_fn(*args): + outs, outs_descs = call_and_expect_output_descs(fn, args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) + mutated_input_pairs = [ + (x, InputMutationAOTOutput(src)) + for (i, (x, src)) in enumerate(zip(args, args_descs)) + if i in meta.mutated_inp_runtime_indices + ] + if mutated_input_pairs: + mutated_inputs_to_return, mutated_inputs_to_return_descs = zip( + *mutated_input_pairs + ) + else: + mutated_inputs_to_return, mutated_inputs_to_return_descs = (), () + return ( + (*mutated_inputs_to_return, *outs), + (*mutated_inputs_to_return_descs, *outs_descs), + ) + + return inner_fn + + +@contextmanager +def disable_autocast(): + with ExitStack() as stack: + autocast_enabled_devices = torch._C._autocast_supported_devices() + for device_type in autocast_enabled_devices: + if hasattr(torch, device_type): + stack.enter_context(torch.amp.autocast(device_type, enabled=False)) + yield + + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: TraceFn, + args_descs: list[AOTInput], + meta: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> PreppedForAutogradTraceFn: + @simple_wraps(fn) + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs, outs_descs = call_and_expect_output_descs(fn, args_maybe_cloned) + assert isinstance(outs, (tuple, list)) + outs = list(outs) + assert len(meta.output_info) == len(outs) + + mutated_input_pairs = [ + (x, InputMutationAOTOutput(src)) + for (i, (x, src)) in enumerate(zip(args_maybe_cloned, args_descs)) + if i in meta.mutated_inp_runtime_indices + ] + if mutated_input_pairs: + mutated_inputs_to_return, mutated_inputs_to_return_descs = zip( + *mutated_input_pairs + ) + else: + mutated_inputs_to_return, mutated_inputs_to_return_descs = (), () + + intermediate_bases = [] + intermediate_bases_descs = [] + for o, info, o_desc in zip(outs, meta.output_info, outs_descs): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + assert isinstance(o, torch.Tensor), ( + f"Expected tensor for intermediate base, got {type(o)}" + ) + intermediate_bases.append(o._base) + intermediate_bases_descs.append(IntermediateBaseAOTOutput(o_desc)) + + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + fw_outs_to_return_descs = ( + *mutated_inputs_to_return_descs, + *outs_descs, + *intermediate_bases_descs, + ) + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data + and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, Tensor) + and meta.output_info[i].requires_grad + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = ( + mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + ) + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + if not aot_config.disable_functionalization: + for arg in args_maybe_cloned: + if not isinstance(arg, Tensor): + continue + sync_functional_tensor(arg) + + return (fw_outs_to_return, out_grad_mask), ( + fw_outs_to_return_descs, + out_grad_mask, + ) + + return inner_fn + + +@dataclass +class JointFnHandle: + post_forward: Optional[Callable] = None + + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint( + fn: Any, # PreppedForAutogradTraceFn + primals_descs: Optional[list[AOTInput]] = None, + *, + aot_config: AOTConfig, +) -> Any: # JointTraceFn + joint_fn_handle = JointFnHandle() + + # post_forward + # NB: this type is inaccurate when primals_descs is None + @simple_wraps(fn) + def inner_fn( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: + outs_descs = None + if primals_descs is None: + outs, tangent_mask = fn(*primals) + assert not pytree.tree_any(lambda x: isinstance(x, AOTOutput), tangent_mask) + else: + (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( + fn, primals + ) + mode = get_proxy_mode() + assert mode is not None, "Expected non-None proxy mode" + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" + + # TODO: I think this hook can also be eliminated now + if joint_fn_handle and joint_fn_handle.post_forward: + joint_fn_handle.post_forward(primals) + + assert len(tangent_mask) == len(outs) + outs_to_grad = [ + o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent + ] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals: list[torch.Tensor] = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + if isinstance(p, Tensor) and p.requires_grad: + inputs_needs_grads.append(True) + assert isinstance(p, torch.Tensor) # Help mypy understand the type + grad_primals.append(p) + else: + inputs_needs_grads.append(False) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + # The guard_or_true also sketchy; if unbacked + # symints are involved, we're just going to assume that the + # decomps setup the base shape correctly + + # Return out if the result of out.shape==tangent.shape is unknown or known to be true. + # otherwise if its a known false return out.view(tangent.shape). + # tangent should also be a tensor since it corresponds to a tensor output + assert isinstance(tangent, torch.Tensor), ( + f"Expected tensor tangent, got {type(tangent)}" + ) + needed_outs.append( + out + if guard_or_true(sym_eq(out.shape, tangent.shape)) + else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + if config.functionalize_rng_ops: + PhiloxStateTracker.mark_beginning_of_backward() + backward_out: tuple[Tensor, ...] = () + # Call the backwards pass + if grad_primals: + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + if functional_tensor_mode is not None: + # Side-Effect Tokens: + # We want to have independent chains of tokens for forward and backward. + # functional_tensor_mode._tokens is used by both. + # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output, + # to return them as joint graph outputs. + # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward. + # Joint graph tracing allows tokens discovery, + # So all the tokens in backward will be created and added as a graph inputs during tracing. + functional_tensor_mode._tokens_forward_output = ( + functional_tensor_mode._tokens + ) + functional_tensor_mode._tokens = {} + + with ( + set_partitioner_tag_is_backward(), + fx_traceback.preserve_node_meta(), + ExitStack() as stack, + ): + backward_pass_autocast = torch._functorch.config.backward_pass_autocast + if backward_pass_autocast == "same_as_forward": + # Use the ambient autocast mode(s) + pass + elif backward_pass_autocast == "off": + stack.enter_context(disable_autocast()) + else: + # Disable autocast, then enable anything in `backward_pass_autocast`. + stack.enter_context(disable_autocast()) + assert isinstance(backward_pass_autocast, list) + for kwargs in backward_pass_autocast: + assert isinstance(kwargs, dict) + stack.enter_context(torch.amp.autocast(**kwargs)) + + # for full graph export, we always export a joint graph where we assume no tangents are needed. + if aot_config.no_tangents: + assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + allow_unused=True, + ) + else: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + final_outs = ( + outs, + [next(backward_out_iter) if i else None for i in inputs_needs_grads], + ) + if primals_descs is None: + return final_outs # type: ignore[return-value] + assert outs_descs is not None + return final_outs, ( + outs_descs, + [ + # TODO: ideally we do know this is DifferentiableAOTInput + # but this is quite an involved refactor + GradAOTOutput(desc) if i else None # type: ignore[arg-type] + for i, desc in zip(inputs_needs_grads, primals_descs) + ], + ) + + @simple_wraps(inner_fn) + def inner_fn_with_anomaly( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[Optional[Tensor]]], + tuple[list[AOTOutput], list[Optional[AOTOutput]]], + ]: + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") + with torch.autograd.detect_anomaly(check_nan=False): + return inner_fn(primals, tangents) + + def joint_helper(primals, tangents): + return inner_fn_with_anomaly(primals, tangents) + + joint_helper.handle = joint_fn_handle # type: ignore[attr-defined] + + return joint_helper + + +def create_functionalized_rng_ops_wrapper( + func, args, args_descs, trace_joint=True +) -> Any: + # Functionalization of rng ops changes the calling convention of the joint graph. + # It goes from (primals, tangents) to (seed, offset, primals, tangents) + # At runtime, we pass on the current seed and offset. This is hidden from + # the user. + fake_mode_det = detect_fake_mode() + fake_mode: AbstractContextManager[Any] = nullcontext() + if fake_mode_det is not None: + fake_mode = fake_mode_det + + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): + out = PhiloxStateTracker.get_state_as_tensor() + return out + + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): + PhiloxStateTracker.set_state_from_tensor(x) + + def append_rng_offsets(outs, outs_descs): + if trace_joint: + # outs signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) + # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) + return ( + ( + (*outs[0], PhiloxStateTracker.get_updated_fwd_offset()), + (*outs[1], PhiloxStateTracker.get_updated_bwd_offset()), + ), + ( + (*outs_descs[0], PhiloxUpdatedForwardOffsetAOTOutput()), + (*outs_descs[1], PhiloxUpdatedBackwardOffsetAOTOutput()), + ), + ) + else: + # outs signature before: Tuple(fwd_outputs) + # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset) + return ( + (*outs, PhiloxStateTracker.get_updated_fwd_offset()), + (*outs_descs, PhiloxUpdatedForwardOffsetAOTOutput()), + ) + + def traced_joint( + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset + ): + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), + ): + return append_rng_offsets(*func(primals, tangents)) + + def traced_forward(*primals_fwd_seed_fwd_base_offset): + # The signature is (*primals, seed, offset) + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), + ): + return append_rng_offsets(*func(*primals_fwd_seed_fwd_base_offset[:-2])) + + if trace_joint: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") + return ( + traced_joint, + ( + *args, + fwd_seed, + fwd_base_offset, + bwd_seed, + bwd_base_offset, + ), + ( + *args_descs, + PhiloxForwardSeedAOTInput(), + PhiloxForwardBaseOffsetAOTInput(), + PhiloxBackwardSeedAOTInput(), + PhiloxBackwardBaseOffsetAOTInput(), + ), + ) + else: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + return ( + traced_forward, + (*args, fwd_seed, fwd_base_offset), + ( + *args_descs, + PhiloxForwardSeedAOTInput(), + PhiloxForwardBaseOffsetAOTInput(), + ), + ) + + +@contextmanager +def set_partitioner_tag(tag: str): + meta_key = "partitioner_tag" + assert fx_traceback.has_preserved_node_meta() + + original_val = fx_traceback.current_meta.get(meta_key, None) + fx_traceback.current_meta[meta_key] = tag + try: + yield + finally: + fx_traceback.current_meta[meta_key] = original_val + + +def set_partitioner_tag_is_backward(): + return set_partitioner_tag("is_backward") + + +def set_partitioner_tag_must_be_in_backward(): + return set_partitioner_tag("must_be_in_backward") + + +def set_partitioner_tag_must_be_in_forward(): + return set_partitioner_tag("must_be_in_forward") + + +@dataclass +class MutationCounters: + mc_data: int + mc_storage: int + mc_inductor_storage_resized: int + + +T = TypeVar("T") + + +def sc_visit( + t, fn: Callable[[Tensor], T], reduce_fn: Callable[[T, T], T], accum_init: T +) -> T: + if not is_traceable_wrapper_subclass(t): + return fn(t) + + accum = accum_init + + def visit(e): + if not is_traceable_wrapper_subclass(e): + nonlocal accum + accum = reduce_fn(accum, fn(e)) + return + + for a in e.__tensor_flatten__()[0]: + visit(getattr(e, a)) + + visit(t) + return accum + + +def _get_mutation_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_mutation_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_storage_changed_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_storage_changed_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_inductor_storage_resized_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_inductor_storage_resized_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_mutation_counters(t) -> MutationCounters: + return MutationCounters( + _get_mutation_counter(t), + _get_storage_changed_counter(t), + _get_inductor_storage_resized_counter(t), + ) + + +def apply_in_graph_mutations( + input_info, + inpt_old, + inpt_new, + f_inpt, + input_idx, + mcs: Optional[MutationCounters] = None, + applied_mcs: Optional[MutationCounters] = None, +): + assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if input_info.mutates_storage_metadata: + if mcs is None or mcs.mc_storage > applied_mcs.mc_storage: # type: ignore[union-attr] + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if input_info.mutation_inductor_storage_resize: + if ( + mcs is None + or mcs.mc_inductor_storage_resized > applied_mcs.mc_inductor_storage_resized # type: ignore[union-attr] + ): + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import FunctionalTensor + + assert isinstance(f_inpt, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=False + ) + if old_storage_size != new_storage_size: + assert old_storage_size == 0 or new_storage_size == 0, f"""\ + Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (thee for FSDP)""" + torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + return + + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + return + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + + if not input_info.mutates_data: + return + + if mcs is not None and mcs.mc_data <= applied_mcs.mc_data: # type: ignore[union-attr] + return + + if input_info.mutations_hidden_from_autograd: + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif input_info.mutations_under_no_grad_or_inference_mode: + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + + with torch.no_grad(): + inpt_old.copy_(inpt_new) + else: + inpt_old.copy_(inpt_new) + + +# This creates the final function that we want to trace using make_fx(), +# in both aot_dispatch_autograd and aot_dispatch_base. +# Preconditions: +# - fn corresponds to the user's fw function +# - fn arguments have been flattened, duplicate arguments have been handled +# - In the returned function, the "primals" arguments *includes* synthetic bases. +# This function does the work of functionalizing the input function, +# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. +# The function returned has signature that is either: +# (1) "traced_fn(primals: List[Any])" if trace_joint is False +# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True +# Returns a new (functionalized) function, and updated arguments to call it with. +def create_functionalized_fn( + fn, + args, + args_descs, + *, + meta: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, + joint_fn_handle: Optional[JointFnHandle] = None, +) -> Any: + primals_after_forward = None + f_args_after_forward = None + f_args_mutation_counters_after_forward: Optional[list[MutationCounters]] = None + inputs_mutated_in_graph = [ + info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info + ] + has_input_mutated_in_graph = any(inputs_mutated_in_graph) + + @simple_wraps(fn) + def _functionalized_f_helper( + *args: list[FxValue], + ) -> tuple[tuple[list[FxValue], list[Tensor]], list[Optional[AOTOutput]]]: + with maybe_enable_thunkify(): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # The functionalization code here can potentially trigger traces + # into the graph, but we'd prefer to NOT do this, because if we + # trace them now, we will end up with FX nodes that don't have + # module stack annotations, which makes unflattener unhappy. + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + + if trace_joint and has_input_mutated_in_graph and joint_fn_handle: + # TODO(ivankobzarev): Support fw and bw mutations for subclasses + def _post_forward(primals): + nonlocal primals_after_forward + primals_after_forward = pytree.tree_map(from_fun, primals) + nonlocal f_args_after_forward + f_args_after_forward = f_args[0] + nonlocal f_args_mutation_counters_after_forward + f_args_mutation_counters_after_forward = [ + MutationCounters(-1, -1, -1) + if not inputs_mutated_in_graph[i] + else _get_mutation_counters(f_arg) + for i, f_arg in enumerate(f_args_after_forward) + ] + + joint_fn_handle.post_forward = _post_forward + + # Run the joint + f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args) + + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip(f_args[0], primals_before, primals_after, meta.input_info) + ): + # Store information about mutations in joint(for backward analysis) + joint_mutates_data = has_data_mutation(f_inpt) + + joint_mutates_metadata = has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) + + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert not joint_mutates_metadata, ( + "Found a graph input that had its metadata mutated in the backward. This is not supported" + ) + + # Ban storage resizing on fw inputs during the bw + if not inpt_info.mutation_inductor_storage_resize: + assert not was_inductor_storage_resized(f_inpt), ( + "Found a graph input that had storage resizing in the backward. This is not supported" + ) + + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if ( + joint_mutates_data + and not inpt_info.mutates_data + and not inpt_info.mutates_storage_metadata + ): + # Not banning here mutations on inpt_info.requires_grad - + # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) + # Add node meta for copy_ for partitioner that this node should be in backward graph. + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): + # before and after should be tensors if we're calling copy_ on them + assert isinstance(before, torch.Tensor) and isinstance( + after, torch.Tensor + ) + before.copy_(after) + meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( + idx + ) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ), ( + "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + ) + if has_data_mutation(f_inpt): + can_be_in_graph = _check_if_mutation_can_be_in_graph( + keep_input_mutations=True, + mutates_data=True, + mutates_metadata=False, + mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd( + f_inpt + ), + mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode( + f_inpt + ), + mutates_storage_metadata=False, + mutation_inductor_storage_resize=was_inductor_storage_resized( + f_inpt + ), + requires_grad=f_inpt.requires_grad, + ) + assert can_be_in_graph, ( + "a backward input that had data mutated in an autograd-aware way. This is not supported" + ) + # Perform the input mutation + with torch.fx.traceback.preserve_node_meta(): + # before and after should be tensors if we're calling copy_ on them + assert isinstance(before, torch.Tensor) and isinstance( + after, torch.Tensor + ) + before.copy_(after) + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + + # Apply in graph forward mutations only in joint case. + # Note: Mutations of primals in forward AND backward. + # If we have mutations of the same input in forward and in backward, + # we can not fuse them into one copy_ node. As in this case partitioner will put it + # either in forward or in backward. This will lead to incorrect state + # after forward and before backward. + # We have to emit two copy_ nodes, marking with additional meta each node, + # if it must be in forward or backward. + # We memorize mutation counter of the inputs after forward. + # Based on this after joint graph we check if backward also mutated input or not. + # We emit copy_ only in the end of joint tracing, to provide invariant for joint + # graph passes, that our graph is functional, except only some number of copy_ nodes + # in the end. + mcs_applied: list[MutationCounters] = [MutationCounters(0, 0, 0)] * len( + meta.input_info + ) + if f_args_mutation_counters_after_forward is not None: + primals_before = args[0] + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip( + f_args_after_forward, # type: ignore[arg-type] + primals_before, # type: ignore[arg-type] + primals_after_forward, # type: ignore[arg-type] + meta.input_info, + ) + ): + if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH: + continue + + mcs_after_forward = f_args_mutation_counters_after_forward[idx] + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_forward(), + _proxy_tensor_disable_update_tensor_tracker(), + ): + apply_in_graph_mutations( + inpt_info, + before, + after, + f_inpt, + idx, + mcs_after_forward, + mcs_applied[idx], + ) + mcs_applied[idx] = mcs_after_forward + + for idx, (inpt_old, f_inpt) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) # type: ignore[arg-type] + ): + if not isinstance(f_inpt, torch.Tensor): + continue + assert is_fun(f_inpt) + inpt_new = from_fun(f_inpt) + if ( + meta.input_info[idx].mutation_type + != MutationType.MUTATED_IN_GRAPH + ): + continue + mcs: Optional[MutationCounters] = None + if f_args_mutation_counters_after_forward is not None: + # This could happen for subclasses tracing + # Subclasses support for mutations in fw and bw is TBD. + mcs = _get_mutation_counters(f_inpt) + if mcs == mcs_applied[idx]: + # No mutation in backward; mutation was already applied. + continue + + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): + apply_in_graph_mutations( + meta.input_info[idx], + inpt_old, + inpt_new, + f_inpt, + idx, + mcs, + mcs_applied[idx], + ) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i in range(num_outs): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + fw_args = args[0] if trace_joint else args + flat_outs[i] = fw_args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec), f_outs_descs + + return pytree.tree_map(from_fun, f_outs), f_outs_descs + + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export + def joint_helper(primals, tangents): + return _functionalized_f_helper(primals, tangents) + + helper = joint_helper if trace_joint else _functionalized_f_helper + if config.functionalize_rng_ops: + # Setup the wrapper for functionalization of rng ops + helper, args, args_descs = create_functionalized_rng_ops_wrapper( + helper, args, args_descs, trace_joint + ) + + return helper, args, args_descs + + +def handle_effect_tokens_fn( + fn, + args, + args_descs: list[AOTInput], + *, + meta: ViewAndMutationMeta, + trace_joint: bool, +) -> Any: + num_tokens = len(meta.tokens) + + @simple_wraps(fn) + def inner_fn(*args): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert isinstance(args, tuple) and isinstance(args[0], (list, tuple)) + tokens = args[0][:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = (args[0][num_tokens:], *args[1:]) + else: + tokens = args[:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = args[num_tokens:] + + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + assert functional_tensor_mode is not None + f_tokens = pytree.tree_map(to_fun, tokens) + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] + + # Run the joint + outs, outs_descs = call_and_expect_output_descs(fn, args) + + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert len(outs) == 2 + assert len(functional_tensor_mode._tokens_forward_output) == num_tokens + fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values() + + bwd_out_tokens = functional_tensor_mode._tokens.values() + + f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens] + f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens] + f_fwd_out_tokens_descs = [ + ForwardTokenAOTOutput(i) for i in range(len(fwd_out_tokens)) + ] + f_bwd_out_tokens_descs = [ + BackwardTokenAOTOutput(i) for i in range(len(bwd_out_tokens)) + ] + + meta.num_backward_tokens = len(bwd_out_tokens) + return ( + ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)), + ( + (*f_fwd_out_tokens_descs, *outs_descs[0]), + (*outs_descs[1], *f_bwd_out_tokens_descs), + ), + ) + + out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()] + # TODO: can probably do a little more resolution here + out_tokens_descs = [ + ForwardTokenAOTOutput(i) + for i in range(len(functional_tensor_mode._tokens.values())) + ] + return ((*out_tokens, *outs), (*out_tokens_descs, *outs_descs)) + + # Additionally pass in tokens as inputs + # See Note [Side-Effectful Tokens in AOTAutograd] + additional_fwd_token_inputs = [torch.tensor([])] * num_tokens + additional_fwd_token_inputs_descs = [ + ForwardTokenAOTInput(i) for i in range(num_tokens) + ] + + if trace_joint: + args = ([*additional_fwd_token_inputs, *args[0]], *args[1:]) + args_descs = ( # type: ignore[assignment] + [*additional_fwd_token_inputs_descs, *args_descs[0]], # type: ignore[misc] + *args_descs[1:], + ) + else: + args = [*additional_fwd_token_inputs, *args] + args_descs = [*additional_fwd_token_inputs_descs, *args_descs] + return inner_fn, args, args_descs + + +# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor +# Also returns: +# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) +# - the updated ViewAndMutationMeta for this dense -> dense function. +# The other important arguments are: +# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. +# when is_joint_structure=False, this is just the forward function. +# - fw_only: this is *always* the forward-only function. +# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. +# In particular, we need this to tell the partitioner how many dense forward outputs there are. +def aot_dispatch_subclass( + flat_fn_maybe_joint: Union[JointTraceFn, TraceFn], + args: Union[list[FxValue], tuple[list[FxValue], list[FxValue]]], + args_descs: Union[list[AOTInput], tuple[list[AOTInput], list[AOTInput]]], + *, + is_joint_structure: bool, + meta: ViewAndMutationMeta, + fw_only: Callable, +) -> SubclassTracingInfo: + # Skip logic if we don't need to trace through any subclasses + req_subclass_dispatch = requires_subclass_dispatch(args, meta) + if not req_subclass_dispatch: + return SubclassTracingInfo( + plain_tensor_trace_fn=flat_fn_maybe_joint, + plain_tensor_args=args, + plain_tensor_args_descs=args_descs, + maybe_subclass_meta=None, + ) + + # TODO: add subclass guards (later PR). + + # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). + # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, + # so we set it later, while we're tracing the joint (see inner_fn() below). + # Another option would be to run our run_functionalized_fw_and_collect_metadata() function + # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). + subclass_meta = SubclassMeta() + + # NB: doesn't take descs, this is going from the NEW flat_args to the + # subclasses, we don't need to do bookkeeping here + def inner_fn(fn, args, *, use_trace_joint: bool): + # Step 1: wrap tensor inputs into subclasses if necessary + all_args = wrap_tensor_subclasses_maybe_joint( + args, is_joint_structure=use_trace_joint, meta=meta + ) + + # Step 2: call the inner function, with our (maybe subclass) inputs + wrapped_outs, wrapped_outs_descs = call_and_expect_output_descs(fn, all_args) + + if use_trace_joint: + # See Note: [Computing Subclass Metadata about grad_inputs] + # We also stash subclass info on our grad_inputs, if we're tracing the joint. + nonlocal subclass_meta + assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2, ( + wrapped_outs, + wrapped_outs_descs, + ) + # Don't need fw outs since we already have subclass metadata on them + grad_inputs = wrapped_outs[1] + subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) + + # Add extra symints as outputs to the forward/backward graphs + # ignore nested ints here + forward_outs, forward_outs_descs = unwrap_tensor_subclasses( + wrapped_outs[0], wrapped_outs_descs[0], append_symints=True + ) + # ignore nested ints here + backward_outs, backward_outs_descs = unwrap_tensor_subclasses( + wrapped_outs[1], wrapped_outs_descs[1], append_symints=True + ) + return ( + (forward_outs, backward_outs), + (forward_outs_descs, backward_outs_descs), + ) + + # Step 3: Unwrap any subclass outputs back into dense tensors + return unwrap_tensor_subclasses( + wrapped_outs, wrapped_outs_descs, append_symints=True + ) + + def joint_fn( + primals: list[FxValue], tangents: list[FxValue] + ) -> tuple[ + tuple[list[FxValue], list[FxValue]], tuple[list[AOTOutput], list[AOTOutput]] + ]: + with maybe_enable_thunkify(): + return inner_fn( + flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True + ) + + def fw_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + with maybe_enable_thunkify(): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + + def metadata_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + @simple_wraps(fw_only) + def inner_fw_only(*args): + return call_and_expect_output_descs(fw_only, args) + + return inner_fn(inner_fw_only, primals, use_trace_joint=False) + + if is_joint_structure: + # Add extra symints (size/strides) as input to the forward graph + primals_unwrapped_pair = unwrap_tensor_subclasses( + args[0], # type: ignore[arg-type] + args_descs[0], # type: ignore[arg-type] + append_symints=True, + ) + # We pass append_symints=False here because the partitioner will + # capture and add any extra argument + tangents_unwrapped_pair = unwrap_tensor_subclasses( + args[1], # type: ignore[arg-type] + args_descs[1], # type: ignore[arg-type] + append_symints=False, + ) + + args_unwrapped = (primals_unwrapped_pair[0], tangents_unwrapped_pair[0]) + args_descs_unwrapped = (primals_unwrapped_pair[1], tangents_unwrapped_pair[1]) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args[0], meta.static_input_indices + ) + else: + args_unwrapped, args_descs_unwrapped = unwrap_tensor_subclasses( # type: ignore[assignment] + args, # type: ignore[arg-type] + args_descs, # type: ignore[arg-type] + append_symints=True, + ) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args, meta.static_input_indices + ) + + if is_joint_structure: + primals_unwrapped = args_unwrapped[0] # type: ignore[assignment] + primals_unwrapped_descs = args_descs_unwrapped[0] # type: ignore[assignment] + fn_to_trace = joint_fn # type: ignore[assignment] + else: + primals_unwrapped = args_unwrapped # type: ignore[assignment] + primals_unwrapped_descs = args_descs_unwrapped # type: ignore[assignment] + fn_to_trace = fw_fn # type: ignore[assignment] + + # Note: [Partitioner handling for Subclasses, Part 1] + # The way the partitioner works is that: + # (1) we pass is a single graph containing the joint fw/bw, + # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs + # (2) The partitioner accepts an arguments, num_fwd_outputs, + # and assumes that the first "num_fwd_outputs" graph outputs correspond + # to outputs of the forward graph. + # How do tensor subclasses enter the picture? + # the num_fwd_outputs in the final graph is actually non-trivial to compute, + # because it can be influenced by input mutations and intermediate bases. + # So we compute it by inspecting the current ViewAndMutationMeta object. + # However, the original ViewAndMutationMeta that we computed was created + # on the subclass -> subclass graph, + # which can have a different number of outputs than the dense -> dense graph. + # That's why we created a fresh metadata object on the dense -> dense function here, + # and plumb it back up to the partitioner. + # See Note: [Partitioner handling for Subclasses, Part 2] for more info. + meta_updated = run_functionalized_fw_and_collect_metadata( + without_output_descs(metadata_fn), + # pyrefly: ignore [bad-argument-type] + flat_args_descs=primals_unwrapped_descs, + static_input_indices=remapped_static_indices, + keep_input_mutations=meta.keep_input_mutations, + is_train=meta.is_train, + # pyrefly: ignore [not-iterable] + )(*primals_unwrapped) + + subclass_meta.fw_metadata = meta_updated + + return SubclassTracingInfo( + plain_tensor_trace_fn=fn_to_trace, + plain_tensor_args=args_unwrapped, + plain_tensor_args_descs=args_descs_unwrapped, + maybe_subclass_meta=subclass_meta, + ) + + +def create_functional_call( + mod, params_spec, params_len, store_orig_mod=False, strict_out_tuple=True +): + # Redundant with dynamo, but worth having in case this gets invoked elsewhere. + # https://github.com/pytorch/pytorch/issues/103569 + + @simple_wraps(mod) + def functional_call(*args, **kwargs): + flat_params = args[:params_len] + if isinstance(params_spec, TreeSpec): + params = pytree.tree_unflatten(flat_params, params_spec) + else: + assert isinstance(params_spec, list) + params = dict(zip(params_spec, flat_params)) + with ( + stateless._reparametrize_module(mod, params), + maybe_disable_thunkify(), + ): + if isinstance(mod, torch.fx.GraphModule): + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). + arg_list = list(args[params_len:]) + arg_list.extend(list(kwargs.values())) + args = tuple(arg_list) + else: + args = args[params_len:] + + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + fake_mode = detect_fake_mode() + assert fake_mode is not None + fake_mode.epoch += 1 + out = PropagateUnbackedSymInts(mod).run(*args) + else: + out = mod(*args[params_len:], **kwargs) + + if strict_out_tuple and not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a (). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + # Note [Preserving the nn module stack metadata during export non-strict mode] + # This path is currently only used by the non-strict export flow, + # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. + # Instead, we stash the original user nn module here, and rely on `make_fx` to grab + # this stashed module and use it to track nn module stack metadata + if store_orig_mod and not hasattr(functional_call, "_orig_mod"): + functional_call._orig_mod = mod # type: ignore[attr-defined] + + return functional_call diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b1939a741e57daee2dd0fde613730743225ddb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py @@ -0,0 +1,2338 @@ +# mypy: allow-untyped-defs +""" +Functions in this module do most of the "work" of AOTAutograd. +An aot_dispatch_* function: +- Takes in the input flat_fn, flat_args, and some metadata +- Runs a set of pre compile wrappers (e.g. argument deduping) +- Runs the actual compiler +- Wraps the returned callable in a set of post compile wrappers +- Returns the wrapped callable and metadata. +""" + +import copy +import dataclasses +import itertools +import logging +import operator +import time +import traceback +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from typing import Any, Optional, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Sequence + +import threading +from contextlib import contextmanager + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import ( + CompileEventLogger, + detect_fake_mode, + dynamo_timed, + lazy_format_graph_code, +) +from torch._guards import CompileContext, TracingContext +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses import FakeTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals, guard_or_true +from torch.fx.graph_module import GraphModule +from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars +from torch.multiprocessing.reductions import StorageWeakRef +from torch.types import py_sym_types +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torchgen.utils import dataclass_repr + +from .. import config +from .aot_autograd_result import GenericAOTAutogradResult, serialize_graph_module +from .autograd_cache import ( + AOTAutogradCache, + should_bundle_autograd_cache, + should_use_remote_autograd_cache, +) +from .descriptors import AOTOutput, PlainAOTOutput +from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph +from .logging_utils import track_graph_compiling +from .runtime_wrappers import ( + AOTDedupeWrapper, + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + AOTSyntheticBaseWrapper, + AutogradLazyBackwardCompileInfo, + CompilerWrapper, + DebugAssertWrapper, + EffectTokensWrapper, + FakifiedOutWrapper, + FunctionalizedRngRuntimeWrapper, + make_runtime_safe, + post_compile, + pre_compile, + RuntimeWrapper, + SerializableCompiledFunction, +) +from .schemas import ( + AOTConfig, + AOTGraphCapture, + AOTState, + FlatFn, + FxValue, + MutationType, + SubclassMeta, + ViewAndMutationMeta, +) +from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta +from .utils import ( + contain_metadata_mutation_ops, + get_cuda_generator_meta_val, + make_boxed_func, + simple_wraps, + strict_zip, + unlift_tokens, +) + + +_thread_local = threading.local() + + +@contextmanager +def maybe_skip_decompose(aot_config: AOTConfig): + old_decomp = aot_config.decompositions + try: + if config.selective_decompose: + aot_config.decompositions = {} + yield + finally: + aot_config.decompositions = old_decomp + + +# Saved tensor hooks context +# Compiled saved tensor hooks are convenient way to inline some logic in the graphs +# for saved nodes from forward to backward. (E.g. activations quantization) +# In base implementation user does not have any additional information about saved value +# in the hook, except FakeTensor shape, dtype, device etc. +# _get_saved_tensor_hook_context gives additional graph information about that saved value, +# that can be used to make a decisions which pack/unpack to apply for particular saved value. +# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in +# graph aware way. +# Alternative to this will be making user to write a custom pass that mucks with forward outputs, +# backward input metadata, which requires significantly more effort. +# +# As for now in context we expose forward graph, backward graph and current saved node, +# which contains node.meta with additional information about that fx.Node. +# Warning: This API may change without backward compatibility. +@contextmanager +def _saved_tensor_hook_context(state: dict[str, Any]): + previous_state = getattr(_thread_local, "state", None) + try: + _thread_local.state = state + yield + finally: + # Clean up: restore previous state or remove attribute + if previous_state is not None: + _thread_local.state = previous_state + else: + if hasattr(_thread_local, "state"): + delattr(_thread_local, "state") + + +def _get_saved_tensor_hook_context() -> dict[str, Any] | None: + return getattr(_thread_local, "state", None) + + +zip = strict_zip + +log = logging.getLogger(__name__) +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + +aten = torch.ops.aten + +# Returns a Callable and a ViewAndMutationMeta. +# Currently, only export needs the ViewAndMutationMeta after this function. +# TODO: Refactor this +DispatchReturn = tuple[Callable, ViewAndMutationMeta] + + +def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]: + """ + Wrappers that run on every dispatch function + """ + return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] + + +def aot_stage1_graph_capture( + aot_state: AOTState, + orig_flat_fn: FlatFn, +) -> AOTGraphCapture: + # NB: flat_fn at this point coincides with the initial info from forward + # metadata collection returning a list[Tensor]. We are now going to + # augment the output to return a tuple[list[Tensor], list[AOTOutput]] and + # then preserve this convention through the rest of the passes. + + # TODO: We could test for consistency with fw_metadata, but this is not a + # big deal + @simple_wraps(orig_flat_fn) + def orig_flat_fn2(*args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + out = orig_flat_fn(*args) + out_descs: list[AOTOutput] = type(out)( # type: ignore[assignment] + PlainAOTOutput(i) # type: ignore[misc] + for i in range(len(out)) # type: ignore[misc] + ) + return out, out_descs + + aot_config = aot_state.aot_config + + wrappers = _create_wrappers_for_dispatch(aot_state.needs_autograd) + flat_fn, aot_state.flat_args, aot_state.flat_args_descs, aot_state.fw_metadata = ( + pre_compile( + wrappers, + orig_flat_fn2, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + ) + + # NB: This is currently only used for backwards, where fwd/bwd + # deterministic TLS can be different + aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + with maybe_skip_decompose(aot_config): + # if config.selective_decompose, skip decomposition and apply selective_decompose + # after we get the joint graph. See [Note: Selective Decomposition] for details. + if aot_state.needs_autograd and not aot_config.pre_dispatch: + # FYI: this being moved to trigger in export is new, seems fine! + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + ( + graph, + updated_flat_args, + updated_flat_args_descs, + maybe_subclass_meta, + ) = aot_dispatch_autograd_graph( + flat_fn, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + else: + graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = ( + aot_dispatch_base_graph( + flat_fn, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + ) + # Apply AC rematerialization to forward+loss+bwd graph + if torch._functorch.config.remat_using_tags_for_fwd_loss_bwd_graph: + from torch._functorch._activation_checkpointing.remat_using_tags_for_fwd_loss_bwd_graph_pass import ( + remat_using_tags_for_fwd_loss_bwd_graph, + ) + + graph = remat_using_tags_for_fwd_loss_bwd_graph(graph) + + if config.selective_decompose: + from torch.fx.experimental.proxy_tensor import selective_decompose + from torch.fx.passes.regional_inductor import _needs_inductor_compile + + graph = selective_decompose( + graph, + *updated_flat_args, + decomposition=aot_config.decompositions, + should_decompose=_needs_inductor_compile, + trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch, + ) + + return AOTGraphCapture( + wrappers=wrappers, + graph_module=graph, + updated_flat_args=updated_flat_args, + updated_flat_args_descs=updated_flat_args_descs, + maybe_subclass_meta=maybe_subclass_meta, + ) + + +def aot_stage2_export( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture +) -> DispatchReturn: + graph = aot_graph_capture.graph_module + aot_config = aot_state.aot_config + wrappers = aot_graph_capture.wrappers + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="export") + + # NB: the wrappers that run in pre_compile for export are + # either a no-op, because they're not needed, or will raise a runtime error, + # since they don't support export. + # We still run these wrappers to make sure that they're not needed pre compile, + # but we technically don't need to run them post compile at all here. + compiled_fn, aot_state.fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=aot_state.fw_metadata + ) + + # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph + # (either a joint or an inference-only graph) + assert isinstance(compiled_fn, torch.fx.GraphModule) + return compiled_fn, aot_state.fw_metadata + + +def sanitize_aot_config(input: AOTConfig) -> AOTConfig: + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + inference_compiler=None, + num_params_buffers=input.num_params_buffers, + aot_id=input.aot_id, + keep_inference_input_mutations=input.keep_inference_input_mutations, + is_export=input.is_export, + no_tangents=input.no_tangents, + aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source, + dynamic_shapes=input.dynamic_shapes, + enable_log=input.enable_log, + static_input_indices=input.static_input_indices, + pre_dispatch=input.pre_dispatch, + cache_info=None, + precompile_backend_id=input.precompile_backend_id, + ) + + +def _get_inner_meta( + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, +) -> ViewAndMutationMeta: + """ + Util to get view and mutation metadata. + """ + return ( + fw_metadata if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata + ) + + +def _apply_tensorify_python_scalars(module: torch.fx.GraphModule) -> None: + """ + Util to apply tensorify_python_scalars. + """ + # TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes. + fake_mode = detect_fake_mode() + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(module, fake_mode.shape_env, fake_mode) + + +def aot_stage2_compile( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, + partition_fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + inference_compiler: Optional[Callable] = None, +) -> DispatchReturn: + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + # Update the AOTState with the provided compilers + aot_state.aot_config.partition_fn = partition_fn + aot_state.aot_config.fw_compiler = fw_compiler + aot_state.aot_config.bw_compiler = bw_compiler + aot_state.aot_config.inference_compiler = inference_compiler + + if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch: + return aot_stage2_autograd(aot_state, aot_graph_capture) + else: + return aot_stage2_inference(aot_state, aot_graph_capture) + + +def _log_inference_graph( + fw_module: torch.fx.GraphModule, + aot_config: AOTConfig, +) -> Optional[str]: + """ + Log the inference graph to the structured logger. + Return a str representation of the graph. + """ + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(), + ) + + # Save the forward_graph_str right after aot_dispatch_base_graph, + # to save in the cache + aot_forward_graph_str = None + if aot_config.cache_info is not None: + aot_forward_graph_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + expanded_def=True, + ) + + return aot_forward_graph_str + + +def _aot_stage2b_inference_compile( + fw_module: torch.fx.GraphModule, + updated_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config, +) -> Callable: + return _aot_stage2b_compile_forward_or_inference( + fw_module, + updated_flat_args, # type: ignore[arg-type] + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=True, + )[1] + + +def aot_stage2_inference( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + """ + Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. + """ + + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + fw_module = aot_graph_capture.graph_module + wrappers = aot_graph_capture.wrappers + updated_flat_args = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference") + aot_forward_graph_str = _log_inference_graph(fw_module, aot_config) + + assert isinstance(fw_module, GraphModule) + _apply_tensorify_python_scalars(fw_module) + + compiled_fw = _aot_stage2b_inference_compile( + fw_module, + updated_flat_args, # type: ignore[arg-type] + maybe_subclass_meta, + fw_metadata, + aot_config, + ) + + entry = _cache_inference_info( + aot_config, + fw_metadata, + maybe_subclass_meta, + compiled_fw, + aot_forward_graph_str, + wrappers, + ) + + return _aot_stage2c_make_inference_function( + aot_config, + fw_metadata, + compiled_fw, + wrappers, + entry, + ) + + +def _cache_inference_info( + aot_config, + fw_metadata, + maybe_subclass_meta, + compiled_fw, + aot_forward_graph_str, + wrappers, +): + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + cache_info = aot_config.cache_info + + def should_save_cache(): + if should_bundle_autograd_cache(): + return True + else: + return hasattr(compiled_fw, "_fx_graph_cache_key") + + entry: Optional[GenericAOTAutogradResult] = None + if cache_info is not None and should_save_cache(): + time_taken_ns = time.time_ns() - cache_info.start_time_ns + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + entry = AOTAutogradCache.make_entry( + compiled_fw_func=compiled_fw, # type: ignore[arg-type] + compiled_bw_func=None, + aot_joint_graph_str=None, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=None, + runtime_metadata=fw_metadata, + dispatch_wrappers=wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + indices_of_inps_to_detach=[], + forward_time_taken_ns=time_taken_ns, + backward_time_taken_ns=0, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=None, + num_symints_saved_for_bw=None, + serialized_bw_module=None, + ) + AOTAutogradCache.save( + cache_info.cache_key, + entry, + remote=should_use_remote_autograd_cache(), + ) + + return entry + + +def _aot_stage2c_make_inference_function( + aot_config, + fw_metadata, + compiled_fw, + wrappers, + entry, +): + if entry is not None: + compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry) + + disable_amp = torch._C._is_any_autocast_enabled() + compiled_fn = RuntimeWrapper( + indices_of_inps_to_detach=[], + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fn = post_compile( + wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata + ) + return compiled_fn + + +def collect_fw_donated_buffer_idxs( + fw_ins: list[Optional[FakeTensor]], + user_fw_outs: list[Optional[FakeTensor]], + bw_outs: list[Optional[FakeTensor]], + saved_tensors: list[FakeTensor], +) -> list[int]: + """ + Checks if the saved tensors are donated buffers, which means a saved tensor is not + an alias of any tensors in fw_ins, user_fw_outs, and bw_outs. + """ + + storage_refs = set() + + for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): + # Only access storage if a tensor has storage (not sparse) + if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t): + storage_refs.add(StorageWeakRef(t.untyped_storage())) + + num_saved_tensor = len(saved_tensors) + donated_buffer_idxs = [] + for i in range(num_saved_tensor): + t = saved_tensors[i] + if ( + t is not None + and not is_sparse_any(t) + and StorageWeakRef(t.untyped_storage()) not in storage_refs + ): + donated_buffer_idxs.append(i) + + return donated_buffer_idxs + + +def collect_bw_donated_buffer_idxs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, +) -> list[int]: + """ + Collects backward donated buffer indexes from fw_module and bw_module. + """ + + # [Note: Metadata mutation in proxy tracing] + # node.meta["val"] is a snapshot of the tensor value when tracing a graph, + # instead of the final state after the graph has run. node.meta["val"] is + # not updated even if later there is a metadata mutation op. + # See: https://github.com/pytorch/pytorch/pull/141308#issuecomment-2495798947 + # + # Currently, metadata mutation op happens only for sacrificial parameter + # specifically the `set_` op. This motivates banning metadata mutation from + # proxy tracing. + # + # Since node.meta["val"] is used to detect donated buffer, we return an empty + # list if there exists metadata mutation op. + if contain_metadata_mutation_ops(fw_module) or contain_metadata_mutation_ops( + bw_module + ): + return [] + + fw_ins = fw_module.graph.find_nodes(op="placeholder") + bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0] + fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0] + + fw_ins = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_ins + ] + fw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_outs + ] + bw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in bw_outs + ] + + user_fw_outs = fw_outs[: fw_metadata.num_forward] + saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice] + + fw_donated_buffer = collect_fw_donated_buffer_idxs( + fw_ins, + user_fw_outs, + bw_outs, + # pyrefly: ignore [bad-argument-type] + saved_tensors, + ) + + assert fw_metadata.num_symints_saved_for_bw is not None + return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer] + + +@dataclasses.dataclass +class InvokeSubgraphHopGraphs: + """ + A data structure to hold all the information needed to partition the + `joint_hop_gm` and joint graph and the restitch the `new_fw_hop_gm` and + `new_bw_hop_gm` into the bigger `joint_gm`. + """ + + # To avoid re-partitioning subgraphs + partitioning_done: bool = False + old_num_fw_outputs: Optional[int] = None + old_num_fw_inputs: Optional[int] = None + + new_fw_hop_gm: Optional[torch.fx.GraphModule] = None + new_bw_hop_gm: Optional[torch.fx.GraphModule] = None + new_num_sym_nodes: Optional[int] = None + new_num_saved_nodes: Optional[int] = None + + +def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}") + else: + env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}") + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + +def run_joint_graph_passes_on_hops( + joint_gm: torch.fx.GraphModule, + joint_inputs: Any, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + """ + This pass runs the joint graph passes on the HOP graph. In torch.compile, we + typically have many passes which work on the joint graph and then end with a + partitioner. + + + The partitioner part is quite mechanical to handle. HOP have their own + forward and backward graph. The process can be broken into following steps + + 1) Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` + 2) Run joint graph passes on the `joint_hop_gm` to get `new_fw_hop_gm` and `new_bw_hop_gm` + 3) Stitch the `new_fw_hop_gm` and `new_bw_hop_gm` back into the `joint_gm`. + + The terminology used in the code is + `joint_graph/joint_gm` : Refers to the main graph. This may contain many HOPs which have their own `hop_graph` + `fw_hop_graph/fw_hop_gm` : Refers to the forward graph associated with a HOP. + `bw_hop_graph/bw_hop_gm` : Refers to the backward graph associated with a HOP. + `joint_hop_graph/joint_hop_gm` : Refers to the subgraph associated with the HOP like invoke_subgraph. + `new_fw_hop_graph/new_fw_hop_gm` : Refers to the forward graph after partitioning is applied to `joint_hop_gm`. + `new_bw_hop_graph/new_bw_hop_gm` : Refers to the backward graph after partitioning is applied to `joint_hop_gm`. + + NB: This pass works for invoke_subgraph today because we took extra care in + the Autograd.Dispatch key of invoke_subgraph to vastly simplify Step 1. + """ + from torch._higher_order_ops import invoke_subgraph + + def num_outputs(mod): + return len(mod.graph.find_nodes(op="output")[0].args[0]) + + def num_inputs(mod): + return len(mod.graph.find_nodes(op="placeholder")) + + new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( + lambda: InvokeSubgraphHopGraphs() + ) + + # Step 1 - Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` This is + # easy to do for `invoke_subgraph` HOP. During the Autograd dispatch key + # tracing, we have put the joint_hop_graph in the backward hop graph itself. + # So to recover the joint_hop_gm, we just have to look at the backward + # HOP graphs. + # So we will merge step 1 and step 2 in this next section + + # Save the fw and bwd hop nodes. We will later in-place modify the graph + # using these nodes. + fw_hop_nodes = [] + bw_hop_nodes = [] + for node in joint_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is invoke_subgraph + and isinstance(node.args[1], str) + ): + if node.args[1].startswith("fw"): + fw_hop_nodes.append(node) + elif node.args[1].startswith("bw"): + bw_hop_nodes.append(node) + + if not bw_hop_nodes: + return joint_gm + + assert len(fw_hop_nodes) == len(bw_hop_nodes) + + # Create a bw to hop node mapping. This helps us in identifying the bw and + # fw subgraph pairs without relying on the identifier. This is important + # because we can have different subgraphs for bwd for same subgraph in the + # fwd because of differing strides in the backward. + bw_to_fw_hop_node = dict(zip(list(reversed(bw_hop_nodes)), fw_hop_nodes)) + + for node in bw_hop_nodes: + identifier = node.args[1].removeprefix("bw") + + # If partitioning already done for this identifier, skip. This saves + # redundant joint graph passes for same subgraphs. + if new_hop_graphs[identifier].partitioning_done: + continue + + # Collect some information from the forward hop graph + fw_hop_node = bw_to_fw_hop_node[node] + fw_hop_gm = getattr(joint_gm, fw_hop_node.args[0].target) + assert isinstance(fw_hop_gm, torch.fx.GraphModule) + num_fw_inputs = num_inputs(fw_hop_gm) + num_fw_outputs = num_outputs(fw_hop_gm) + new_hop_graphs[identifier].old_num_fw_inputs = num_fw_inputs + new_hop_graphs[identifier].old_num_fw_outputs = num_fw_outputs + + # Step 1) - Get the `joint_hop_gm`. As mentioned earlier, the + # backward graph is the joint graph. + joint_hop_gm = getattr(joint_gm, node.args[0].target) + assert isinstance(joint_hop_gm, torch.fx.GraphModule) + + # Prepare the graph for the partitioner + joint_hop_gm = prepare_for_partitioner( + joint_hop_gm, num_fw_inputs, num_fw_outputs + ) + + # TODO: invoke_subgraph should track which of its inputs static indices + # so it can propagate them to the partitioner (and use in cudagraphs) + static_lifetime_input_indices: list[int] = [] + # Step 2) and 3) - Run joint graph passes and partitioner + new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn( + joint_hop_gm, + [], + num_fwd_outputs=num_fw_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + # Save the new forward and backward graph modules + new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm + new_hop_graphs[identifier].new_bw_hop_gm = new_bw_hop_gm + + # Save the number of symints and saved tensors + new_fw_out_nodes = new_fw_hop_gm.graph.find_nodes(op="output")[0].args[0] + extra_outputs = new_fw_out_nodes[num_fw_outputs:] + symint_outputs = [n for n in extra_outputs if is_sym_node(n)] + + new_hop_graphs[identifier].new_num_sym_nodes = len(symint_outputs) + new_hop_graphs[identifier].new_num_saved_nodes = len(extra_outputs) - len( + symint_outputs + ) + + new_hop_graphs[identifier].partitioning_done = True + + # Step 3) Restitch the new fw and bw graphs back into the main graph. + # + # This is a very mechanical process. There are a quite a few pieces that we + # need to connect together to make it work. Lets try to understand the + # problem statement first. + # + # For the forward graph, the signature of the old_fw_hop_gm is + # inputs - (*primals) + # outputs - (*fw_outs) + # Now the signature of the new_fw_hop_gm is + # inputs - (*primals) -- This is same + # outputs - (*fw_outs, *saved_tensors) - This is different + # At a high level, this is an easy transformation, in the new graph we just + # have to replace the old_fw_hop_gm with the new_fw_hop_gm. Everything else + # falls into place, because the input signature (i.e. args) is same. And + # even though output signature is different, fw_outs are still at the same + # indexes as before. So the forward of the `joint_gm` works nicely. + # + # Now, lets look at the backward hop graph. Old signature + # inputs - (*primals, *tangents) + # outputs - (*grad_outs, *fw_outs) + # New signature + # inputs - (*saved_tensors, *tangents) -- Different + # outputs - (*grad_outs) -- Different + # Here both input and output signature change. The output signature handling + # is quite easy because the grads_out are sitting at the right place, so we + # dont have to do anything. + # + # For the input signature, we have to collect the saved tensors from the + # corresponding forward graph output. We collect all saved_tensors when we + # see the forward graph, and save it into a map and then later use it during + # the backward. + + # The stack of fw_nodes for invoke_subgraph HOP. There is an implicit + # assumption about the graph structure, i.e., if we have hop1, hop2, hop3, + # ... in the forward part of the joint graph, we will have .., hop3, hop2, + # hop1 order for the backward. This structure allows us to just use a stack + # to collect all the information that we need to pass from the forward hop + # node to the corresponding backward node. + + already_added_new_hop_mods = set() + + def add_new_hop_gm(new_subgraph_mod, name): + new_subgraph_attr_name = f"partitioned_{name}" + if new_subgraph_attr_name in already_added_new_hop_mods: + return new_subgraph_attr_name + + joint_gm.register_module(new_subgraph_attr_name, new_subgraph_mod) + already_added_new_hop_mods.add(new_subgraph_attr_name) + return new_subgraph_attr_name + + def propagate_meta_info(new_hop_gm, new_call_function_node, old_call_function_node): + # Copy all the fields from the old call_function node. And then override + # the `val` meta field with the outputs of new_hop_gm. + new_call_function_node.meta = copy.copy(old_call_function_node.meta) + + output = new_hop_gm.graph.find_nodes(op="output")[0] + out_example_vals = [n.meta["val"] if n else None for n in output.args[0]] + new_call_function_node.meta["val"] = tuple(out_example_vals) + + for bw_node in reversed(bw_hop_nodes): + identifier = bw_node.args[1].removeprefix("bw") + + # Make changes to the corresponding fw and bw node pair simultaneously. + # The removes the need of any bookkeeping. + + # Fw node changes + # Insert the new_fw_hop_gm. This is straightforward. Get the + # new_fw_hop_gm, insert the hop_gm as a get_attr fw_node, and then + # add a call_function fw_node. Additionally, also use getitem + # call_functions to collect the saved_tensor nodes + + fw_node = bw_to_fw_hop_node[bw_node] + new_fw_hop_gm = new_hop_graphs[identifier].new_fw_hop_gm + assert new_fw_hop_gm is not None + + old_num_fw_outputs = new_hop_graphs[identifier].old_num_fw_outputs + new_num_sym_nodes = new_hop_graphs[identifier].new_num_sym_nodes + new_num_saved_nodes = new_hop_graphs[identifier].new_num_saved_nodes + assert old_num_fw_outputs is not None + assert new_num_sym_nodes is not None + assert new_num_saved_nodes is not None + total_outputs = old_num_fw_outputs + new_num_saved_nodes + new_num_sym_nodes + + extra_fw_outputs = [] + + # Insert the new_fw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(fw_node): + new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}") + new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name) + new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta) + + # new_hop_fw_gm output signature is (*fw_outs, *saved_tensors) + with joint_gm.graph.inserting_after(new_fw_mod_attr): + new_fw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_fw_mod_attr, + new_fw_mod_attr_name, + *fw_node.args[2:], + ), + ) + propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node) + + # old_num_fw_outputs = (*fw_outs) + # new_num_fw_outputs = (*fw_outs, *saved_tensors, *sym_nodes) + with joint_gm.graph.inserting_after(new_fw_node): + for fw_out_idx in range(old_num_fw_outputs, total_outputs): + saved_tensor_node = joint_gm.graph.call_function( + the_function=operator.getitem, args=(new_fw_node, fw_out_idx) + ) + saved_tensor_node.meta = copy.copy(new_fw_node.meta) + saved_tensor_node.meta["val"] = new_fw_node.meta["val"][fw_out_idx] + extra_fw_outputs.append(saved_tensor_node) + + fw_node.replace_all_uses_with(new_fw_node) + joint_gm.graph.erase_node(fw_node) + + # Bw node changes + # Prepare the operands for the bwd graph + # Old bw graph signature : (*primals, *tangents) + # New signature will be : (*sym_nodes, *saved_tensors, *tangents) + # We have already collected the saved_tensors in the forward hop processing. + + # extra_fw_outputs are in the order (*saved_nodes, *sym_nodes). + # Partitioner has this quirk where the backward wants sym_nodes + # first. So extract the sym and saved nodes. + + new_bw_hop_gm = new_hop_graphs[identifier].new_bw_hop_gm + assert new_bw_hop_gm is not None + + saved_tensor_nodes = extra_fw_outputs[:new_num_saved_nodes] + sym_nodes = extra_fw_outputs[new_num_saved_nodes:] + + num_primals = new_hop_graphs[identifier].old_num_fw_inputs + assert num_primals is not None + tangents = list(bw_node.args[2 + num_primals :]) + operands = sym_nodes + saved_tensor_nodes + tangents + + # Insert the new_bw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(bw_node): + new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1]) + new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name) + new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta) + + with joint_gm.graph.inserting_after(new_bw_mod_attr): + new_bw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_bw_mod_attr, + new_bw_mod_attr_name, + *operands, + ), + ) + propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node) + # Since the partitioner is run after the graph passes, we have lost + # the eager information and cannot faithfully extract the eager + # inputs for the new partitioned backward graph. For the forward + # graph, it was fine because the input signature remains same. + new_bw_node.meta.pop("eager_input_vals", None) + + bw_node.replace_all_uses_with(new_bw_node) + joint_gm.graph.erase_node(bw_node) + + joint_gm.graph.eliminate_dead_code() + joint_gm.graph.lint() + joint_gm.recompile() + return joint_gm + + +def maybe_log_graph( + gm, + graph_name, + aot_config, + structured_log_prefix_fn, + out_structured_logs: Optional[list[str]] = None, +): + if not aot_config.enable_log: + return + aot_graphs_log.debug( + "%s", + lazy_format_graph_code( + f"{graph_name}", + gm, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + def gm_str_fn() -> str: + return gm.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + + if out_structured_logs is not None: + out_structured_logs.append(f"{structured_log_prefix_fn()}:{gm_str_fn()}") + else: + trace_structured( + f"{structured_log_prefix_fn()}", + payload_fn=lambda: gm_str_fn(), + ) + + +def create_wrap_fn(fn, args): + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + + from .functional_utils import from_fun, has_data_mutation, to_fun + + def assert_no_mutation(t): + assert not has_data_mutation(t), ( + "Saved tensors hooks with inputs mutations are not allowed" + ) + + @simple_wraps(fn) + def _wrapper(*args): + with maybe_enable_thunkify(): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + f_args = pytree.tree_map(to_fun, args) + f_outs = fn(*f_args) + pytree.tree_map(assert_no_mutation, f_args) + return pytree.tree_map(from_fun, f_outs) + + return _wrapper, args + + +def prepare_hook_gm(aot_config, fn, args): + from torch._functorch._aot_autograd.graph_capture import _create_graph + + fn, args = create_wrap_fn(fn, args) + gm = _create_graph(fn, args, aot_config=aot_config) + return gm + + +# Inline Autograd saved_tensors_hooks into epilogue of forward graph +# and prologue of backward graph. +# This changes forward graph outputs and inputs. +# Pack hook can return tensors, sym scalars, constants. +# All tensors to save for backward will be grouped together at front. +# Sym scalars grouped on another end. Constants are inlined in the graph. +def maybe_inline_graph_saved_tensors_hooks( + fw_module, # torch.fx.GraphModule + bw_module, # torch.fx.GraphModule + num_inner_fwd_outputs, + inner_meta, + aot_config, + static_input_indices, +): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if not are_inline_hooks(hooks): + return + + pack_hook_gm, unpack_hook_gm = hooks + + structured_logs: list[str] = [] + maybe_log_graph( + fw_module, + "Forward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_forward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_backward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + fw_g = fw_module.graph + bw_g = bw_module.graph + + fw_g_names = {node.name for node in fw_g.nodes} + bw_g_names = {node.name for node in bw_g.nodes} + + def _gen_unused_name(candidate: str): + c = candidate + i = 0 + while c in fw_g_names or c in bw_g_names: + c = f"{candidate}_{i}" + i = i + 1 + return c + + bw_g_inputs = bw_g.find_nodes(op="placeholder") + + fw_out_n = fw_g.output_node() + fw_outs = fw_out_n.args[0] # type: ignore[var-annotated] + fw_outs_inner_set = set(fw_outs[:num_inner_fwd_outputs]) + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + fw_outs_packed_tensors = [] # type: ignore[var-annotated] + fw_outs_packed_syms = [] # type: ignore[var-annotated] + + # The main use case for saved_tensors_hooks is activation quantization, + # for memory usage optimization. + # Desired behavior is to quantize saved activations to free the original saved tensor. + # Saved nodes may include forward inputs, outputs, parameters. + # They may be held by something else and will not be deallocated after quantization. + # Donated buffers are intermediates in the graph invisible for the user, + # this guarantees that they can be deallocated. + # Using this as a default behavior to select saved nodes to apply hooks. + # There is also a config to apply hooks for all saved nodes without any filtering. + # The plan is to propagate meta about the source of the saved node to the user hook function. + mode = torch._functorch.config.saved_tensors_hooks_filtering_mode + allow_set = None + exclude_set = None + + if mode == "donated": + # collect_bw_donated_buffer_idxs requires inner_meta to have num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = len( + [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + ) + bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + fw_donated_idxs = [ + i - inner_meta.num_symints_saved_for_bw for i in bw_donated_idxs + ] + allow_set = {fw_outs_saved_for_bw[i].name for i in fw_donated_idxs} + elif mode == "no_static": + fw_g_inputs = fw_g.find_nodes(op="placeholder") + exclude_set = {fw_g_inputs[i].name for i in static_input_indices} + + if (allow_set is not None) and (not allow_set): + # This means we have empty whitelist, + # No donated (intermediate) saved. + # Do not do anything in this case + return + + if aot_config.enable_log: + structured_logs.append(f"fw_outs_saved_for_bw:{fw_outs_saved_for_bw}") + structured_logs.append(f"mode:{mode}") + structured_logs.append(f"allow_set:{allow_set}") + structured_logs.append(f"exclude_set:{exclude_set}") + + for saved in fw_outs_saved_for_bw: + if ((allow_set is not None) and (saved.name not in allow_set)) or ( + (exclude_set is not None) and (saved.name in exclude_set) + ): + if isinstance(saved.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(saved) + continue + + val = saved.meta["val"] + if not isinstance(val, torch.Tensor): + continue + + def _get_extra_info() -> dict[str, Any]: + return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved} + + with _saved_tensor_hook_context(_get_extra_info()): + pack_out_val = pack_hook_gm(val) + + requires_sc_handling = any( + is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) + ) + if requires_sc_handling: + raise NotImplementedError( + "Tensor subclasses in GraphModule saved tensors hooks are not supported" + "You can workaround it by manually returning subclass's inner tensors" + " in the pack hook, and reconstructing the subclass in the unpack hook" + ) + + with _saved_tensor_hook_context(_get_extra_info()): + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) + + # Install pack hook graph as eiplogue of fw_module. + # Saved tensor output becomes input of pack hook graph. + # Replace saved tensor output with pack hook graph output. + # Outputs symbolic scalars, tensors are accumulated separately. + # Then in forward outputs and backward inputs installed in order + # sym_scalars, packed_saved_tensors. + # Keeping all tensors together allows to preserve + # the same identification at runtime, + # updating only number of saved sym_scalars and tensors. + pack_g_inputs = pack_g.find_nodes(op="placeholder") + assert len(pack_g_inputs) == 1 + env = {pack_g_inputs[0]: saved} + fw_pack_out_args = None + with fw_g.inserting_before(fw_out_n): + for node in pack_g.nodes: + if node.op == "placeholder": + continue + new_n = fw_g.node_copy(node, lambda n: env[n]) + fw_g_names.add(new_n.name) + env[node] = new_n + # Output node is temporarily copied to have remapped arguments. + # Removed in the end. + if node.op == "output": + fw_pack_out_args = new_n.args[0] + fw_g.erase_node(new_n) + + env.clear() + assert fw_pack_out_args + fw_outs_bw_ins_node_names = [] + for out_idx, _n in enumerate(pytree.tree_leaves(fw_pack_out_args)): + if not isinstance(_n, torch.fx.Node): + fw_outs_bw_ins_node_names.append("") + continue + + # This happens when hook is noop and it is either user input or user output. + # Do not do anything with this node. + if _n.op == "placeholder" or _n in fw_outs_inner_set: + # This means the hook returned input primals unchanged + # Do not rename in this case. + n = _n + new_node_name = _n.name + fw_outs_bw_ins_node_names.append(new_node_name) + else: + # We can not specify desired name in node_copy. + # Copying node manually to set specific name, + # to have matching fw_outs, bw_inputs names. + new_node_name = _gen_unused_name(f"{saved.name}_hook_{out_idx}") + with fw_g.inserting_before(_n): + n = fw_g.create_node( + _n.op, + _n.target, + _n.args, + _n.kwargs, + name=new_node_name, + ) + assert n.name == new_node_name + fw_outs_bw_ins_node_names.append(new_node_name) + n.meta = copy.copy(_n.meta) + _n.replace_all_uses_with(n) + fw_g.erase_node(_n) + if isinstance(n.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(n) + elif is_sym_node(n): + fw_outs_packed_syms.append(n) + + # Install unpack hook graph as a prologue of backward graph + # Saved tensors inputs are replaced with packed tensors and packed sym scalars. + # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. + with _saved_tensor_hook_context(_get_extra_info()): + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) + + def find_saved_in_bw_inputs(bw_inputs): + for n in bw_inputs: + if n.name == saved.name: + return n + + bw_g_input = find_saved_in_bw_inputs(bw_g_inputs) + assert bw_g_input + original_bw_g_input_users = list(bw_g_input.users.keys()) + bw_g_input_used_directly = False + + # Replace backward graph saved tensor input with copy of pack graph outputs + # All non-Tensor, non-symscalars outputs are constanted. + + unpack_g_inputs = unpack_g.find_nodes(op="placeholder") + env = {} + for out_idx, (unp_in_n, out_n, val) in enumerate( + zip( + unpack_g_inputs, + pytree.tree_leaves(fw_pack_out_args), + pytree.tree_leaves(pack_out_val), + ) + ): + is_sym = isinstance(val, py_sym_types) + if isinstance(val, torch.Tensor) or is_sym: + # We want forward_outputs names to match backward_inputs, + # Potentially backward may already have "{saved.name}_hook_{idx}", + # In this case fx.Graph will add suffix. + new_node_name = fw_outs_bw_ins_node_names[out_idx] + if bw_g_input.name == new_node_name: + env[unp_in_n] = bw_g_input + bw_g_input_used_directly = True + else: + # Backward calling convention: ctx_symints,ctx_saved_tensors + # Inserting packed sym scalars before first saved tensor input. + # Inserting packed tensors before last saved tensor input. + # Saved tensor inputs between them will be removed. + with ( + bw_g.inserting_before(bw_g_inputs[0]) + if is_sym + else bw_g.inserting_before(bw_g_input) + ): + new_n = bw_g.placeholder(new_node_name) + assert new_n.name == new_node_name + new_n.meta = copy.copy(out_n.meta) + env[unp_in_n] = new_n + else: + # Inline values of non-Tensor, non-SymScalars + env[unp_in_n] = val + + # Inserting unpack hook after placeholders. + bw_unpack_out_n = None + with bw_g.inserting_before(bw_g_inputs[-1].next): + for node in unpack_g.nodes: + if node.op == "placeholder": + continue + new_n = bw_g.node_copy(node, lambda n: env[n]) + bw_g_names.add(new_n.name) + env[node] = new_n + # Temporary insert output, to have remapped by node_copy args. + # Removed in the end. + if node.op == "output": + bw_unpack_out_n = new_n + + assert bw_unpack_out_n + _leaves = pytree.tree_leaves(bw_unpack_out_n.args) + assert len(_leaves) == 1 + unpack_saved_tensor_n = _leaves[0] + + if not bw_g_input_used_directly: + bw_g_input.replace_all_uses_with(unpack_saved_tensor_n) + bw_g.erase_node(bw_g_input) + else: + # Keep usages of bw_g_input in inserted unpacked hook graph. + # Replace other usages of bw_g_input with unpack_saved_tensor_n. + for use_node in original_bw_g_input_users: + use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n) + bw_g.erase_node(bw_unpack_out_n) + + # Changing forward graph outputs, + # Inserting packed_tensors and packed_syms on the place of saved tensors. + # Packed sym_scalars are together with saved symints + symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + fw_new_outs = pytree.tree_leaves( + ( + fw_outs[:num_inner_fwd_outputs], + fw_outs_packed_tensors, + fw_outs_packed_syms, + symint_outs_saved_for_bw, + ) + ) + fw_out_n.args = (tuple(fw_new_outs),) + + # Assert that saved tensors and symints in forward outputs are aligned with backward inputs + _fw_n = num_inner_fwd_outputs + _fw_num_t = len(fw_outs_packed_tensors) + _fw_num_s = len(fw_outs_packed_syms) + len(symint_outs_saved_for_bw) + fw_outs_saved_tensors = fw_new_outs[_fw_n : _fw_n + _fw_num_t] + fw_outs_saved_syms = fw_new_outs[_fw_n + _fw_num_t :] + bw_new_ins = list(bw_g.find_nodes(op="placeholder")) + bw_ins_saved_syms = bw_new_ins[:_fw_num_s] + bw_ins_saved_tensors = bw_new_ins[_fw_num_s : _fw_num_s + _fw_num_t] + + fw_t_names = [n.name for n in fw_outs_saved_tensors] + bw_t_names = [n.name for n in bw_ins_saved_tensors] + fw_s_names = [n.name for n in fw_outs_saved_syms] + bw_s_names = [n.name for n in bw_ins_saved_syms] + + def _log_structured_logs(): + if not aot_config.enable_log: + return + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_saved_tensors_hooks_graphs", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(structured_logs), + ) + + if aot_config.enable_log: + structured_logs.append( + f"fw_outs[:num_inner_fwd_outputs]:{fw_outs[:num_inner_fwd_outputs]}" + ) + structured_logs.append(f"fw_outs_packed_tensors:{fw_outs_packed_tensors}") + structured_logs.append(f"fw_t_names:{fw_t_names}") + structured_logs.append(f"bw_t_names:{bw_t_names}") + structured_logs.append(f"fw_s_names:{fw_s_names}") + structured_logs.append(f"bw_s_names:{bw_s_names}") + structured_logs.append(f"\nfw_g_pre_assert:{fw_g}") + structured_logs.append(f"\nbw_g_pre_assert:{bw_g}") + maybe_log_graph( + fw_module, + "Forward graph after transform pre-assert", + aot_config, + lambda: "aot_forward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph after transform pre-assert", + aot_config, + lambda: "aot_backward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + _log_structured_logs() + + assert fw_t_names == bw_t_names + assert fw_s_names == bw_s_names + + fw_g.lint() + bw_g.lint() + fw_module.recompile() + bw_module.recompile() + + +def _log_joint_graph( + fx_g: torch.fx.GraphModule, + aot_config: AOTConfig, +) -> Optional[str]: + """ + Log the joint graph to the structured logger. + Return a str representation of the graph. + """ + joint_graph_str = None + if aot_config.enable_log: + aot_joint_log.info( + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + joint_graph_str = fx_g.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + trace_structured( + "aot_joint_graph", + payload_fn=lambda: joint_graph_str, + ) + return joint_graph_str + + +def _log_fw_bw_graphs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> tuple[Optional[str], Optional[str]]: + """ + Log the fw and bw graphs to the structured logger. + Return str representations of the graphs. + """ + fw_module_str = None + bw_module_str = None + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(), + ) + aot_graphs_log.info( + "aot_config id: %s, fw_metadata=%s, inner_meta=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(_get_inner_meta(maybe_subclass_meta, fw_metadata)), + ) + + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + fw_module_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + bw_module_str = bw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module_str, + ) + trace_structured( + "aot_backward_graph", + payload_fn=lambda: bw_module_str, + ) + return fw_module_str, bw_module_str + + +def _aot_stage2a_partition( + fx_g: torch.fx.GraphModule, + joint_inputs: Union[list[Any], tuple[list[Any], list[Any]]], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule, int, int, list[int], list[Any]]: + """ + Partition the joint graph into a forward graph and a backward graph. Returns: + - the forward and backward graphs + - the number of forward outputs and the number of symints saved for backward + - indices of inputs to detach + - adjusted inputs to forward + """ + disable_amp = torch._C._is_any_autocast_enabled() + inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata) + + with torch.no_grad(): + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(), track_graph_compiling(aot_config, "joint"): + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = ( + compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + ) + num_tokens = len(fw_metadata.tokens) + num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) + num_inner_fwd_outputs = ( + num_mutated_inp_runtime_indices + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] + ) + fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) + + # apply joint_gm callback here + if callable(torch._functorch.config.joint_custom_pass): + # pyrefly: ignore [bad-assignment] + fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) + + static_lifetime_input_indices = fw_metadata.static_input_indices + fw_module, bw_module = aot_config.partition_fn( + fx_g, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if config.unlift_effect_tokens and ( + num_tokens > 0 or fw_metadata.num_backward_tokens > 0 + ): + unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) + + num_inner_fwd_outputs -= num_tokens + joint_inputs = ( + joint_inputs[0][num_tokens:], + joint_inputs[1], + ) + + maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + fw_metadata.static_input_indices, + ) + static_lifetime_input_indices = fw_metadata.static_input_indices + + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + symint_outs_saved_for_bw = [] + for idx, node in enumerate(fw_outs_saved_for_bw): + if is_sym_node(node): + symint_outs_saved_for_bw.append(node) + elif ( + isinstance(node, torch.fx.Node) + and "val" in getattr(node, "meta", {}) + and isinstance(node.meta["val"], FakeTensor) + ): + # record dynamic tensor activations + dynamic_dims: set[int] = { + dim + for dim, size in enumerate(node.meta["val"].shape) + if not isinstance(size, int) + } + if dynamic_dims: + fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims + + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw + if torch._functorch.config.donated_buffer: + fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs + + # Note [Detaching inputs that never need gradients] + # See https://github.com/pytorch/pytorch/issues/97745 + # Suppose we have a function like this that we want to compile: + # + # def f(x, y): + # return torch.mul(x, y.detach()) + # + # What gradients should we compute for x and y? + # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, + # and so we'll compute: + # x_grad_input = y + # y_grad_input = None + # Does this preserve the semantics of eager mode? + # Unfortunately, no. + # Doing the above will cause autograd to **continue** to backprop the autograd tape + # that was generated from constructing y. + # + # This is **different** from what would have happened in eager mode. + # In eager mode, if we backprop through the output of this function, autograd will only traverse + # the bit of the autograd tape corresponding to "x". + # In particular, if a user had previously backpropped through y's autograd tape, + # And then they try to backprop through the output of the above function, + # then we'll hit the dreaded "Trying to backward through the graph a second time" error. + # + # You might think: If autograd sees that a gradient is None, shouldn't it stop early, + # instead of continuing the backprop through the ancestors of that node in the graph? + # + # Autograd has two passes: + # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed + # (2) a second pass that actually goes ahead and executes each node when it becomes ready, + # propagating gradients + # By the time we're executing a node and we see that it produces a None, the set of nodes to execute + # is already locked-in. + # + # The fix: instead, we can recognize statically that the graph we're compiling will never contribute + # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. + # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. + # + # Note that this solution is not bulletproof. + # It's possible to construct a case where eager may or may not have have tried to autograd through y, + # depending on the actual grad_outputs that were passed in during the backward. + # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, + # allowing autograd to reuse the graph. + # + # An example of this case is: + # def f(x): + # return x.detach() * 2, x * 3 + # If we were to only backprop through outs[0], in eager, we would stop + # If we backward only on the first output, we shouldn't send a grad through x. + # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 + # and we will end up with a zero grad at x. + # If we later backprop through the second output, this will also require backprop'ing through x. + # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. + _indices_of_inps_to_detach: list[int] = [] + + # reversed() since we expect output at end of graph + bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) + bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] + + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is None: + num_backward_tokens: int = inner_meta.num_backward_tokens + assert ( + len(bw_outs) + == len(fw_metadata.input_info) + + inner_meta.num_outputs_rng_offset + + num_backward_tokens + ) + bw_outs_no_rng_no_tokens = bw_outs + if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: + bw_outs_no_rng_no_tokens = bw_outs[ + : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) + ] + assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info) + + for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): + # If our input experiences a metadata mutation inside the graph (e.g. set_()), + # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation + metadata_mutation_in_graph = ( + fw_metadata.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + and fw_metadata.input_info[i].mutates_storage_metadata + ) + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: + _indices_of_inps_to_detach.append(i) + + return ( + fw_module, + bw_module, + num_fw_outs_saved_for_bw, + num_symints_saved_for_bw, + _indices_of_inps_to_detach, + joint_inputs[0], + ) + + +def _aot_stage2b_fw_compile( + fw_module: torch.fx.GraphModule, + adjusted_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + num_fw_outs_saved_for_bw: int, + aot_config: AOTConfig, +) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: + return _aot_stage2b_compile_forward_or_inference( + fw_module, + adjusted_flat_args, + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=False, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ) + + +def _aot_stage2b_bw_compile( + bw_module: torch.fx.GraphModule, + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]], + num_symints_saved_for_bw: int, + aot_config: AOTConfig, +) -> tuple[AutogradLazyBackwardCompileInfo, Optional[Callable]]: + """ + Compile the backward graph. Returns: + - the placeholder list for the backward graph + - the compiled backward function + """ + with torch.no_grad(): + # NB: It's important to compile backwards ahead of time, as this may + # add extra guards which we need to apply to the Dynamo cache at + # forwards + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): + placeholder_list = fx_placeholder_vals(bw_module) + + forward_saved_for_backwards_strides = None + if fwd_output_strides is not None: + inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata) + forward_saved_for_backwards_strides = fwd_output_strides[ + inner_meta.tensors_saved_for_backwards_slice + ] + + # saved activations can have different stride to eager if + # the compiler does layout optimization. We should restride the + # tensor passed in for compiling the backward graph using the + # saved tensor's stride. + for i in range(len(placeholder_list)): + ph_arg = placeholder_list[i] + if not isinstance(ph_arg, torch.Tensor): + continue + + if forward_saved_for_backwards_strides is None: + continue + + real_stride = None + # Per all_args calling convention + j = i - num_symints_saved_for_bw + if 0 <= j < len(forward_saved_for_backwards_strides): + real_stride = forward_saved_for_backwards_strides[j] + if real_stride is None: + continue + + # Comparing ph_arg.stride() with real_stride directly may + # cause dynamic dimensions in ph_arg being specialized to static + # value. Using suppress_guards and guard_or_true to avoid that. + + stride_different = False + fake_mode = detect_fake_mode() + suppress_ctx = ( + fake_mode.shape_env.suppress_guards() + if fake_mode is not None and fake_mode.shape_env is not None + else nullcontext() + ) + + # Inductor can choose different strides for activations than + # what backward graph has. if we can't statically tell that + # strides are the same, we assume they are not. + with suppress_ctx: + for k in range(len(ph_arg.stride())): + # real_stride can't be symbolic. + # pyrefly: ignore [index-error] + if guard_or_true(ph_arg.stride()[k] != int(real_stride[k])): + stride_different = True + break + + if stride_different: + # Note that here we use the stride of the real tensor to + # restride a FakeTensor. This does not cause trouble + # for dynamic shape since this code path only get + # executed if layout optimization is enabled. And we + # disable layout optimization for dynamic shape right + # now. + # + # A solution that decide stride order based on real + # tensor's stride and then apply that stride order to + # the FakeTensor does not work smoothly since some + # tensor's layout is not 'dense'. E.g. mixnet_l has a + # tensor with size [8, 64, 112, 112] and strides + # (2408448, 1, 21504, 192). The solution mentioned will + # decide a stride of (802816, 1, 7168, 64) for this + # tensor which is wrong. + + ph_size = ph_arg.size() + + # pyrefly: ignore [bad-argument-type] + placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride) + compiled_bw_func = None + if ( + num_symints_saved_for_bw > 0 + or aot_config.force_non_lazy_backward_lowering + ): + try: + # See Note: [Backward graph lazy lowering] + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # If bw_module contains lifted constants, they will be real tensors stored as + # GraphModule. Deepcopying tensors under fake mode is not supported and will + # raise when attempting to set storage. + bw_module_copy = copy.deepcopy(bw_module) + compiled_bw_func = aot_config.bw_compiler( + bw_module_copy, placeholder_list + ) + del bw_module_copy + except Exception as e: + if aot_config.force_non_lazy_backward_lowering: + raise + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join( + traceback.format_exception( + type(exc), exc, exc.__traceback__ + ) + ), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) + # Compiled autograd will run the bw_module in the backward pass, + # so recompilation need happen anyway if the backward pass is ever + # called. + # + # The reason we do the GraphModule recompilation here is because + # the lazy recompilation will cause issue in the backward pass + # with compiled autograd. + # + # Do the _LazyGraphModule.force_recompile here rather than when + # bw_module is first generated by the partitioner because the bw_module.recompile + # may be called in some code path later and cause the _LazyGraphModule.forward + # becomes the lazy version again. One example is when dynamic shape is enabled + # upfront, the bw_compiler will be called above which can cause extra + # graph module recompilation on bw_module. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(bw_module) + + saved_context = TracingContext.try_get() + saved_compile_context = CompileContext.try_get() + + lazy_backward_info = AutogradLazyBackwardCompileInfo( + bw_module, + placeholder_list, + saved_context, + saved_compile_context, + ) + + return lazy_backward_info, compiled_bw_func + + +def aot_stage2_autograd( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + """ + Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, + and returns a wrapped torch.autograd.Function with a forward and backward. + """ + + fx_g = aot_graph_capture.graph_module + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + fw_metadata = aot_state.fw_metadata + aot_config = aot_state.aot_config + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd") + joint_graph_str = _log_joint_graph(fx_g, aot_config) + + _apply_tensorify_python_scalars(fx_g) + + ( + fw_module, + bw_module, + num_fw_outs_saved_for_bw, + num_symints_saved_for_bw, + _indices_of_inps_to_detach, + adjusted_flat_args, + ) = _aot_stage2a_partition( + fx_g, + aot_graph_capture.updated_flat_args, + maybe_subclass_meta, + fw_metadata, + aot_config, + ) + + fw_module_str, bw_module_str = _log_fw_bw_graphs( + fw_module, bw_module, maybe_subclass_meta, fw_metadata, aot_config + ) + + fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile( + fw_module, + adjusted_flat_args, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + aot_config, + ) + + lazy_backward_info, compiled_bw_func = _aot_stage2b_bw_compile( + bw_module, + maybe_subclass_meta, + fw_metadata, + fwd_output_strides, + num_symints_saved_for_bw, + aot_config, + ) + + try_save_cache_entry, entry = _cache_autograd_info( + aot_config, + aot_state.flat_args, + compiled_fw_func, + compiled_bw_func, + fw_module_str, + bw_module_str, + joint_graph_str, + aot_graph_capture.wrappers, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + bw_module, + ) + + return _aot_stage2c_make_autograd_function( + aot_config, + aot_state.flat_args, + fw_metadata, + maybe_subclass_meta, + aot_graph_capture.wrappers, + compiled_fw_func, + compiled_bw_func, + lazy_backward_info, + try_save_cache_entry, + entry, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + ) + + +def _aot_stage2c_make_autograd_function( + aot_config, + flat_args, + fw_metadata, + maybe_subclass_meta, + wrappers, + compiled_fw_func, + compiled_bw_func, + lazy_backward_info, + try_save_cache_entry, + entry, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, +): + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + disable_amp = torch._C._is_any_autocast_enabled() + compiled_fn = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + maybe_subclass_meta, + num_symints_saved_for_bw, + backward_state_indices, + disable_amp, + _indices_of_inps_to_detach, + lazy_backward_info, + aot_config, + fw_metadata=fw_metadata, + try_save_cache_entry=try_save_cache_entry, + ) + + if entry is not None: + compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry) + + if config.debug_assert: + flat_requires_grad: list[Optional[bool]] = [ + a.requires_grad if isinstance(a, Tensor) else None for a in flat_args + ] + compiled_fn = DebugAssertWrapper( + flat_requires_grad=flat_requires_grad + ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata) + + compiled_fn = post_compile( + wrappers, + compiled_fn, + aot_config, + runtime_metadata=fw_metadata, + ) + return compiled_fn + + +def _cache_autograd_info( + aot_config, + flat_args, + compiled_fw_func, + compiled_bw_func, + fw_module_str, + bw_module_str, + joint_graph_str, + wrappers, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + bw_module, +): + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + try_save_cache_entry: Optional[Callable] = None + entry: Optional[GenericAOTAutogradResult] = None + + if aot_config.cache_info is not None: + forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns + + # NB: aot_config here is technically not needed as an argument: we could just + # close over aot_config.cache_info, since aot_config never changes. + # But closing over random variables is confusing IMO, so I'm leaving it. + def try_save_cache_entry( # noqa: F811 + compiled_bw_func: Callable, + bw_module: torch.fx.GraphModule, + _fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + ) -> Optional[GenericAOTAutogradResult]: + cache_info = aot_config.cache_info + + def should_save_cache(): + if should_bundle_autograd_cache(): + return True + else: + return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr( + compiled_bw_func, "_fx_graph_cache_key" + ) + + if cache_info is not None and should_save_cache(): + assert forward_time_taken_ns is not None + # TODO: technically, AOTAutograd does a *little* bit of post processing work + # in the backward that isn't measured here. But it's small enough that it's not worth + # the complexity of threading a bunch of times through the code, so we + # use the compiled_bw_func's inductor compile time instead. + # It's possible this changes in the future, in which case we should + # update backward_time_taken_ns to be more inclusive + backward_time_taken_ns = getattr(compiled_bw_func, "_time_taken_ns", 0) + + aot_forward_graph_str: Optional[str] = fw_module_str + aot_backward_graph_str: Optional[str] = bw_module_str + aot_joint_graph_str: Optional[str] = joint_graph_str + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + + entry = AOTAutogradCache.make_entry( + compiled_fw_func, # type: ignore[arg-type] + compiled_bw_func, # type: ignore[arg-type] + aot_joint_graph_str, + aot_forward_graph_str, + aot_backward_graph_str, + _fw_metadata, + wrappers, + maybe_subclass_meta, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + forward_time_taken_ns, + backward_time_taken_ns, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw=num_symints_saved_for_bw, + serialized_bw_module=serialize_graph_module(bw_module), + ) + AOTAutogradCache.save( + cache_info.cache_key, + entry, + remote=should_use_remote_autograd_cache(), + ) + return entry + return None + + if compiled_bw_func is not None: + # If we already compiled the backward, we save its cache entry now + entry = try_save_cache_entry( + compiled_bw_func, bw_module, fw_metadata, aot_config + ) + try_save_cache_entry = None + + return try_save_cache_entry, entry + + +def _aot_stage2b_compile_forward_or_inference( + fw_module: torch.fx.GraphModule, + adjusted_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + *, + is_inference: bool, + num_fw_outs_saved_for_bw: Optional[int] = None, +) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: + """ + Compile the forward or inference graph. Returns: + - the output strides of the forward graph + - the compiled forward/inference function + + Args: + fw_module: The forward graph module to compile + adjusted_flat_args: Flattened arguments after adjustments + maybe_subclass_meta: Metadata for tensor subclasses + fw_metadata: View and mutation metadata + aot_config: AOT configuration + is_inference: If True, compile for inference; if False, compile for forward (autograd) + num_fw_outs_saved_for_bw: Number of forward outputs saved for backward (required if not is_inference) + + Before compiling, we run pre_compile for the following wrappers: + - FakifiedOutWrapper + - FunctionalizedRngRuntimeWrapper + After compiling, we run post_compile for the following wrappers: + - EffectTokensWrapper + - AOTDispatchSubclassWrapper + - FunctionalizedRngRuntimeWrapper + - FakifiedOutWrapper + """ + + # Validation + if not is_inference and num_fw_outs_saved_for_bw is None: + raise ValueError( + "num_fw_outs_saved_for_bw must be provided when is_inference=False" + ) + + # Determine grad context, autocast context, tracking mode, compiler + if is_inference: + grad_ctx: Any = nullcontext + autocast_ctx: Any = ( + torch._C._DisableAutocast + if torch._C._is_any_autocast_enabled() + else nullcontext + ) + tracking_mode: str = "inference" + compiler: Any = aot_config.inference_compiler + else: + grad_ctx = torch.no_grad + autocast_ctx = torch._C._DisableAutocast + tracking_mode = "forward" + compiler = aot_config.fw_compiler + + with grad_ctx(), autocast_ctx(), track_graph_compiling(aot_config, tracking_mode): + # Setup wrappers + fakified_out_wrapper = FakifiedOutWrapper() + fakified_out_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Initialize RNG wrapper based on mode + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( + return_new_outs=is_inference + ) + + # Add RNG states for forward mode only + if not is_inference and fw_metadata.num_graphsafe_rng_states > 0: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + + functionalized_rng_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Set tracing context + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = _get_inner_meta( + maybe_subclass_meta, fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = compiler(fw_module, adjusted_flat_args) + + # Make boxed if needed + if not getattr(compiled_fw_func, "_boxed_call", False): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + # Set forward output strides if needed + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + # Apply post-compile wrappers + compiled_fw_func = EffectTokensWrapper().post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = AOTDispatchSubclassWrapper( + fw_only=None, + trace_joint=False, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = functionalized_rng_wrapper.post_compile( + compiled_fw_func, aot_config, runtime_metadata=fw_metadata + ) + + compiled_fw_func = fakified_out_wrapper.post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + return fwd_output_strides, compiled_fw_func diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6325b6e6ab2489c175347afe13e05bfbed3c7e8d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +""" +Contains utils for logging in AOTAutograd, including managing the names of the graphs under +compilation, capturing user-friendly tracebacks, and debug messages. +""" + +import collections +from contextlib import contextmanager + +import torch +import torch.fx.traceback as fx_traceback + + +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: list[str] = [] +# TODO: It would be nice to reset the numbering every time aot_id goes +# up, but this is annoying to do right now (because we don't know if +# an aot_id will come back from the dead), so right now this also happens +# to be a globally unique number too (at the cost of wobbling if you change +# how the graphs compile) +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_aot_compilation_context() -> tuple[list[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name + + +@contextmanager +def track_graph_compiling(aot_config, graph_name): + global graph_being_compiled + # TODO: Don't shove the aot_id in here; set it in the context + graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] + old_name = None + if tracing_context := torch._guards.TracingContext.try_get(): + old_name = tracing_context.aot_graph_name + tracing_context.aot_graph_name = graph_being_compiled + has_tracing_context = True + else: + has_tracing_context = False + try: + yield + finally: + global nth_graph + nth_graph += 1 + graph_being_compiled = [] + if has_tracing_context: + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.aot_graph_name = old_name + + +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: list): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() # type: ignore[var-annotated] + for node in roots: + if node is not None and node not in seen: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_, seq_nr): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined] + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + fx_traceback.set_grad_fn_seq_nr(seq_nr) + + return prehook + + def get_posthook(special_stack_, seq_nr): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + fx_traceback.reset_grad_fn_seq_nr() + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr())) + + special_stack = forward_node_stack.copy() + special_stack.append(fx_traceback.GRADIENT_ACC_SPECIAL_STACK) + node.register_hook(get_posthook(special_stack, node._sequence_nr())) + + +def describe_input(i, aot_config): + if i < aot_config.num_params_buffers: + return f"parameter/buffer {i}" + else: + return f"input {i - aot_config.num_params_buffers}" + + +def format_guard_bug_msg(aot_config, expected): + return ( + f"At compilation time, graph {aot_config.aot_id} was compiled under the " + f"assumption that {expected}, but at runtime this was not the case. " + "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..86202e2cd319d9a959d1af9e57efca9299624085 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -0,0 +1,2604 @@ +# mypy: allow-untyped-defs +""" +This module defines runtime wrappers, which, based on previous analysis attempts to: +1. process the inputs and outputs +2. apply mutations +3. handle functionalized randomness +4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) +""" + +import builtins +import collections +import contextlib +import copy +import functools +import itertools +import pprint +from collections.abc import Callable +from contextlib import AbstractContextManager, nullcontext +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Optional, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Sequence + +import torch +import torch.fx as fx +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo import config as dynamo_config +from torch._dynamo.callback import callback_handler, CallbackTrigger +from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context +from torch._guards import ( + compile_context, + CompileContext, + detect_fake_mode, + DuplicateInputs, + tracing, + TracingContext, +) +from torch._prims_common import CUDARngStateHelper +from torch._subclasses import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .descriptors import ( + AOTInput, + AOTOutput, + DummyAOTInput, + MetadataMutationAOTOutput, + SyntheticBaseAOTInput, + ViewBaseAOTInput, +) +from .functional_utils import gen_alias_from_base +from .graph_capture_wrappers import aot_dispatch_subclass +from .input_output_analysis import ( + compute_overlapping_inputs, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling +from .schemas import ( + AOTConfig, + CompilerWrapper, + FxValue, + InductorWrapper, + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputType, + PlainTensorMeta, + SubclassCreationMeta, + SubclassMeta, + TensorAlias, + TraceFn, + ViewAndMutationMeta, +) +from .subclass_utils import ( + requires_subclass_dispatch, + runtime_unwrap_tensor_subclasses, + wrap_tensor_subclasses, +) +from .utils import ( + call_and_expect_output_descs, + call_func_at_runtime_with_args, + make_boxed_func, + partial_flatten_asdict, + simple_wraps, + strict_zip, + without_output_descs, +) + + +zip = strict_zip + + +# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic +# that needs to run after the compiled function. +# +# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime +# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. +# This is because there are some minor differences in how we treat these cases at runtime: +# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. +# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs +@dataclass +class RuntimeWrapper(CompilerWrapper): + indices_of_inps_to_detach: list[int] + trace_joint: bool + disable_amp: bool + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + return _create_runtime_wrapper( + compiled_fn, + runtime_metadata=runtime_metadata, + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=self.trace_joint, + keep_input_mutations=aot_config.keep_inference_input_mutations, + disable_amp=self.disable_amp, + ) + + +class NoopAliasHandler: + def __init__(self, info, runtime_metadata, trace_joint): + pass + + def __call__(self, orig_inputs, fw_outs, out): + return out + + +def _unwrap_tensoralias(x): + assert isinstance(x, TensorAlias) + return x.alias + + +def _identity(x): + return x + + +class AliasOfInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.view_meta_sequence = info.view_meta_sequence + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return gen_alias_from_base( + aliased_base_tensor, + self.unwrap_out(out), + self.requires_grad, + self.view_meta_sequence, + replay_views=self.replay_views, + ) + + +class IsInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return aliased_base_tensor + + +class AliasOfIntermediateHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self._unwrap_aliased_base_tensor = _identity + if info.output_type in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + ): + num_user_outputs = len(runtime_metadata.output_info) + self.base_idx = info.base_idx + num_user_outputs + else: + self.base_idx = info.base_idx + if self.base_idx in runtime_metadata.aliased_out_indices: + self._unwrap_aliased_base_tensor = _unwrap_tensoralias + + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.view_meta_sequence = info.view_meta_sequence + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = fw_outs[self.base_idx] + return gen_alias_from_base( + self._unwrap_aliased_base_tensor(aliased_base_tensor), + self.unwrap_out(out), + self.requires_grad, + self.view_meta_sequence, + replay_views=self.replay_views, + ) + + +_HANDLER_MAP = { + OutputType.non_alias: NoopAliasHandler, + OutputType.unsafe_view_alias: NoopAliasHandler, + OutputType.custom_function_view: NoopAliasHandler, + OutputType.alias_of_input: AliasOfInputHandler, + OutputType.is_input: IsInputHandler, + OutputType.alias_of_intermediate: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler, +} + + +def make_output_handler(info, runtime_metadata, trace_joint): + handler_type = _HANDLER_MAP[info.output_type] + return handler_type(info, runtime_metadata, trace_joint) + + +# not sure why AOTDispatcher needs to manually set this +def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]): + if hasattr(t, "_dynamo_weak_dynamic_indices"): + # pyrefly: ignore [missing-attribute] + t._dynamo_weak_dynamic_indices |= dims + else: + t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined] + + +def _should_disable_saved_tensors_hooks(): + # Compiled autograd is not supported yet, to be added in future. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return False + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if are_inline_hooks(hooks): + return True + + return False + + +def _create_runtime_wrapper( + compiled_fn, + *, + runtime_metadata: ViewAndMutationMeta, + indices_of_inps_to_detach: list[int], + trace_joint: bool, + keep_input_mutations: bool, + disable_amp: bool, +): + if not getattr(compiled_fn, "_boxed_call", False): + compiled_fn = make_boxed_func(compiled_fn) + + # Note [Inputs needed in runtime epilogue after list clearing] + # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to + # wrap the input arguments in a list, and clear the list from within the function. + # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`. + # + # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early. + # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs + # **after** the compiled function has finished running. There are two main cases: + # (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input. + # (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`, + # and doing so requires us accessing the corresponding input after the compiled artifact has run. + epilogue_args_idx = [] + epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices) + for info in runtime_metadata.output_info: + if ( + info.output_type == OutputType.alias_of_input + or info.output_type == OutputType.is_input + ): + assert isinstance(info.base_idx, int) + epilogue_args_idx.append(info.base_idx) + + if config.unlift_effect_tokens: + assert len(runtime_metadata.tokens) == 0 + + if runtime_metadata.num_outputs_aliased > 0: + output_handlers = tuple( + make_output_handler(info, runtime_metadata, trace_joint) + for info in runtime_metadata.output_info + ) + + def record_runtime_wrapper_prologue_enter() -> Optional[ + AbstractContextManager[None] + ]: + if ( + torch.autograd.profiler._is_profiler_enabled + and dynamo_config.record_runtime_overhead + ): + cm = torch._C._profiler._RecordFunctionFast( + "AOTDispatcher Runtime Wrapper Prologue" + ) + cm.__enter__() + return cm + return None + + def record_runtime_wrapper_prologue_exit( + cm: Optional[AbstractContextManager[None]], + ) -> None: + if cm is not None: + cm.__exit__(None, None, None) + + @simple_wraps(compiled_fn) + def runtime_wrapper(args: list[Any]): + # Create context manager for profiler + cm = record_runtime_wrapper_prologue_enter() + + # stash a ref to each input tensor we plan to use after the compiled function + orig_inputs = {i: args[i] for i in epilogue_args_idx} + + if keep_input_mutations: + mutated_args = ( + args[i] + for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd + ) + torch.autograd.graph.increment_version(mutated_args) + + if trace_joint: + args_ = list(args) + # See Note [Detaching inputs that never need gradients] + for idx in indices_of_inps_to_detach: + if isinstance(args_[idx], torch.Tensor): + args_[idx] = args_[idx].detach() + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with ( + torch.autograd._force_original_view_tracking(True), + torch.enable_grad(), + ): + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args_, disable_amp=disable_amp, steal_args=True + ) + else: + # When we have an inference graph, we run with grad disabled. + # It's possible to get an inference graph with inputs that require grad, + # in which case we want to make sure autograd is disabled + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args, disable_amp=disable_amp, steal_args=True + ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) + del args + + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices + num_intermediate_bases = runtime_metadata.num_intermediate_bases + + assert ( + len(all_outs) + == num_mutated_runtime_inps + + runtime_metadata.num_outputs + + num_intermediate_bases + ) + + # Step 3: After running the compiled fw, apply updates to mutated inputs + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices + if num_mutations_to_apply > 0: + updated_inputs = all_outs[:num_mutations_to_apply] + fw_outs = all_outs[num_mutations_to_apply:] + + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): + meta = runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + original_inpt = orig_inputs[inpt_idx] + updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # See Note [set_() Input Mutations in AOTAutograd] + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + with torch.no_grad(): + original_inpt.set_(updated_inpt) + continue + if meta.mutates_metadata and not meta.mutates_data: + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if meta.mutates_data and meta.mutates_metadata: + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + assert meta.mutates_data + if meta.is_leaf and original_inpt.requires_grad: + # We can hit this situation in this case: + # def f(x): + # x.detach().mul_(2) + # return x + 1 + # AOTAutograd will see a mutation in the above case, and try to + # apply a copy_() here, in the epilogue. + # But if x required gradients, and is a leaf, then autograd + # will yell at us for trying to mutate it. + # However, it's only possible to end up in this scenario (like the above) + # if all of the mutations to the leaf input were non-autograd-tracking mutations + # (aka mutations under no_grad(), or on detached views). + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. + original_inpt.detach().copy_(updated_inpt) + else: + original_inpt.copy_(updated_inpt) + else: + fw_outs = all_outs + + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of + # compiling them. + if runtime_metadata.num_outputs_aliased > 0: + # The compiled forward also returned intermediate bases. We don't want to return them to the user. + expect_num_outputs = ( + len(output_handlers) + runtime_metadata.num_intermediate_bases + ) + assert len(fw_outs) == expect_num_outputs + ret_outs = [ + handler(orig_inputs, fw_outs, out) + for out, handler in builtins.zip(fw_outs, output_handlers) + ] + else: + ret_outs = fw_outs + + if runtime_metadata.dynamic_outputs: + for t, o in zip(ret_outs, runtime_metadata.output_info): + if o.dynamic_dims is None: + continue + maybe_mark_dynamic_helper(t, o.dynamic_dims) + if runtime_metadata.grad_enabled_mutation is not None: + torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) + return ret_outs + + if not (trace_joint and _should_disable_saved_tensors_hooks()): + return runtime_wrapper + + # Disabling saved tensors hooks + @simple_wraps(runtime_wrapper) + def _runtime_wrapper(*args, **kwargs): + with _disable_saved_tensors_hooks(): + return runtime_wrapper(*args, **kwargs) + + return _runtime_wrapper + + +# WARNING: this does NOT operate on TraceFn +@dataclass +class FunctionalizedRngRuntimeWrapper(InductorWrapper): + # TODO: I would love to get rid of this argument, but it's + # Wrapped pretty tightly around our aot_dispatch_autograd logic. + # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices + # for setting placeholder strides(which is done before runtime, before this wrapper runs) + # and for saving tensors for backward (which is done during runtime, after this wrapper runs) + # So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one + # of those two indices incorrect. + return_new_outs: bool = True + + def pre_compile( + self, + flat_fn: torch.fx.GraphModule, + flat_args, + aot_config, + *, + fw_metadata, + ) -> None: + if config.functionalize_rng_ops: + # Update example inputs for the fw_compiler + fake_mode = detect_fake_mode() + assert fake_mode is not None + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) + flat_args.extend([seed, offset]) + # We are not clearing flat_args here because + # 1) There is a check in the debug compiler at the end + # 2) It does not matter as these are fake tensors + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def wrapper(runtime_args: list[Any]): + if runtime_metadata.is_rng_op_functionalized: + # Add the seed and offset to args + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() + runtime_args.extend([seed, offset]) + out = compiled_fn(runtime_args) + out = self._functionalized_rng_runtime_epilogue( + runtime_metadata, + out, + # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper + runtime_metadata.num_forward_returns, + ) + return out + return compiled_fn(runtime_args) + + return wrapper + + # Calling convention: If we are running functionalized RNG, then outs consists + # of (user_outs, rng_offset) + def _functionalized_rng_runtime_epilogue( + self, + metadata: ViewAndMutationMeta, + outs, + offset_index, + ): + if metadata.is_rng_op_functionalized: + assert metadata.num_outputs_rng_offset == 1 + new_rng_offset = outs[offset_index] + CUDARngStateHelper.set_new_offset(new_rng_offset) + if self.return_new_outs: + user_outs = outs[:offset_index] + outs[offset_index + 1 :] + return user_outs + else: + return outs + + return outs + + +# WARNING: this does NOT operate on TraceFn +@dataclass +class FakifiedOutWrapper(InductorWrapper): + out_metas: list[torch.Tensor] = field(default_factory=list) + # TracingContext.fwd_output_strides + # Generated from actually doing compile + # NB: an entry is None if it's not a Tensor + fwd_output_strides: Optional[list[Optional[list[int]]]] = None + needs_post_compile: bool = True + + def pre_compile( + self, + fw_module: fx.GraphModule, # Must be fw_module from aot_dispatch_*_graph + flat_args, + aot_config, + *, + fw_metadata, + ) -> None: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + self.out_metas = [ + n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0]) + ] + else: + self.needs_post_compile = False + + def _compute_output_meta_with_inductor_strides(self): + out = self.out_metas + fwd_output_strides = self.fwd_output_strides + if not fwd_output_strides: + return out + + from torch.fx.experimental.symbolic_shapes import statically_known_true + + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + strides = fwd_output_strides[i] + # fwd_output_strides is best effort by Inductor. When an output + # Tensor has unbacked SymInts, Inductor may sometimes be unable + # to compute what the output stride would be. If Inductor doesn't + # have any clear direction on the layout, we don't have to run + # as_strided. To repro without this, run: + # + # python test/distributed/test_dynamo_distributed.py + # TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding + if strides is None: + continue + if all( + statically_known_true(s1 == s2) + for s1, s2 in zip(out[i].stride(), strides) + ): + continue + out[i] = out[i].as_strided(out[i].shape, strides) + return out + + # To be called post compile + def set_fwd_output_strides(self, fwd_output_strides): + self.fwd_output_strides = fwd_output_strides + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.needs_post_compile: + assert self.fwd_output_strides is not None + fakified_out = self._compute_output_meta_with_inductor_strides() + + @wraps(compiled_fn) + def wrapper(runtime_args): + nonlocal fakified_out + if fakified_out is not None: + out = fakified_out + fakified_out = None + return out + return compiled_fn(runtime_args) + + return wrapper + # If we don't need to fakify, we can just return the original compiled function + return compiled_fn + + +# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. +# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, +# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). +# This function handles the wrapping and unwrapping of tensor subclasses at runtime. +@dataclass +class AOTDispatchSubclassWrapper(CompilerWrapper): + trace_joint: bool + fw_only: Optional[Callable] # Not cached, only used in pre_compile + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ): + (new_flat_fn, new_flat_args, new_flat_args_descs, subclass_meta) = ( + aot_dispatch_subclass( + flat_fn, + flat_args, + flat_args_descs, + is_joint_structure=self.trace_joint, + meta=fw_metadata, + fw_only=self.fw_only, # type: ignore[arg-type] + ) + ) + self.maybe_subclass_meta = subclass_meta + return new_flat_fn, new_flat_args, new_flat_args_descs, fw_metadata + + def post_compile( + self, + compiled_fn, + _aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.maybe_subclass_meta is None: + return compiled_fn + + subclass_metas = runtime_metadata.subclass_fw_graph_out_meta + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + unwrapped_args = runtime_unwrap_tensor_subclasses( + args, + subclass_metas=runtime_metadata.subclass_inp_meta, + append_symints=True, + ) + args.clear() + # expectation: runtime_fn is a boxed fn + unwrapped_outs = compiled_fn(unwrapped_args) + wrapped_outs = wrap_tensor_subclasses( + unwrapped_outs, + subclass_metas=subclass_metas, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + is_runtime=True, + included_subclass_symints=True, + ) + return wrapped_outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +@dataclass +class EffectTokensWrapper(CompilerWrapper): + def post_compile( + self, + compiled_fn, + _aot_config, + *, + runtime_metadata: ViewAndMutationMeta, + ): + num_tokens = len(runtime_metadata.tokens) + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + if num_tokens > 0: + # Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + old_args = args + args = [*([None] * num_tokens), *args] + old_args.clear() + + outs = compiled_fn(args) + + # Inductor cache DummyModule can return None + if outs is None: + return None + # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + return outs[num_tokens:] if num_tokens != 0 else outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +# MOTIVATION: +# +# When tracing functions for future execution, one must be careful not to pass +# in the same input tensor multiple times (e.g., f(x, x), as this can result +# in graphs that are ONLY valid if you later pass a new tensor in exactly the +# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct +# tensors that alias each other is a different situation that is covered by +# aot_dispatch_deduplicated_autograd). Here are two examples: +# +# (1) Suppose you have a function: +# +# def f(x, y): +# return x + y +# +# If you make_fx(f)(x, x), you will trace out: +# +# def f(x, y): +# return y + y +# +# Oops! +# +# (2) For most tensors x and y, you can compute f's gradient with respect to +# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, +# if x is y, you will trace out a program that gets incorrect gradients: +# +# >>> x = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + x, (x, x)) +# (tensor([2.]), tensor([2.])) +# +# In other words, the gradient is double-counted. Deduplicating the arguments +# gives you an appropriate gradient: +# +# >>> y = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + y, (x, y)) +# (tensor([1.]), tensor([1.])) +# +# HOW TO DEDUPLICATE: +# +# There are a few strategies, in order of preference: +# +# 1. For every duplicate argument to the function, detach it into +# a separate leaf tensor, so that it is no longer duplicated. +# +# PRO: The resulting compiled graph works for any configuration +# of duplicated arguments. +# +# CON: It does not (naively) work if you mutate the metadata of inputs: +# +# def f(x, y): +# x.transpose_(0, 1) +# y.transpose_(0, 2) +# +# x = torch.randn(2, 3, 4) +# f(x, x) +# +# The ordering of the transposes inside f dictates whether or not +# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute +# what metadata mutations should get applied to each input; you need to +# assume they aren't duplicates (what we do today) or preserve +# the original metadata mutations exactly in order, so that they work +# for any duplicate configuration. +# +# CON: It does not (naively) work if you mutate the data of inputs. +# In particular, leaf tensors that require grad cannot be mutated, +# this makes it impossible to differentiate with respect to the original +# base. +# +# 2. For every duplicate argument to the function, remove it, so it is +# no longer part of the "true" signature: +# +# PRO: Implemented naively, it still works for metadata/data mutation. +# +# CON: The resulting compiled graph is duplicate-specialized: it only +# works if future calls duplicate arguments in exactly the same way. +# Horribly, Dynamo doesn't guard on this at the moment. But even if +# it did, you could still end up recompiling a bunch of each duplicate. +# +# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if +# Dynamo's guards are not enough. In practice, this seems to cover +# everything. +# +@dataclass +class AOTDedupeWrapper(CompilerWrapper): + keep_arg_mask: list[bool] = field(default_factory=list) + add_dupe_map: list[int] = field(default_factory=list) + old_input_metadata: list[InputAliasInfo] = field(default_factory=list) + needs_post_compile: bool = True + + # NB: Hot path, avoid set lookups here + # TODO: Can avoid the zip here too, probably + def remove_dupe_args(self, args): + return [t for t, keep in zip(args, self.keep_arg_mask) if keep] + + def add_dupe_args(self, args): + return [args[i] for i in self.add_dupe_map] + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + # Use information about whether or not flat_fn mutates its arguments + # or not to handle dupe args + + # Strategy 1: For any input that is not mutated, we can leafify it if we + # need to remove a duplicate. + leaf_flat_args: list[FxValue] = [] + leaf_flat_args_descs: list[AOTInput] = [] + args_set = set() + ok = True + + for i, (a, a_desc) in enumerate(zip(flat_args, flat_args_descs)): + if not isinstance(a, torch.Tensor): + leaf_flat_args.append(a) + leaf_flat_args_descs.append(a_desc) + elif a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) + leaf_flat_args_descs.append(a_desc) + elif ( + not fw_metadata.input_info[i].mutates_data + and not fw_metadata.input_info[i].mutates_metadata + ): + leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) + leaf_flat_args_descs.append(a_desc) + else: + ok = False + break + + if ok: + self.needs_post_compile = False + return flat_fn, leaf_flat_args, leaf_flat_args_descs, fw_metadata + + if requires_subclass_dispatch(leaf_flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered duplicate inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + # export path: ban duplicate inputs for now, add later if requested. + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered duplicated inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + fw_metadata={str(fw_metadata)} + """ + ) + + # Strategy 2: Duplicate specialization + # + # When we have duplicate arguments in a function call, we need to handle them specially. + # For example, if we have a function call f(a, b, a, c), we need to: + # + # 1. Remove duplicates to get a deduplicated list [a, b, c] + # 2. Compile our function to work with this deduplicated list + # 3. At runtime, convert incoming arguments with duplicates to the deduplicated form + # 4. Pass the deduplicated arguments to our compiled function + # + # To do this, we need two helper functions: + # + # - remove_dupe_args: Converts [a, b, a, c] -> [a, b, c] + # - add_dupe_args: Converts [a, b, c] -> [a, b, a, c] + # + # For our example [a, b, a, c], we track: + # + # - seen_args = {a: 0, b: 1, c: 2} (maps each unique arg to its first position) + # - add_dupe_map = [0, 1, 0, 2] (tells us how to reconstruct the original list) + # - keep_arg_mask = [True, True, False, True] (tells us which args to keep when deduplicating) + + seen_args: dict[Tensor, int] = {} + # Implicitly map duped arg position (list index) to de-duped arg position + keep_arg_mask: list[bool] = [] + add_dupe_map: list[int] = [] + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for t in flat_args: + if isinstance(t, torch.Tensor): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map.append(seen_args[t]) + continue + seen_args[t] = j + + keep_arg_mask.append(True) + add_dupe_map.append(j) + j += 1 + assert len(add_dupe_map) == duped_arg_len, ( + f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + ) + + self.keep_arg_mask = keep_arg_mask + self.add_dupe_map = add_dupe_map + + deduped_flat_args = self.remove_dupe_args(flat_args) + # TODO: instead of arbitrarily removing args, it might be useful to + # have a record that these were duped, perhaps as a mutable attribute + # on the kept arg? Do this if someone needs it + deduped_flat_args_descs = self.remove_dupe_args(flat_args_descs) + + # Update our input metadata to remove duped input metadata. + updated_fw_metadata = remove_dupe_metadata( + fw_metadata, keep_arg_mask, add_dupe_map + ) + + if ( + tracing_context := TracingContext.try_get() + and aot_config.aot_autograd_arg_pos_to_source + ): + # TODO(voz): This structure is 1:1, we could consider an alternate structure like + # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there, + # which feels like needless complexity for a tiny bit of efficiency at this point. + for dupe_arg_pos, (kept_pos, keep_arg) in enumerate( + zip(add_dupe_map, keep_arg_mask) + ): + if not keep_arg: + dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + dupe_arg_pos + ] + kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + kept_pos + ] + tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined] + DuplicateInputs(kept_arg_source, dupe_arg_source) + ) + + @simple_wraps(flat_fn) + def wrapped_flat_fn( + *args: FxValue, + ) -> tuple[list[FxValue], list[AOTOutput]]: + outs, out_descs = call_and_expect_output_descs( + flat_fn, self.add_dupe_args(args) + ) + return outs, out_descs + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + without_output_descs(wrapped_flat_fn), + flat_args_descs=deduped_flat_args_descs, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*deduped_flat_args) + assert ref_fw_metadata == updated_fw_metadata, ( + f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + ) + + return ( + wrapped_flat_fn, + deduped_flat_args, + deduped_flat_args_descs, + updated_fw_metadata, + ) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + @wraps(compiled_fn) + def wrapped_compiled_fn(args: list[Any]): + deduped_args = self.remove_dupe_args(args) + args.clear() + return compiled_fn(deduped_args) + + wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + # This can be uncommented when we properly guard for duplicates, + # but right now we must not do it. + # if not config.debug_assert: + # return wrapped_compiled_fn + + @wraps(wrapped_compiled_fn) + def debugged_compiled_fn(args): + # Test that the computed remove/add arg functions are an inverse + new_args = self.add_dupe_args(self.remove_dupe_args(args)) + seen: dict[Any, None] = {} + for i, (x, y) in enumerate(zip(new_args, args)): + seen[y] = None + assert x is y, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would be a duplicate of " + f"{describe_input(self.add_dupe_map[i], aot_config)}", + ) + # This is only an error if there is metadata mutation on both of + # the duped arguments; in this case, we need to know what order + # the metadata mutation applies in. You'll get the correct result + # otherwise, because a graph that assumes distinct inputs works if + # you dupe the inputs (the gradient contributions from each input + # will get summed up appropriately.) + # + # TODO: work out how to setup this assert correctly + """ + assert len(seen) == unique_args, format_guard_bug_msg(aot_config, + f"there would be {unique_args} distinct arguments" + ) + """ + return wrapped_compiled_fn(args) + + debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + return debugged_compiled_fn + + +# This layer handles the situation where you have two inputs that alias each other, +# and one of the inputs is mutated. +# We need to take special care to ensure that the mutation is applied to the other aliases in the graph. +# +# pre-condition: AOTDedupWrapper has already run. +# (This function will in theory work if there are duplicate args. +# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs +# would cause us to hit that path more frequently). +@dataclass +class AOTSyntheticBaseWrapper(CompilerWrapper): + # Currently, the only reason we need to plumb this bool is because + # the synthetic base code prohibits more cases in the autograd case than the inference case. + trace_joint: bool # TODO: refactor trace_joint + needs_post_compile: bool = True + aliased_arg_idx_with_metadata_mutations: list[int] = field(default_factory=list) + + def pre_compile( + self, + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + is_inference = not self.trace_joint + ( + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + synthetic_base_info, + ) = merge_view_inputs( + aot_config, + flat_args, + flat_args_descs, + fw_metadata.input_info, + is_inference=is_inference, + ) + + # Happy path: we don't need synthetic bases + if synthetic_base_info is None: + self.needs_post_compile = False + return flat_fn, flat_args, flat_args_descs, fw_metadata + + # export path: ban synthetic bases for now, add later if requested. + if requires_subclass_dispatch(flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered aliased inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered aliased inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + synthetic_base_info={str(synthetic_base_info)} + + fw_metadata={str(fw_metadata)} + """ + ) + + assert len(fw_metadata.input_info) == len(synthetic_base_info) + + # Update our forward metadata to take synthetic bases into account + ( + fw_metadata_updated, + aliased_arg_idx_with_metadata_mutations, + ) = create_synthetic_base_metadata( + fw_metadata, + synthetic_base_info, + flat_args, + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + ) + # Save old input args for post-compile + self.old_input_info = fw_metadata.input_info + + self.aliased_arg_idx_with_metadata_mutations = ( + aliased_arg_idx_with_metadata_mutations + ) + replay_views = config.view_replay_for_aliased_outputs + + def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]: + f_args_inner = [] + # pyrefly: ignore [not-iterable] + for inner_idx_or_tuple in synthetic_base_info: + if isinstance(inner_idx_or_tuple, int): + f_args_inner.append(primals[inner_idx_or_tuple]) + else: + inner_base_idx, view_tensor = inner_idx_or_tuple + base = primals[inner_base_idx] + view_arg = gen_alias_from_base( + base, + view_tensor, + view_tensor.requires_grad, + replay_views=replay_views, + ) + f_args_inner.append(view_arg) + return f_args_inner + + @simple_wraps(flat_fn) + def wrapped_flat_fn(*args): + unpacked_args = _unpack_synthetic_bases(args) + # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) + # is to relieve the downstream logic from having to reason about mutations on inputs that alias + # each other, by replacing aliased inputs with a synthetic base. + # One area where this breaks down a bit however is if one of those aliased inputs + # experienced a metadata mutation. + # We are now obligated to reapply the metadata mutation directly to the user's input; + # it isn't enough to apply mutations back to the synthetic base in the downstream logic. + # + # The way we handle this is by pretending that those aliased inputs that experience metadata mutations + # are additional outputs in the user's forward function. + # The downstream logic will just treat these as "user outputs that alias inputs". + # However, we will manually grab them at runtime here, use them to reapply the metadata mutation + # to the user inputs, and not return them to the user. + aliased_args_with_metadata_mutations = [ + x + for i, x in enumerate(unpacked_args) + if i in self.aliased_arg_idx_with_metadata_mutations + ] + out, out_descs = call_and_expect_output_descs(flat_fn, unpacked_args) + if len(aliased_args_with_metadata_mutations) > 0: + # TODO: record more detailed desc information here + return (*out, *aliased_args_with_metadata_mutations), ( + *out_descs, + *( + [ + MetadataMutationAOTOutput(i) + for i in range( + len(self.aliased_arg_idx_with_metadata_mutations) + ) + ] + ), + ) + else: + return out, out_descs + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + without_output_descs(wrapped_flat_fn), + flat_args_descs=flat_args_descs_with_synthetic_bases, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*flat_args_with_synthetic_bases) + assert ref_fw_metadata == fw_metadata_updated, ( + f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, " + f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}" + ) + return ( + wrapped_flat_fn, + flat_args_with_synthetic_bases, + flat_args_descs_with_synthetic_bases, + fw_metadata_updated, + ) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + is_inference = not self.trace_joint + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + # TODO: this sure seems expensive to run at runtime (which + # post_compile seems to imply it does?!) + args_with_synthetic_bases, _, synthetic_base_info = merge_view_inputs( + aot_config, args, None, self.old_input_info, is_inference=is_inference + ) + assert synthetic_base_info is not None + aliased_args_w_metadata_mutations = [ + args[i] for i in self.aliased_arg_idx_with_metadata_mutations + ] + num_aliased_args_with_metadata_mutations = len( + aliased_args_w_metadata_mutations + ) + args.clear() + outs = compiled_fn(args_with_synthetic_bases) + if num_aliased_args_with_metadata_mutations > 0: + # This code does not handle **all** input metadata mutations. + # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases + # (which only happens if at least one aliased input experienced a data mutation). + # e.g: + # def f(a, b): + # a.mul_(2) + # b.t_(1, 0) + # f(x.view(2, 2), x.view(2, 2)) + mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:] + user_outs = outs[:-num_aliased_args_with_metadata_mutations] + for inp, mutated_inp in zip( + aliased_args_w_metadata_mutations, mutated_metadata_inps + ): + inp.as_strided_( + mutated_inp.size(), + mutated_inp.stride(), + mutated_inp.storage_offset(), + ) + return user_outs + return outs + + return wrapped_compiled_fn + + +# Note [Handling mutations on an input that aliases other inputs] +# The easiest example to show-case this edge case is here: +# +# def f(a, b): +# a.mul_(2) +# out = a + b +# return out +# b = torch.ones(...) +# a = b.view(-1) +# f(a, b) +# +# In this situation, if a and b happened to be aliased, we need to trace something different! +# Suppose we had b = a.view(-1) +# (In this case, that means that `a._base is b`) +# +# We need to ensure that the aliasing relationship between a and b is preserved. +# We do that detecting the specific situation above (mutate an input that aliases another input), +# and when we do that, we create a synthetic base argument. Then inside of the traced forward, +# we regenerate a and b off of that base. +# The complete example of the transformed function looks like this: +# +# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views +# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph +# def traced_forward(base): +# a = base.as_strided(...) +# b = base.as_strided(...) +# a_updated = a.mul(2) +# base_updated = torch.as_strided_scatter(base, a_updated, ...) +# b_updated = base_updated.as_strided(...) +# out = a_updated + b_updated +# return a_updated, out +# +# def compiled_fn(a, b): +# // we detect that a is the "differentiable base" here +# base = a +# // In other situations, we might do either: +# // (1) a and b are both views off of some larger differentiable base +# // assert a._base is b._base and a._base is not None +# // base = a._base +# // (2) a and b both don't require gradients. Create a base from the storage +# // assert a._base is None and b._base is None +# // base = torch.Tensor(a.storage()) +# a_updated, out = traced_forward(base) +# a.copy_(a_updated) +# return out +# +# This function: +# (1) Merges input views into a synthetic base argument, when any of those input views are mutated +# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, +# to respect the new calling convention. +# +# The calling convention is as follows. +# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. +# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], +# Where the ordering of the bases is determined from the ordering of the original view args. +# baseA will come before baseB if the earliest original argument coming from baseA +# showed up earlier in the argument list than the earliest original argument coming from baseB. +# +# Example, given some tensors a, b, c, d +# call site: +# f(a, c.view(-1), b.view(-1), b, c, d) +# Modified argument list: +# c_base comes first because the first c view came earlier in arg list than the first b view +# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases +# b_base = torch.Tensor(b.storage()) +# c_base = torch.Tensor(c.storage()) +# f(c_base, b_base, a, d) +def merge_view_inputs( + aot_config: AOTConfig, + fwd_inputs: list[Any], + # This is None when called at runtime from post_compile closure + fwd_inputs_descs: Optional[list[AOTInput]], + mutated_input_info: list[InputAliasInfo], + *, + # The autograd case currently has more restrictions than the inference case. + is_inference: bool, +) -> tuple[ + list[Any], list[AOTInput], Optional[list[Union[int, tuple[int, torch.Tensor]]]] +]: + if fwd_inputs_descs is None: + fwd_inputs_descs = [DummyAOTInput(i) for i in range(len(fwd_inputs))] + + def _are_differentiable_views(view1, view2): + if view1 is view2: + return True + if view1._base is None and view2._base is None: + return False + if view1._base is view2._base or view1._base is view2 or view1 is view2._base: + return True + return False + + def _same_dtype_views(view1, view2): + if view1.dtype != view2.dtype: + return False + if view1._base is not None and view1.dtype != view1._base.dtype: + return False + if view2._base is not None and view2.dtype != view2._base.dtype: + return False + return True + + assert len(fwd_inputs) == len(mutated_input_info) + if not [info for info in mutated_input_info if info.mutates_data]: + # Return early when there are no mutations. + return fwd_inputs, fwd_inputs_descs, None + + storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list) + base_args = [] + other_args = [] + base_args_descs = [] + other_args_descs = [] + for i, (inpt, source) in enumerate(zip(fwd_inputs, fwd_inputs_descs)): + if isinstance(inpt, Tensor): + storage_ref = StorageWeakRef(inpt.untyped_storage()) + storage_ref_to_idx[storage_ref].append(i) + else: + other_args.append(inpt) + other_args_descs.append(source) + # Note [Synthetic Base Info Metadata] + # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. + # It's either: + # - another int (corresponding to the index in the argument list of the element from the outer calling convention) + # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx]) + # idx corresponds to which synthetic base from the outer calling context to view + inner_calling_convention_meta: dict[int, Union[int, tuple[int, torch.Tensor]]] = {} + for aliased_input_indices in storage_ref_to_idx.values(): + if len(aliased_input_indices) <= 1 or not any( + # We only care about mutations that affect all aliases, + # so metadata mutations on an input doesn't require us to do synthetic base handling. + mutated_input_info[inpt_idx].mutates_data + for inpt_idx in aliased_input_indices + ): + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + other_args_descs.extend( + fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # Here, we attempt to do a more complicated check to detect false aliasing + # (e.g. if all the tensors have the same storage, but don't actually overlap) + # In theory, we could have a large group of tensors that all share storages, where only *some* of them + # have overlapping memory. + # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair + # of tensors in the current group that shares a storage is non-overlapping. + aliased_input_indices_no_false_sharing = compute_overlapping_inputs( + aot_config, fwd_inputs, aliased_input_indices + ) + if len(aliased_input_indices_no_false_sharing) <= 1: + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + other_args_descs.extend( + fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # We detected an input that was mutated, AND aliases with another input. + # we need to replace this set of aliased inputs with a single synthetic base. + # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases + # and error out. We can fix them later. + # These checks are transitive, so we don't need to check every pair. + for idx1, idx2 in zip( + aliased_input_indices, aliased_input_indices[1:], strict=False + ): + view1 = fwd_inputs[idx1] + view2 = fwd_inputs[idx2] + # The "inputs that are aliased but have different differentiable bases" case + # is more complicated and hopefully pretty rare. Not currently handled. + if not is_inference: + assert _are_differentiable_views(view1, view2), ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert _same_dtype_views(view1, view2), ( + "aot_autograd() does not yet handle input mutations on views with different dtypes." + ) + non_none_bases = [ + (i, fwd_inputs[i]._base) + for i in aliased_input_indices + if fwd_inputs[i]._base is not None + ] + aliases_with_none_bases = [ + fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None + ] + synthetic_base_desc: AOTInput + if len(non_none_bases) == 0: + # Case where none of the aliases have a ._base + # we generate a synthetic base without gradients, and generate views off of it + # We hit this case when we have input tensors to the graph that share a storage, + # but do not have a ._base field. + # Wondering when we hit this case? + # The _base field simply says that autograd knows about the aliasing relationship, + # but sometimes we create tensors which are aliased out of the same storage but guaranteed + # to be disjoint. In these cases, we will skip setting up the _base relationship + # for performance reasons (because the fact that the tensors share the same storage + # is unobservable unless you (1) do naughty things with resize_/as_strided + # or (2) look at the storage--as we are doing here.) + # One particular example of this is optimizer steps on the LSTM module: + # LSTM parameters are packed into a contiguous storage for efficiency reasons when + # calling cuDNN kernels, so when these parameters get passed to the optimizer we will + # find they share the same storage, but do not have _base set since they are all disjoint. + # + # NOTE: There is one case where this is unsafe: + # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily + # the same shape as the "actual" base that the tensor came from. + # For the most part this is fine, because we always use as_strided() + # to generate the original aliased inputs again. + # If we were to use view-replay though, this could cause the aliased views + # to have incorrect sizes. + example_idx = aliased_input_indices[0] + example_alias = fwd_inputs[example_idx] + # Note that this function is reused at both trace time and runtime. + # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. + synthetic_base = torch.empty( + (0,), dtype=example_alias.dtype, device=example_alias.device + ) + # We don't actually have a convenient way of going from storage -> tensor, + # So using set_() here (we suffer some minor overhead, but this case is rare). + synthetic_base.set_(example_alias.untyped_storage()) + synthetic_base_desc = SyntheticBaseAOTInput(fwd_inputs_descs[example_idx]) + else: + # Case where all of the aliases require gradients, and have the same _base. + i, synthetic_base = non_none_bases[0] + synthetic_base_desc = ViewBaseAOTInput(fwd_inputs_descs[i]) + for _, other_base in non_none_bases[1:]: + assert other_base is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + for alias in aliases_with_none_bases: + assert alias is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) + base_args.append(synthetic_base) + base_args_descs.append(synthetic_base_desc) + for curr_view_idx in aliased_input_indices: + curr_view = fwd_inputs[curr_view_idx] + base_idx = len(base_args) - 1 + # We store just enough info here so that we can regenerate the view later. + # Regeneration: curr_view._view_func(args[base_idx]) + inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view) + if len(base_args) == 0: + assert len(other_args) == len(fwd_inputs) + # If no synthetic bases are necessary, just return the original inputs. + return fwd_inputs, fwd_inputs_descs, None + else: + from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr + + def make_hashable(arg): + if isinstance(arg, torch.SymInt): + # Since only nested SymInt objects can be hashed, we wrap them with + # SymIntEqByExpr, which is a hashable wrapper of SymInts. + return SymIntEqByExpr(arg) + return arg + + # Otherwise, return: + # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) + # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. + # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. + args_to_functionalization = base_args + other_args + args_to_functionalization_descs = base_args_descs + other_args_descs + + # Map each argument into its old index. + # There may be some repeated arguments, so we collect their indices in a list. + arg_to_old_idx_map = collections.defaultdict(list) + for i, arg in enumerate(fwd_inputs): + arg_to_old_idx_map[make_hashable(arg)].append(i) + # Reverse the list of each argument, so that we can easily pop them one-after-the-other in order. + for hashable_arg in arg_to_old_idx_map: + arg_to_old_idx_map[hashable_arg] = list( + reversed(arg_to_old_idx_map[hashable_arg]) + ) + + for i, other_arg in enumerate(other_args): + new_idx = len(base_args) + i + old_idx = arg_to_old_idx_map[make_hashable(other_arg)].pop() + inner_calling_convention_meta[old_idx] = new_idx + + # post process into a list + post_processed_calling_convention_meta: list[ + Union[int, tuple[int, torch.Tensor]] + ] = [-1 for _ in range(len(inner_calling_convention_meta))] + for k, v in inner_calling_convention_meta.items(): + post_processed_calling_convention_meta[k] = v + # Quick assert: every argument in the inner calling convention should be accounted for. + for x in post_processed_calling_convention_meta: + assert x != -1 + return ( + args_to_functionalization, + args_to_functionalization_descs, + post_processed_calling_convention_meta, + ) + + +# Note: [Backward graph lazy lowering] +# After AOTDispatch traces the backward for graphs requiring autograd, we will lower the graph lazily, +# unless we suspect that inductor might specialize and insert additional guards. When we do lazy +# lowering, we stash the AOT backward graph (bw_module) in this class. +# +# Lowering passes are performed on a deepcopy of this bw_module due to compatibility +# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645. +@dataclass +class AutogradLazyBackwardCompileInfo: + bw_module: Callable + placeholder_list: list[Any] + saved_context: Optional[TracingContext] + saved_compile_context: Optional[CompileContext] + + +# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually +# no need to keep information around for a new lazy compilation. Except for compiled autograd, +# which wants to retrace this backward into a larger graph, and it needs the graph module to do so. +@dataclass +class CachedAutogradLazyBackwardCompileInfo: + bw_module_fn: Callable + + +def _raise_if_functorch_active(): + # not ideal but prevent the user from seeing a nasty traceback - See #138422 + stack = torch._C._functorch.peek_interpreter_stack() + torch._check( + stack is None, + lambda: ( + "It looks like you're trying to call a compiled backward function within vmap/grad/vjp, " + "which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the " + "backward function." + ), + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_prologue_functional( + ctx_saved_tensors, ctx_symints, metadata, maybe_subclass_metadata, *flat_args +): + # Calling convention: we expect a grad_out passed to the backward: + # - for every output of the fw that does *not* alias an input or graph intermediate + # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations) + # - for every graph intermediate that we need to use to generate an output later. + # The other outputs in the autograd.Function.forward that do *not* show up in the backward include: + # - outputs that alias inputs or graph intermediates + # - updated inputs due to metadata-only mutations. + # We need to return them in the forward, but ensure that they all do not get gradients in the backward, + # and we filter them out here before passing the remaining grad_outputs into the compiled backward. + _raise_if_functorch_active() + + num_intermediate_bases = metadata.num_intermediate_bases + num_mutated_runtime_inps = metadata.num_mutated_inp_runtime_indices + expected_grad_outs = ( + metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases + ) + deterministic = metadata.deterministic + global_deterministic = torch.are_deterministic_algorithms_enabled() + if deterministic is not None: + torch._check( + not (not deterministic and global_deterministic), + lambda: ( + "This compiled backward function is being run with " + "torch.use_deterministic_algorithms(True), " + "but it was previously generated during the forward function while " + "torch.use_deterministic_algorithms(False) was set." + ), + ) + + assert len(flat_args) == expected_grad_outs + out_info = metadata.output_info + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[:num_mutated_runtime_inps], + flat_args[ + num_mutated_runtime_inps : num_mutated_runtime_inps + metadata.num_outputs + ], + flat_args[num_mutated_runtime_inps + metadata.num_outputs :], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + input_info = metadata.input_info + inp_tangents_filtered = [ + x + for x, info_idx in zip( + inp_tangents, + metadata.mutated_inp_runtime_indices, + ) + if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x + for x, info in zip(out_tangents, out_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = [ + *inp_tangents_filtered, + *out_tangents_filtered, + *intermediate_base_tangents, + ] + num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) + # TODO: replace this with FunctionalizedRngRuntimeWrapper + rng_args = [] + if metadata.is_rng_op_functionalized: + # Add the seed and offset to args + rng_args = CUDARngStateHelper.get_torch_state_as_tuple() + + bw_tokens = [None] * metadata.num_backward_tokens + + # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first + # in the bw output order. + + # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls + # There are tests that count these calls, saving to var. + num_ctx_saved_tensors = len(ctx_saved_tensors) + all_args = [ + *ctx_symints, + *ctx_saved_tensors, + *flat_bw_args_with_grads, + *bw_tokens, + *rng_args, + ] + del ctx_saved_tensors + + # Note: [AOTAutograd Backward Guards] + # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. + # Doing so requires us to "guess" about some of the metadata of our grad_outputs. + # + # In particular: if an output to the forward is a plain tensor or a subclass, + # its corresponding grad_output in the backward **may or may not** be + # a plain tensor or a subclass. The main cases are: + # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, + # *unless* the output is used in some subclass compute later in the forward graph, + # which will cause its grad_output to become a subclass + # (2) If an output is a subclass, its grad_out will also be a subclass, + # *unless* the output of the forward did not actually participate in the gradient computation, + # in which case autograd will insert a plain tensor of zeros for the grad_output. + # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, + # although this is not turned on today in AOTAutgrad and would require more work. + # + # Today, we make a guess on subclass-ness based on the above examples, + # and hard-error in the backward if we guessed wrong. + # + # In the future, we should add backward guards that would allow us to + # properly handle this case instead of erroring: we would need to retrace the backward graph, + # since we might produce an entirely different trace if our grad_outputs are subclass or not. + del flat_bw_args_with_grads + + tangents_start_idx = ( + len(all_args) - num_flat_bw_args_with_grads - len(rng_args) - len(bw_tokens) + ) + assert tangents_start_idx == len(ctx_symints) + num_ctx_saved_tensors + tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens) + + # TODO: figure out how to refactor the backward properly + # so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + tangents = all_args[tangents_start_idx:tangents_end_idx] + + if len(tangents) != len(metadata.subclass_tangent_meta): + raise RuntimeError( + "The grad inputs should be same number as forward output tangents" + ) + + flat_processed_tangents = list( + itertools.chain.from_iterable( + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + m, + )[1] + ) + for t, m in zip( + tangents, + metadata.subclass_tangent_meta, + ) + ) + ) + + all_args = ( + runtime_unwrap_tensor_subclasses( + all_args[:tangents_start_idx], + # SymInts that are inputs to the backward graph are + # already included in the "all_args" list. + # Any symints coming from tensor subclasses should always + # come from primals, and so they will show up as extra + # arguments to the forward graph, and they will be saved + # as activation in the backward graph. + append_symints=False, + ) + + flat_processed_tangents + + runtime_unwrap_tensor_subclasses( + all_args[tangents_end_idx:], + append_symints=False, + ) + ) + else: + all_args = [ + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + metadata.subclass_tangent_meta[i - tangents_start_idx], + )[0] + if (tangents_start_idx <= i < tangents_end_idx) + else t + ) + for i, t in enumerate(all_args) + ] + + # Backward with forward inputs mutations is not supported in double backward. + if ( + torch.is_grad_enabled() + and metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw + ): + raise RuntimeError( + "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True" + ) + + return all_args + + +def initialize_rng_states( + num_rng: int, + graphsafe_idx: int, + fwd_rng_states: list[torch.Generator], + bwd_rng_states: list[torch.Generator], +): + """ + Initialize the cudagraph safe rng states. + + Initialization of rng states should have a few properties: + - the initialization for each rng state should be independent + - the initialization should be deterministic + - the initialization should be based off current rng state, so that independent graphs do not + have equal rng behavior + + We defer initialization of rng states until runtime because compilation is wrapped + with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations + do not give equal randomness. + """ + with torch.utils._python_dispatch._disable_current_modes(): + seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu") + fwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + bwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_epilogue_functional( + metadata, maybe_subclass_metadata, out, *, make_subclass_override=None +): + # Toss out the backward output tokens + num_bw_tokens = metadata.num_backward_tokens + if num_bw_tokens > 0: + out = out[:-num_bw_tokens] + + # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile + out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue( + metadata, out, offset_index=len(out) - 1 + ) + out = tuple(out) + + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + assert maybe_subclass_metadata.grad_input_metas is not None + outs_wrapped = wrap_tensor_subclasses( + out, + subclass_metas=maybe_subclass_metadata.grad_input_metas, + included_subclass_symints=True, + is_runtime=True, + make_subclass_override=make_subclass_override, + ) + return outs_wrapped + return out + + +def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta): + if memory_format.memory_format is not None: + # Coerce to torch.memory_format + if not x.is_contiguous(memory_format=memory_format.memory_format): + x = x.contiguous(memory_format=memory_format.memory_format) + return x + + expected_size = memory_format.size + assert expected_size is not None + expected_stride = memory_format.stride + assert expected_stride is not None + # Expected size and stride are static ints + # ok to use == to compare runtime tensor strides and shapes + + if x.shape == expected_size and x.stride() == expected_stride: + # Runtime tangent size and stride are the same as expected, no need to coerce + return x + + # Empty_strided creates a raw Tensor. + # We are guaranteed that only raw Tensors has expected size and stride. + # Subclasses have only expected memory_format. + restrided = torch.empty_strided( + size=expected_size, + stride=expected_stride, + dtype=x.dtype, + device=x.device, + layout=x.layout, + requires_grad=x.requires_grad, + ) + restrided.copy_(x) + return restrided + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks(): + error_message = ( + "Saved tensors hooks were specialized as GraphModules." + "In this case aot_autograd inlines them in forward and backward graph " + "and disables them during runtime of aot_autograd compiled region." + "If you see this error, that means that there is some unexpected push or pop manipulation " + "during aot_autograd compiled region runtime." + "Compilation with different hooks must result in recompilation." + ) + fail_if_non_empty = False + maybe_prev_message = None + try: + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) + torch._C._autograd._saved_tensors_hooks_disable( + error_message, fail_if_non_empty + ) + yield + finally: + if maybe_prev_message is None: + torch._C._autograd._saved_tensors_hooks_enable() + else: + torch._C._autograd._saved_tensors_hooks_disable( + maybe_prev_message, fail_if_non_empty + ) + + +@dataclass +class SerializableCompiledFunction: + """ + Represents a result of AOTDispatch after calling the inner compiler + that can be serialized + """ + + compiled_fn: Callable + serialize_fn: Callable + + def __init__(self, compiled_fn: Callable, serialize_fn: Callable): + self.compiled_fn = compiled_fn + self.serialize_fn = serialize_fn + # Equivalent to functools.wraps + functools.update_wrapper( + self, + compiled_fn, + assigned=("__doc__", "__annotations__", "__type_params__"), + ) + + def serialize(self) -> Any: + return self.serialize_fn() + + def __call__(self, *args, **kwargs): + return self.compiled_fn(*args, **kwargs) + + +# This is wrapped in a class just for namespacing purposes +# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly +class AOTDispatchAutograd: + @staticmethod + def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]): + if not isinstance(x, torch.Tensor): + return x, [x] + + if isinstance(x, FakeTensor): + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + return x, [x] + + expected_type: Optional[type] = torch.Tensor + expected_meta = None + if isinstance(meta, SubclassCreationMeta): + expected_type = meta.original_subclass_type + expected_meta = meta.meta + + runtime_type = type(x) + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + + runtime_meta = None + runtime_subclass_keys: Sequence[str] = [] + + if is_traceable_wrapper_subclass(x): + runtime_subclass_keys, runtime_meta = x.__tensor_flatten__() + + def maybe_coerce(x): + same_type: bool = expected_type == runtime_type + same_meta: bool = expected_meta == runtime_meta + + if same_type and same_meta: + return x + + if not hasattr(x, "__coerce_same_metadata_as_tangent__"): + return None + + if same_type: + # Backward Compatibility, as some Subclass impls can have original 1-arg function. + return x.__coerce_same_metadata_as_tangent__(expected_meta) + + return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type) + + # Coerce to expected type and metadata + orig_x = x + x = maybe_coerce(x) + if x is None: + raise RuntimeError( + f""" +During the backward, we encountered a tensor subclass where we guessed its +metadata incorrectly. + +Expected metadata: {str(expected_meta)}, expected type: {str(expected_type)} + +Runtime metadata: {str(runtime_meta)}, runtime type: {str(runtime_type)} + +shape: {str(orig_x.shape)} +To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. +""" + ) + + # Coerce to expected memory format + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + + if not is_traceable_wrapper_subclass(x): + return x, [x] + + assert isinstance(meta, SubclassCreationMeta) + if orig_x is not x: + runtime_subclass_keys = x.__tensor_flatten__()[0] + + assert len(meta.attrs) == len(runtime_subclass_keys) + leaves = [] + for attr, attr_meta in meta.attrs.items(): + elem = getattr(x, attr) + new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( + elem, attr_meta + ) + if new_elem is not elem: + setattr(x, attr, new_elem) + leaves.extend(elem_leaves) + + return x, leaves + + @staticmethod + def post_compile( + compiled_fw_func, # fw_module after compilation + wrappers + compiled_bw_func, # bw_module after compilation + wrappers + maybe_subclass_meta: Optional[SubclassMeta], + num_symints_saved_for_bw_: int, + backward_state_indices: list[int], + disable_amp: bool, + indices_of_inps_to_detach: list[int], + lazy_backward_info: Optional[ + Union[ + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, + ] + ], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, # runtime metadata + try_save_cache_entry: Optional[Callable], # Serialization function + ): + # For additional context see Note [CUDA Graph Safe RNG Functionalization] + # Each pair forward, backward rng states must be equal prior to its invocation on any + # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, + # running forward then backward advances them the same amount and keeps them equal. + # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. + # Initially we have: + # fwd_state0 == bwd_state0. + # Lets say we run: + # fwd0: fwd_state0 -> fwd_state1 + # fwd1: fwd_state1 -> fwd_state2 + # fwd2: fwd_state2 -> fwd_state3 + # If we now invoke bwd2, + # we need to update bwd_state equal to the rng that was observed in fwd2. + # we save the rng_state fwd_state2 in forward because we detect that it is not the + # current backward state and therefore would not be accessible if we do not save it. + # Similarly, if we are going to update the backward state to a new value, and there is a pending + # forwards which needs its current state, we will save it. + # Within the autograd context, we keep track of the curr iteration so that on backward + # we know what the generator state must be before the backward is run. + num_rng = fw_metadata.num_graphsafe_rng_states + graphsafe_idx = fw_metadata.graphsafe_rng_state_index + fwd_rng_states: list[torch.Generator] = [] + bwd_rng_states: list[torch.Generator] = [] + curr_fwd_iter = itertools.count(0) + backward_state_position = 0 + pending_forwards: set[int] = set() + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} + + class CompiledFunction(torch.autograd.Function): + compiled_fw = compiled_fw_func + compiled_bw = compiled_bw_func + metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] + maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta + num_symints_saved_for_bw = num_symints_saved_for_bw_ + _aot_id = aot_config.aot_id + _lazy_backward_info = lazy_backward_info + + @staticmethod + def _compiled_autograd_key(ctx): + return (ctx._autograd_function_id, *ctx.symints) + + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, *deduped_flat_tensor_args): + args = deduped_flat_tensor_args + if backward_state_indices: + bw_state = args[backward_state_indices[0]] + assert isinstance(bw_state, BackwardState) + ctx._compiled_autograd_backward_state = bw_state + + if num_rng: + if len(fwd_rng_states) == 0: + assert graphsafe_idx is not None + initialize_rng_states( + num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states + ) + + _curr_iter = next(curr_fwd_iter) + ctx._curr_iter = _curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if _curr_iter != backward_state_position: + saved_backward_tensor_states[_curr_iter] = [ + rng_state.get_state() for rng_state in fwd_rng_states + ] + + pending_forwards.add(_curr_iter) + args = (*args, *fwd_rng_states) + + # There is a pretty complicated calling convention around what the compiled fw returns. + # The full list of outputs and their relative order is: + # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) + # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version + # of the original view, and not the synthetic base + # - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last + # in the fw output order. + fw_outs = call_func_at_runtime_with_args( + CompiledFunction.compiled_fw, + # pyrefly: ignore [bad-argument-type] + args, + disable_amp=disable_amp, + ) + + num_outputs = CompiledFunction.metadata.num_outputs + num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased + num_mutated_runtime_inps = ( + CompiledFunction.metadata.num_mutated_inp_runtime_indices + ) + num_forward_returns = CompiledFunction.metadata.num_forward_returns + + # Partitioners must put symint arguments at the end separate from tensor arguments + tensors_saved_for_backwards = fw_outs[ + CompiledFunction.metadata.tensors_saved_for_backwards_slice + ] + assert all( + isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards + ) + + def mark_dynamic_activations(activations: list[torch.Tensor]): + for ( + idx, + dims, + ) in CompiledFunction.metadata.dynamic_saved_tensors_idxs.items(): + maybe_mark_dynamic_helper(activations[idx], dims) + return activations + + # See Note [Detaching saved tensors in AOTAutograd] + ctx.save_for_backward( + *mark_dynamic_activations( + [ + x.detach() if x._is_view() else x + for x in tensors_saved_for_backwards + ] + ) + ) + symint_outs = fw_outs[ + CompiledFunction.metadata.symints_saved_for_backwards_slice + ] + assert all( + isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) + for x in symint_outs + ), str([type(x) for x in symint_outs]) + ctx.symints = symint_outs + + raw_returns = fw_outs[0:num_forward_returns] + + # Wrap all autograd.Function.forward() outputs that are aliases + # so that autograd.Function doesn't treat them as tensors + if num_mutated_runtime_inps > 0: + for i, idx in enumerate( + CompiledFunction.metadata.mutated_inp_runtime_indices + ): + # We could make this faster by only looping over inputs with metadata-only mutations + # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. + info = CompiledFunction.metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + raw_return_idx = i + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + user_mutated_inputs_raw = raw_returns[ + 0:num_mutated_runtime_inps + ] + mut_inp_infos = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + assert len(user_mutated_inputs_raw) == len(mut_inp_infos) + + if CompiledFunction.metadata.num_unsafe_view_outputs > 0: + for idx in CompiledFunction.metadata.unsafe_view_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + o = raw_returns[raw_return_idx] + raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( + o, o.shape + ) + + if num_outputs_aliased > 0: + for idx in CompiledFunction.metadata.aliased_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + assert not any( + isinstance(x, TensorAlias) for x in intermediates_raw + ) + + # invariant: intermediate bases always require gradients, so we don't have to + # consider marking them as non-differentiable. + raw_returns_not_including_intermediate_bases = raw_returns[ + : num_mutated_runtime_inps + num_outputs + ] + raw_returns_meta = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + CompiledFunction.metadata.output_info + + fw_outs_not_requiring_grad = [ + x + for (i, x) in enumerate( + raw_returns_not_including_intermediate_bases + ) + if isinstance(x, torch.Tensor) + and not raw_returns_meta[i].requires_grad + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + return tuple(raw_returns) + + @staticmethod + def backward(ctx, *flat_args): + all_args = _backward_prologue_functional( + ctx.saved_tensors, + ctx.symints, + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + *flat_args, + ) + + if num_rng: + nonlocal backward_state_position, bwd_rng_states + curr_backward_iter = ctx._curr_iter + retain_graph = ( + torch._C._autograd._get_current_graph_task_keep_graph() + ) + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + backward_state_position in pending_forwards + and backward_state_position not in saved_backward_tensor_states + and ( + backward_state_position != curr_backward_iter + or retain_graph + ) + ): + saved_backward_tensor_states[backward_state_position] = [ + rng_state.get_state() for rng_state in bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in saved_backward_tensor_states: + if backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + bwd_rng_states, + saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del saved_backward_tensor_states[curr_backward_iter] + else: + assert backward_state_position == curr_backward_iter + + backward_state_position = curr_backward_iter + 1 + if not retain_graph: + pending_forwards.remove(curr_backward_iter) + all_args.extend(bwd_rng_states) + + def impl_fn(double_ctx=None): + out = CompiledFunction._backward_impl(ctx, all_args) + return _backward_epilogue_functional( + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + out, + ) + + needs_grad = torch.is_grad_enabled() and any( + t.requires_grad for t in all_args if isinstance(t, torch.Tensor) + ) + if needs_grad: + # double backward + return CompiledFunction._double_backward(ctx, impl_fn, all_args) + else: + return impl_fn() + + @staticmethod + def _double_backward(ctx, impl_fn, all_args): + # Ensure that the graph is connected, and error if double backward is performed. + # See comment for why once_differentiable is not sufficient: + # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 + class CompiledFunctionBackward(torch.autograd.Function): + # CompiledFunctionBackward is not yet supported in dynamo skipfiles + _aot_id = aot_config.aot_id + + @staticmethod + # pyrefly: ignore [bad-override] + def forward(double_ctx, *unused_args): + return impl_fn(double_ctx) + + @staticmethod + def backward(double_ctx, *args): + raise RuntimeError( + "torch.compile with aot_autograd does not currently support double backward" + ) + + CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign] + CompiledFunction._compiled_autograd_key + ) + + return CompiledFunctionBackward.apply(*all_args) + + @staticmethod + def _backward_impl(ctx, all_args): + # compiled autograd reimplements this function at proxy_call_aot_backward + assert not backward_state_indices, ( + "BackwardState requires CompiledAutograd" + ) + ctx.maybe_clear_saved_tensors() + + saved_tensors_use_once = ( + not torch._C._autograd._get_current_graph_task_keep_graph() + ) + + if CompiledFunction.compiled_bw is None: + assert lazy_backward_info is not None + assert isinstance( + lazy_backward_info, AutogradLazyBackwardCompileInfo + ) + + if ( + hasattr(lazy_backward_info, "saved_context") + and lazy_backward_info.saved_context is not None + ): + assert isinstance( + lazy_backward_info.saved_context, TracingContext + ) + ddp_ctx = lazy_backward_info.saved_context.ddp_optimizer_ctx + if ddp_ctx is not None: + assert ddp_ctx.curr_bucket >= 0, ( + f"expected same # of fw and bw compiles, but found bucket {ddp_ctx.curr_bucket}" + ) + curr_fw_meta = ddp_ctx.metadata_per_bucket[ + ddp_ctx.curr_bucket + ] + # Note [DDPOptimizer and fw_metadata] + # When using the DDPOptimizer, we have a single dynamo graph (and TracingContext), + # but multiple AOTDispatcher graph. + # + # One consequence is that there will be **multiple** fw_metadata objects, one per AOT graph, + # which we stash the fw_metadata on the TracingContext. + # + # Normally what happens is that as we compile AOT graphs 1...N, we clobber the fw_metadata + # for graph i-1 when we start running AOT for graph i. + # Ordinarily this is fine, because inductor no longer needs the metadata from graph i-1. + # + # However, this is a problem for lazy compilation of the backward. During backward compilation, + # we compile the backward lazily at backward runtime, meaning that we will first compile + # backward graph N, N-1, ..., 1. + # We need to ensure that at the time inductor compiles bw graph N-1, it can access + # the corresponding fw_metadta for graph N-1. + # + # We do this by stashing a DDPOptimizerContext, which tracks: + # - the metadata of all N graphs + # - the graph we are currently compiling in our DDPOptimizer region. + ddp_ctx.curr_bucket -= 1 + lazy_backward_info.saved_context.fw_metadata = curr_fw_meta + + if not saved_tensors_use_once: + fw_metadata.bw_donated_idxs = [] + # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` + if ( + hasattr(lazy_backward_info, "saved_context") + and hasattr(lazy_backward_info.saved_context, "fw_metadata") + and hasattr( + lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] + "bw_donated_idxs", + ) + ): + lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] + [] + ) + + bw_module = lazy_backward_info.bw_module + placeholder_list = lazy_backward_info.placeholder_list + saved_context = lazy_backward_info.saved_context + saved_compile_context = lazy_backward_info.saved_compile_context + + context = torch._C._DisableAutocast if disable_amp else nullcontext + metrics_context = get_metrics_context() + with ( + tracing(saved_context), + compile_context(saved_compile_context), + context(), + track_graph_compiling(aot_config, "backward"), + metrics_context, + dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + log_pt2_compile_event=True, + dynamo_compile_column_us="backward_cumulative_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="entire_backward_compile", + ), + callback_handler.install_callbacks( + CallbackTrigger.LAZY_BACKWARD, + str(CompileContext.current_compile_id()), + ), + ): + CompileEventLogger.compilation_metric(is_forward=False) + # See Note: [Backward graph lazy lowering] + CompiledFunction.compiled_bw = aot_config.bw_compiler( + copy.deepcopy(bw_module), placeholder_list + ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + bw_module, + fw_metadata, + aot_config, + ) + + if ( + torch._functorch.config.donated_buffer + and not saved_tensors_use_once + and fw_metadata.bw_donated_idxs != [] + ): + torch._check( + False, + lambda: ( + "This backward function was compiled with non-empty donated " + "buffers which requires create_graph=False and retain_graph=False. " + "Please keep backward(create_graph=False, retain_graph=False) " + "across all backward() function calls, or set " + "torch._functorch.config.donated_buffer=False to disable " + "donated buffer." + ), + ) + + out = call_func_at_runtime_with_args( + CompiledFunction.compiled_bw, + all_args, + steal_args=True, + disable_amp=disable_amp, + ) + return out + + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=indices_of_inps_to_detach, + trace_joint=True, + disable_amp=disable_amp, + ).post_compile( + CompiledFunction.apply, + aot_config, + runtime_metadata=fw_metadata, + ) + + return compiled_function + + +@dataclass +class DebugAssertWrapper(CompilerWrapper): + flat_requires_grad: list[Optional[bool]] = field(default_factory=list) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def debug_compiled_function(args: list[Any]): + # TODO: Check aliasing relationships + # TODO: Check strides for metadata mutation + # (NB: ideally, this logic is factored out of this function and + # you move these debug checks there) + + # Check requires grad. Bad case is when we compiled with + # requires_grad = False, but input requires_grad = True + # (vice versa is OK; we compute a gradient and then throw + # it away when it hits the input.) + for i, a in enumerate(args): + can_require_grad = self.flat_requires_grad[i] + if can_require_grad is None: + assert not isinstance(a, Tensor) + elif not can_require_grad: + assert not a.requires_grad, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad", + ) + + return compiled_fn(args) + + return debug_compiled_function + + +def pre_compile( + wrappers: list[CompilerWrapper], + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function and arguments. + Mutates wrappers in place. + """ + for wrapper in wrappers: + flat_fn, flat_args, flat_args_descs, fw_metadata = wrapper.pre_compile( + flat_fn, flat_args, flat_args_descs, aot_config, fw_metadata=fw_metadata + ) + return flat_fn, flat_args, flat_args_descs, fw_metadata + + +def post_compile( + wrappers: list[CompilerWrapper], + compiled_fn: Callable, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, +) -> tuple[Callable, ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function. Should be called after pre_compile() + """ + for wrapper in reversed(wrappers): + compiled_fn = wrapper.post_compile( + compiled_fn, aot_config, runtime_metadata=runtime_metadata + ) + return compiled_fn, runtime_metadata + + +def make_runtime_safe( + fw_metadata: ViewAndMutationMeta, + maybe_subclass_meta: Optional[SubclassMeta], +): + """ + Calls make_runtime_safe on all ViewAndMutationMetas. + Modifies both arguments. Allows ViewAndMutationMetas to + be safely cached in AOTAutogradCache. + """ + fw_metadata.make_runtime_safe() + if maybe_subclass_meta is not None: + maybe_subclass_meta.fw_metadata.make_runtime_safe() + if maybe_subclass_meta.grad_input_metas: + for meta in maybe_subclass_meta.grad_input_metas: + if isinstance(meta, SubclassCreationMeta): + meta.make_runtime_safe() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb76a637bf71ca8b813d68fcae3123159a21114 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/streams.py @@ -0,0 +1,281 @@ +from typing import Any, Optional, TypeAlias + +import torch.fx +import torch.fx.traceback +import torch.utils._pytree as pytree +from torch._dynamo.graph_utils import _get_flat_args +from torch._dynamo.variables.streams import get_current_stream, new_event +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + get_compute_time, + get_transfer_time, +) + +from .indexed_dict import IndexedDict + + +Node: TypeAlias = torch.fx.Node +Graph: TypeAlias = torch.fx.Graph + + +def get_roofline_estimate(node: Node) -> float: + assert node.op == "call_function", "non-func node in roofline estimate" + + def map_value(x: Any) -> Any: + return x.meta.get("value", x) if isinstance(x, Node) else x + + func = node.target + if func in _IGNORE_OPS: + return 0.0 + + mapped_args = torch.fx.map_arg(node.args, map_value) + mapped_kwargs = torch.fx.map_arg(node.kwargs, map_value) + flat_args_kwargs = [map_value(x) for x in _get_flat_args(node, {})] + flat_outs, _ = pytree.tree_flatten(node.meta.get("value", node)) + out = node.meta.get("value", node) + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES + } + + return ( + max( + get_transfer_time(flat_args_kwargs, flat_outs), + get_compute_time(func, mapped_args, mapped_kwargs, out, out_dtypes), + ) + / 1e6 + ) + + +def is_gradient_acc(node: Node) -> bool: + return node.meta.get("is_gradient_acc", False) + + +def is_bwd_node(node: Node) -> bool: + tag = node.meta.get("partitioner_tag") + return tag == "is_backward" or tag == "must_be_in_backward" + + +def get_device(node: Node) -> torch.device: + return node.meta["val"].device + + +def get_stream(node: Node) -> Optional[int]: + maybe_annotation = node.meta.get("custom", None) + if maybe_annotation is not None: + return node.meta["custom"].get("stream", None) + else: + return None + + +def get_stream_or_current_stream(node: Node) -> int: + ind = get_stream(node) + if ind is None: + ind = get_current_stream(get_device(node)) + return ind + + +def set_stream(node: Node, ind: int) -> None: + if "custom" in node.meta: + node.meta["custom"].update({"stream": ind}) + else: + node.meta["custom"] = {"stream": ind} + + +def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> Node: + with graph.inserting_after(node): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + return node + + +def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> Node: + with graph.inserting_before(node): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + return node + + +def populate_stream_timeline( + stream_to_timeline: dict[Optional[int], IndexedDict[Node, float]], + graph: Graph, + stream_index: Optional[int], +) -> IndexedDict[Node, float]: + if stream_index not in stream_to_timeline: + stream_to_timeline[stream_index] = IndexedDict() + total_time = 0.0 + for node in graph.nodes: + # mlazos: not sure if we should include forward here too but don't think it matters + if is_bwd_node(node) and get_stream(node) == stream_index: + total_time += get_roofline_estimate(node) + stream_to_timeline[stream_index][node] = ( + total_time # NB: total time includes the node's runtime + ) + + return stream_to_timeline[stream_index] + + +# NB: we start all estimates at 0, estimating the total runtime of each stream with timestamps at each node +# we then try and use these timestamps to estimate when to deallocate tensors used in side streams +# See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream +# for details on the problem being addressed. Rather than using the automatic memory management approach of record_stream +# we attempt to find the point which to deallocate based on the estimated timestamps. +def handle_synced_deallocation( + graph: Graph, + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]], + node: Node, + last_usage: Node, +) -> None: + assert is_bwd_node(node), ( + "synced allocations should only be handled on backward nodes" + ) + assert is_bwd_node(last_usage), ( + "synced allocations should only be handled on backward nodes" + ) + allocating_stream = get_stream(node) + side_stream = get_stream(last_usage) + assert allocating_stream != side_stream, ( + "allocating and side stream should be different for synced deallocations" + ) + if not torch.cuda.is_available(): + # fallback to record_stream in this case + with graph.inserting_after(node): + graph.call_function( + torch.ops.streams.record_stream.default, + ( + node, + get_stream_or_current_stream(last_usage), + ), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + allocating_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, allocating_stream + ) + side_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, side_stream + ) + + alloc_ptr = node + target_side_stream_time = side_stream_trace[last_usage] + # linear search from first usage of tensor to a point in time after the side stream has finished + while alloc_ptr is not None: + alloc_time = allocating_stream_trace[alloc_ptr] + + if alloc_time >= target_side_stream_time: + break + elif alloc_time < target_side_stream_time: + next_ptr = allocating_stream_trace.next_key(alloc_ptr) + if next_ptr is not None: + alloc_ptr = next_ptr + else: + break + + wait_event = new_event() + record_node = insert_record_event_after_node(graph, last_usage, wait_event) + with graph.inserting_after(max(alloc_ptr, record_node)): + graph.call_function( + torch.ops.streams.sync_dealloc.default, + (wait_event, get_stream_or_current_stream(alloc_ptr), node), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + +def insert_sync( + graph: Graph, + consumer: Node, + producer: Node, + node_to_wait_event_ind: dict[Node, int], +) -> None: + if producer not in node_to_wait_event_ind: + node_to_wait_event_ind[producer] = new_event() + + insert_record_event_after_node( + graph, producer, node_to_wait_event_ind[producer] + ) + insert_wait_event_before_node(graph, consumer, node_to_wait_event_ind[producer]) + + +def assign_backward_streams(gm: torch.fx.GraphModule) -> None: + """Assigns backward streams to gradient accumulation nodes""" + + # NB: iterate in reverse order to more closely match eager + # the user node stream will be populated first + for node in reversed(list(gm.graph.nodes)): + if is_gradient_acc(node): + # Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream: + # 1. Match first stream assignment of the first user with a stream + # 2. Match first stream assignment encountered in the args from left to right + # This differs from eager in some cases: + # Specifically the eager code uses the autograd node to determine the stream, + # crucially this does not necessarily correspond to the FX graph node. For example, + # in the backward for an add node with a constant we will passthrough and during backward tracing, + # no op will be added to the FX graph, so our stream assignment will differ in this case. + gradients = _get_flat_args(node, {}) + users = list(node.users.keys()) + + # All gradients will be on same device, they will be coerced if they were not with a .to() node + for neighbor in users + gradients: + ind = get_stream(neighbor) + if ind is not None: + set_stream(node, ind) + break + + +def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: + """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" + node_to_wait_event_ind: dict[Node, int] = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + flat_args = _get_flat_args(node, {}) + cur_node_stream = get_stream(node) + + for arg in flat_args: + if is_bwd_node(arg): + arg_stream = get_stream(arg) + if arg_stream != cur_node_stream and get_device(arg).type != "cpu": + insert_sync(gm.graph, node, arg, node_to_wait_event_ind) + + +def sync_deallocations(gm: torch.fx.GraphModule) -> None: + """Handles https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream""" + # Note: this is only needed if the last usage of a tensor is on a stream other than + # the stream the tensor was allocated on + + # an estimated timestamp from the beginning of graph execution (assuming 0 CPU overhead) + # I think this is fine because you should have large tensors if you're using streams + # although perhaps I could add a constant 10us per op ahead of the first stream op? + # a trace of all the nodes running in a given stream + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]] = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + allocating_stream = get_stream(node) + users = list(node.users.keys()) + if not users: + continue + last_user = max(user for user in users) + if last_user.op == "output": + continue + side_stream = get_stream(last_user) + if allocating_stream != side_stream: + handle_synced_deallocation( + gm.graph, stream_to_exec_trace, node, last_user + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a579888dfade33b49ba6f24d1542bcc24a082f29 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py @@ -0,0 +1,520 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. +AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, +and this includes tensor subclasses that implement __torch_dispatch__. +""" + +import collections +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, TypeGuard, TypeVar, Union + +import torch +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._subclasses.fake_tensor import get_plain_tensors +from torch.types import IntLikeType +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .descriptors import ( + AOTInput, + AOTOutput, + DummyAOTInput, + SubclassGetAttrAOTInput, + SubclassGetAttrAOTOutput, + SubclassSizeAOTInput, + SubclassSizeAOTOutput, + SubclassStrideAOTInput, + SubclassStrideAOTOutput, +) +from .schemas import ( + FxValue, + MutationType, + PlainTensorMeta, + SubclassCreationMeta, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + +T = TypeVar("T", bound=torch.Tensor) + + +def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: + args_flattened = pytree.arg_tree_leaves(*args) + any_subclass_args = any( + is_traceable_wrapper_subclass(x) + for x in args_flattened + if isinstance(x, Tensor) + ) + from torch._functorch._aot_autograd.schemas import SubclassCreationMeta + + any_subclass_outputs = any( + type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta + ) + # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime. + return any_subclass_args or any_subclass_outputs + + +from .schemas import MemoryFormatMeta + + +def maybe_suggest_memory_format( + t, with_memory_format: bool +) -> Optional[MemoryFormatMeta]: + if not with_memory_format: + return None + + return MemoryFormatMeta.from_tensor(t) + + +def get_subclass_typing_container( + tensor_subclass: torch.Tensor, +) -> dict[type[torch.Tensor], list[type[torch.Tensor]]]: + """ + Given a subclass, returns a recursive dictionary mapping each + inner tensors to its' subclass types. + """ + + def _get_types_for_subclass(tensor_subclass: torch.Tensor) -> None: + if not is_traceable_wrapper_subclass(tensor_subclass): + return + tracker[type(tensor_subclass)].append(tensor_subclass) + inner_keys, _ = tensor_subclass.__tensor_flatten__() + for key in inner_keys: + inner_tensor = getattr(tensor_subclass, key) + _get_types_for_subclass(inner_tensor) + + tracker: dict[Any, list[Any]] = collections.defaultdict(list) + _get_types_for_subclass(tensor_subclass) + return tracker + + +def create_subclass_metadata( + a: Any, start_idx: int, count_symints: bool, with_memory_format: bool = False +): + if not is_traceable_wrapper_subclass(a): + idx = start_idx + 1 + return ( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + idx, + ) + + inner_keys, metadata = a.__tensor_flatten__() + new_start_idx = start_idx + attrs = {} + + for key in inner_keys: + new_subclass_meta, new_start_idx = create_subclass_metadata( + getattr(a, key), + new_start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + attrs[key] = new_subclass_meta + + # It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart. + assert isinstance(a, Tensor) + + new_start_idx = ( + new_start_idx + + count_symints * len(enumerate_filter_symints(a.size())) + + count_symints * len(enumerate_filter_symints(a.stride())) + ) + + return ( + SubclassCreationMeta( + flat_tensor_start_idx=start_idx, + arg_count=new_start_idx - start_idx, + included_subclass_symints=count_symints, + attrs=attrs, + meta=metadata, + outer_size=a.size(), # type: ignore[attr-defined, arg-type] + outer_stride=a.stride(), # type: ignore[arg-type] + original_subclass=a, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + new_start_idx, + ) + + +# Given a flat list of arguments, some of which may be tensor subclasses, +# computes metadata about "how to reconstruct the current list of subclasses, +# if we were given their flattened dense tensors instead" +def create_subclass_meta( + curr_args: Union[list[Any], tuple[Any, ...]], + *, + count_symints: bool = True, + with_memory_format: bool = False, +) -> list[Union[PlainTensorMeta, SubclassCreationMeta]]: + idx = 0 + infos: list[Union[PlainTensorMeta, SubclassCreationMeta]] = [] + for a in curr_args: + if is_traceable_wrapper_subclass(a): + assert isinstance(a, Tensor) + start_idx = idx + subclass_meta, _ = create_subclass_metadata( + a, + start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + infos.append(subclass_meta) + cnt = subclass_meta.arg_count + else: + infos.append( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ) + ) + cnt = 1 + idx += cnt + return infos + + +def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]: + # Capture all SymInts from the iterable. + def symint_check(s: IntLikeType) -> TypeGuard[SymInt]: + return isinstance(s, SymInt) and not s.node.is_nested_int() + + return [(i, s) for i, s in enumerate(lst) if symint_check(s)] + + +def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]: + # Non-nested symints are replaced with None in `make_runtime_safe()` + return [s is None for s in lst] + + +# Intended to make it easier to define function that is +# either (AOTInput -> AOTInput) or (AOTOutput -> AOTOutput) +# but not the other combos +AOTDescriptor = TypeVar("AOTDescriptor", AOTInput, AOTOutput) + + +# This function takes in a pytree of arguments and unwraps any tensor +# subclasses. +# +# NOTE: The reason for "append_symints": +# +# * At compile time: we append extra symint args when unwrapping primals +# (but not tangents, because they should always share symints with primals). +# We also append extra symints when unwrapping the subclass outputs of the +# traced function, so we can return them as extra outputs +# +# * At runtime: we similarly append subclass sizes when we unwrap subclass +# primals (but not tangents) on entry to the forward. See the runtime version of +# this function below. +def unwrap_tensor_subclasses( + wrapped_args: list[FxValue], + wrapped_args_descs: list[AOTDescriptor], + *, + append_symints: bool, +) -> tuple[list[FxValue], list[AOTDescriptor]]: + def flatten_subclass( + t: FxValue, + desc: AOTDescriptor, + *, + out: tuple[list[FxValue], list[AOTDescriptor]], + ): + # unwrap a subclass into plain tensors and their size/stride if "append_symint" + # is True + if not is_traceable_wrapper_subclass(t): + out[0].append(t) + out[1].append(desc) + return + + attrs, _ = t.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(t, attr) + n_desc: Any = ( + SubclassGetAttrAOTInput(desc, attr) + if isinstance(desc, AOTInput) + # pyrefly: ignore [bad-argument-type] + else SubclassGetAttrAOTOutput(desc, attr) + ) + flatten_subclass(inner_tensor, n_desc, out=out) + + if append_symints: + sizes = enumerate_filter_symints(t.size()) + strides = enumerate_filter_symints(t.stride()) + out[0].extend(s for _, s in sizes) + out[0].extend(s for _, s in strides) + if isinstance(desc, AOTInput): + out[1].extend(SubclassSizeAOTInput(desc, i) for i, _ in sizes) # type: ignore[misc] + out[1].extend(SubclassStrideAOTInput(desc, i) for i, _ in strides) # type: ignore[misc] + else: + out[1].extend(SubclassSizeAOTOutput(desc, i) for i, _ in sizes) # type: ignore[misc] + out[1].extend(SubclassStrideAOTOutput(desc, i) for i, _ in strides) # type: ignore[misc] + + xs_inner: list[FxValue] = [] + descs_inner: list[AOTDescriptor] = [] + + for x, desc in zip(wrapped_args, wrapped_args_descs): + # pyrefly: ignore [bad-argument-type] + flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner)) + + return xs_inner, descs_inner + + +# subclass_metas is needed at runtime to compute which indices are symints in +# the outer_size/outer_stride +def runtime_unwrap_tensor_subclasses( + wrapped_args: list[Union[Tensor, int]], + *, + append_symints: bool, + subclass_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = None, +): + def flatten_subclass(x: Tensor, meta: Optional[SubclassCreationMeta], *, out): + if not is_traceable_wrapper_subclass(x): + out.append(x) + return out + + assert isinstance(x, Tensor) + + attrs, _ = x.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(x, attr) + # pyrefly: ignore [missing-attribute] + inner_meta = meta.attrs.get(attr) + flatten_subclass(inner_tensor, inner_meta, out=out) + + if append_symints: + assert isinstance(meta, SubclassCreationMeta) + # outer_size + size = x.size() + symint_placeholders = compute_symint_placeholders(meta.outer_size) + assert len(size) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(size, symint_placeholders) if is_symint] + ) + + # outer_stride + stride = x.stride() + symint_placeholders = compute_symint_placeholders(meta.outer_stride) + assert len(stride) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(stride, symint_placeholders) if is_symint] + ) + return out + + xs_inner: list[Union[int, Tensor, SymInt]] = [] + + if append_symints: + assert subclass_metas is not None + + for idx, x in enumerate(wrapped_args): + if not is_traceable_wrapper_subclass(x): + xs_inner.append(x) + continue + + if subclass_metas is None: + get_plain_tensors(typing.cast(Tensor, x), out=xs_inner) + else: + meta = subclass_metas[idx] + assert isinstance(meta, SubclassCreationMeta) + flatten_subclass(typing.cast(Tensor, x), meta, out=xs_inner) + + return xs_inner + + +def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args): + ret_unwrapped = [] + ret_indices_to_original = [] + for i, a in enumerate(wrapped_args): + a_unwrapped, _ = unwrap_tensor_subclasses( + [a], [DummyAOTInput(9999)], append_symints=False + ) + ret_unwrapped.extend(a_unwrapped) + n = len(a_unwrapped) + ret_indices_to_original.extend([i] * n) + + return ret_unwrapped, ret_indices_to_original + + +def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices): + static_input_indices = set(static_input_indices) + new_ind = 0 + remapped_static_indices = [] + for i, arg in enumerate(wrapped_args): + num_indices = 1 + if is_traceable_wrapper_subclass(arg): + num_indices = ( + len(get_plain_tensors(typing.cast(Tensor, arg), out=[])) + + len(enumerate_filter_symints(arg.size())) + + len(enumerate_filter_symints(arg.stride())) + ) + + for _ in range(num_indices): + if i in static_input_indices: + remapped_static_indices.append(new_ind) + + new_ind += 1 + + return remapped_static_indices + + +# Turns a flattened list of tensor arguments into (maybe) subclass tensors. +# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in. +def wrap_tensor_subclasses( + unwrapped_args: Union[tuple[Any, ...], list[Any]], + *, + subclass_metas: list[Union[PlainTensorMeta, SubclassCreationMeta]], + num_fw_outs_saved_for_bw: Optional[int] = None, + included_subclass_symints: bool = False, + is_runtime: bool = False, + make_subclass_override: Optional[Callable] = None, +) -> tuple[Any, ...]: + wrapped_args = [] + num_args_tallied = 0 + for subclass_meta in subclass_metas: + if isinstance(subclass_meta, PlainTensorMeta): + wrapped_args.append(unwrapped_args[subclass_meta.unwrapped_idx]) + num_args_tallied += 1 + else: + assert isinstance(subclass_meta, SubclassCreationMeta) + assert subclass_meta.included_subclass_symints == included_subclass_symints + + if make_subclass_override: + wrapped_args.append( + make_subclass_override(subclass_meta, is_runtime, unwrapped_args) + ) + else: + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) + num_args_tallied += subclass_meta.arg_count + + # Note: [Partitioner handling for Subclasses, Part 2] + # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw, + # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them. + # + # When this function is called at runtime in the forward, + # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs. + # + # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen? + # Answer: we do it **inside of our compiled autograd.Function**. + # This seems like morally the right place: autograd happens above subclass desugaring, + # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors. + # + # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph + # into a forward and backward graph, we end up with some activations that show up as extra outputs + # in the compiled forward graph, that are **not** user outputs. + # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses. + # + # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`), + # we computed subclass metadata on every forward output, but this did **not** include activations + # created by the partitioner. + # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), + # but `subclass_metas` will only correspond to subclass metadata on `user_fw_outs`. + # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). + if num_fw_outs_saved_for_bw is not None: + assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( + f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal " + f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of " + f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})" + ) + activations = unwrapped_args[num_args_tallied:] + if isinstance(wrapped_args, tuple) and isinstance(activations, tuple): + return wrapped_args + activations + return tuple(list(wrapped_args) + list(activations)) + else: + assert len(unwrapped_args) == num_args_tallied, ( + f"Expected {len(unwrapped_args)} == {num_args_tallied}" + ) + return tuple(wrapped_args) + + +# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses. +# This function carefully handles the inference vs. joint cases: +# - when is_joint_structure is True, args is (primals, tangents) +# - when is_joint_structure is False, args is [*primals] +def wrap_tensor_subclasses_maybe_joint( + unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta +) -> Union[tuple[Any, ...], list[Any]]: + # Since this function is reused for both inference and joint graphs, + if is_joint_structure: + assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 + assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( + unwrapped_args[1], (tuple, list) + ) + primals, tangents = unwrapped_args[0], unwrapped_args[1] + wrapped_primals = wrap_tensor_subclasses( + primals, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + wrapped_tangents = wrap_tensor_subclasses( + tangents, + subclass_metas=meta.subclass_tangent_meta, + included_subclass_symints=False, + ) + return (wrapped_primals, wrapped_tangents) + else: + wrapped_args = wrap_tensor_subclasses( + unwrapped_args, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + return wrapped_args + + +def compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata: ViewAndMutationMeta, + inner_metadata: ViewAndMutationMeta, +) -> list[int]: + # Note: [Recomputing subclass mutation handling] + # + # Generally, if a subclass requires grad, its components will not require grad. + # But for the purposes of tracking returned tensors, we should treat those component + # tensors as if they require grad. + # + # For example, if the subclass tensor requires grad and will be mutated in a way that + # requires us to handle the mutation outside of the graph, we need to return it + # from the forward graph. The inner_meta data won't consider the component tensors + # as if they need to be returned, because they don't require grad; but really, we + # should handle those tensors the same way we handle the subclass tensor itself; i.e. + # if we'd include the subclass tensor as part of the outputs, then we should also + # include the component tensors. + # + # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs + # from the outer subclass tensors and propagating + + updated_input_info = [] + inner_idx = 0 + if not fw_metadata.subclass_inp_meta: + # Sometimes we don't have subclass info, e.g. synthetic_base codepaths + return inner_metadata.mutated_inp_runtime_indices + assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info) + for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta): + if isinstance(inp_meta, PlainTensorMeta): + assert outer_idx < len(fw_metadata.input_info) + if inner_metadata is not None: + assert inner_idx < len(inner_metadata.input_info) + assert ( + inner_metadata.input_info[inner_idx] + == fw_metadata.input_info[outer_idx] + ) + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + else: + assert inp_meta.original_subclass is not None + for _ in range(inp_meta.arg_count): + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + if inner_metadata is not None: + assert len(inner_metadata.input_info) == len(updated_input_info) + + return [ + i + for i, inp in enumerate(updated_input_info) + if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1255a6de8bf6e8f2d695c12c464be9c58aa171f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py @@ -0,0 +1,771 @@ +# mypy: allow-untyped-defs +""" +Contains various utils for AOTAutograd, including those for handling collections. +""" + +import copy +import dataclasses +import logging +import operator +import warnings +from collections.abc import Callable +from contextlib import nullcontext +from functools import wraps +from typing import Any, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import py_sym_types + +from .descriptors import AOTOutput + + +KNOWN_TYPES = [ + torch.Tensor, + BackwardState, + int, + str, + float, + bool, + type(None), + *py_sym_types, + FakeScriptObject, + torch.ScriptObject, +] + +original_zip = zip + +aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") +annotation_log = getArtifactLogger(__name__, "annotation") + + +def strict_zip(*iterables, strict=True, **kwargs): + if not strict: + return original_zip(*iterables, **kwargs) + + length = len(iterables[0]) + for iterable in iterables[1:]: + if len(iterable) != length: + raise ValueError( + "The iterables have different lengths and strict mode is enabled." + ) + + return original_zip(*iterables, **kwargs) + + +def _get_symint_hints(exprs): + """ + Get the hints of a list/tuple of int/SymInt. + """ + if isinstance(exprs, (list, tuple)): + return type(exprs)(_get_symint_hints(e) for e in exprs) + elif isinstance(exprs, torch.SymInt): + return exprs.node.shape_env.size_hint(exprs.node.expr) + else: + return exprs + + +def partial_flatten_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj): + return { + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) + } + elif isinstance(obj, (list, tuple)): + return obj.__class__([partial_flatten_asdict(item) for item in obj]) + elif isinstance(obj, dict): + return {k: partial_flatten_asdict(v) for k, v in obj.items()} + else: + return obj + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +def _get_autocast_states(): + return [ + torch.is_autocast_enabled("cuda"), + torch.is_autocast_enabled("cpu"), + torch.get_autocast_dtype("cuda"), + torch.get_autocast_dtype("cpu"), + torch.is_autocast_cache_enabled(), + ] + + +def make_boxed_func(f): + @simple_wraps(f) + def g(args): + return f(*args) + + g._boxed_call = True # type: ignore[attr-defined] + return g + + +def make_boxed_compiler(compiler): + @wraps(compiler) + def f(fx_g, inps): + out_f = compiler(fx_g, inps) + fx_g = make_boxed_func(out_f) + return fx_g + + return f + + +def call_func_at_runtime_with_args( + f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False +): + if not steal_args: + args = list(args) + assert isinstance(args, list) + + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + if getattr(f, "_boxed_call", False): + out = normalize_as_list(f(args)) + else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + warnings.warn( + "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " + "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " + "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.", + stacklevel=2, + ) + out = normalize_as_list(f(*args)) + return out + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec: Optional[pytree.TreeSpec] = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple: Optional[bool] = ( + None # if the output spec is a tuple/list, we won't bother unflattening it. + ) + is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec + + def set(self, spec: pytree.TreeSpec) -> None: + assert self.spec is None or self.spec == spec + assert spec is not None + self.spec: pytree.TreeSpec = spec + if self.spec.type in {tuple, list} and all( + child.is_leaf() for child in spec.children() + ): + self.is_simple = True + if self.spec.is_leaf(): + self.is_really_simple = True + + def unflatten(self, x: list[Any]) -> Any: + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + assert self.spec is not None + return pytree.tree_unflatten(x, self.spec) + + +# Creates a function that returns flattened inputs and outputs +# Also returns the output tree spec, which is needed to recover the "unflattened" +# output tree structure later. +def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]: + if kwargs is None: + kwargs = {} + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec + args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + # Can't use functools.wraps here because the wrapper has different + # calling convention + if hasattr(fn, "_orig_mod"): + flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined] + + return flat_fn, out_spec + + +# This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). +# When tracing the joint forward + backward, for any inputs in the graph that are mutated, +# we need to clone them first (and similarly for metadata-only mutations, we need to view them first). +# The idea is that when we trace the backward, we need to pass in the *original* primals +# to autograd.grad(), before they were mutated. +# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them. +# This means that "idx" here represents the index of the (potentially) synthetic base. +# What we need to do is: +# (1) map the current (post-synthetic-base calling convention) input argument index +# to int index pre-synthetic-base-calling-convention. +# (2) There could be multiple, if this index corresponds to a synthetic base +# that has multiple input aliases. +# (3) If any of those corresponding inputs get metadata mutations, then we clone the base. +def maybe_to_fresh_input(idx, t, meta): + if not isinstance(t, torch.Tensor): + return t + if idx in meta.mutated_inp_runtime_indices: + # We only need to bother cloning mutated inputs that participate in autograd. + if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the mutation + return t.clone() + if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the metadata mutation + return t.view(t.shape) + return t + + +def is_with_effects(node): + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.with_effects + ): + return True + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(node.args[1]) + return effects is not None + return False + + +def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): + # Remove the tokens from the inputs/outputs of the graph since inductor does + # not want these extra inputs/outputs, and replace them with + # _make_token() to create a token, and _sink_tokens() to collect the + # tokens. See Note [Side-Effectful Tokens in AOTAutograd] + # Logic: + # 1. In the case of with_effects: + # Before: + # ``` + # def forward(self, token, arg1_1): + # with_effects = torch.ops.higher_order.with_effects(token, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # return (getitem, getitem_1) + # ``` + # + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # with_effects = torch.ops.higher_order.with_effects(_make_token_default, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]); + # return (getitem_1,) + # ``` + # + # 2. In the case of an invoke_subgraph node, we will use the + # InvokeSubgraphCache to determine if the subgraph has effects. Then we will + # turn it into a `with_effects` node. This is so that at the toplevel graph, + # the nodes will have the correct with_effects threading. We will apply this + # pass recursively to submodules so the tokens will be removed from the + # subgraph's inputs. + # + # Before: + # ``` + # def forward(self, token, arg1_1): + # repeated_subgraph0 = self.repeated_subgraph0 + # invoke_subgraph = torch.ops.higher_order.invoke_subgraph( + # repeated_subgraph0, 'subgraph_0', token, x, arg1_1) + # getitem = invoke_subgraph[0] + # getitem_1 = invoke_subgraph[1] + # return (getitem, getitem1) + # ``` + # + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # repeated_subgraph0 = self.repeated_subgraph0 + # with_effects_1 = torch.ops.higher_order.with_effects( + # _make_token_default, torch.ops.higher_order.invoke_subgraph, + # repeated_subgraph0, 'subgraph_0', arg1_1) + # getitem = with_effects_1[0] + # getitem_1 = with_effects_1[1]; with_effects_1 = None + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]) + # return (getitem_1,) + # ``` + # + # 3. The toplevel module should have the following invariants: + # forward: + # expected_num_erased_inputs == len(fw_metadata.tokens) + # expected_num_erased_outputs == len(fw_metadata.tokens) + # backward: + # expected_num_erased_inputs == fw_metadata.num_backward_tokens + # expected_num_erased_outputs == fw_metadata.num_backward_tokens + num_forward_tokens = len(fw_metadata.tokens) + num_backward_tokens = fw_metadata.num_backward_tokens + + def replace_input_token_with_make_token(module, node): + with module.graph.inserting_before(node): + new_token_node = module.graph.call_function( + torch.ops.prims._make_token.default, () + ) + new_token_node.meta["val"] = torch.tensor([]) + new_token_node.meta["tensor_meta"] = torch.tensor([]) + node.replace_all_uses_with(new_token_node) + module.graph.erase_node(node) + + def get_output_tokens(node: torch.fx.Node) -> set[torch.fx.Node]: + output_tokens = set() + for user in list(node.users.keys()): + # Check if this is a getitem accessing index 0 (the token) + if ( + user.op == "call_function" + and user.target is operator.getitem + and len(user.args) > 1 + and user.args[1] == 0 + ): + # Check if this getitem is used in an output + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.add(user) + return output_tokens + + def _unlift_tokens_from_module_helper( + module: torch.fx.GraphModule, + subgraph_str: str, + expected_num_erased: Optional[int], + ): + input_token_nodes = set() + output_token_nodes = set() + + for node in module.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.with_effects + ): + if node.args[0].op == "placeholder": + input_token_nodes.add(node.args[0]) + replace_input_token_with_make_token(module, node.args[0]) + + tokens_from_with_effects = get_output_tokens(node) + output_token_nodes = output_token_nodes | tokens_from_with_effects + + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + subgraph_node, identifier, *operands = node.args + + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + effects = None + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = ( + tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects is not None: + # Wrap invoke_subgraph with with_effects + # Before: invoke_subgraph(subgraph, id, token, *args) -> (token_out, result) + # After: with_effects(token, invoke_subgraph, subgraph, id, *args) -> (token_out, result) + # + # Note: The subgraph itself will be unlifted separately when we iterate + # through named_modules() below. + + num_tokens = len(effects) + assert num_tokens == 1, "Multiple token subgraph NYI" + token_args = operands[:num_tokens] + non_token_args = operands[num_tokens:] + + # Create with_effects wrapper around invoke_subgraph + # with_effects(token, op, *args) where op is invoke_subgraph + # Pass the subgraph and non-token args to invoke_subgraph + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + torch.ops.higher_order.with_effects, + ( + token_args[0], # pyrefly: ignore[bad-argument-type] + torch.ops.higher_order.invoke_subgraph, + subgraph_node, + identifier, + *tuple(non_token_args), + ), + ) + node.replace_all_uses_with(new_node) + new_node.meta = node.meta + module.graph.erase_node(node) + + for token in token_args: + if token.op == "placeholder": + input_token_nodes.add(token) + replace_input_token_with_make_token(module, token) + + # Get output tokens from the new with_effects node + tokens_from_invoke_subgraph = get_output_tokens(new_node) + output_token_nodes = ( + output_token_nodes | tokens_from_invoke_subgraph + ) + + output_node = next(reversed(module.graph.find_nodes(op="output"))) + assert output_node is not None + with module.graph.inserting_before(output_node): + module.graph.call_function( + torch.ops.prims._sink_tokens.default, + (list(output_token_nodes),), + ) + new_out_args = tuple( + [out for out in output_node.args[0] if out not in output_token_nodes] + ) + output_node.args = (new_out_args,) + + if expected_num_erased: + assert len(input_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_inputs:{len(input_token_nodes)} " + f"{input_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) + assert len(output_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_outs:{len(output_token_nodes)} " + f"{output_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) + + module.recompile() + + def unlift_tokens_from_module(module, subgraph_str, expected_num_erased): + for name, m in module.named_modules(): + if isinstance(m, torch.fx.GraphModule): + if name == "": + _unlift_tokens_from_module_helper( + m, subgraph_str, expected_num_erased + ) + else: + # Subgraph -- we may or may not have effects applied + _unlift_tokens_from_module_helper(m, f"{subgraph_str}_{name}", None) + + if num_forward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Forward graph before unlifting tokens", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + unlift_tokens_from_module( + fw_module, + "forward", + num_forward_tokens, + ) + + if bw_module is not None and num_backward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Backward graph before unlifting tokens", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + unlift_tokens_from_module(bw_module, "backward", num_backward_tokens) + + # This is sad, but we need to update the metadata to get rid of + # the tokens. + fw_metadata.tokens = {} + fw_metadata.num_backward_tokens = 0 + + +def root_module_when_exporting_non_strict(flat_fn): + # When exporting in non-strict mode, we wrap the root module in a specific pattern. + # See `_aot_export_non_strict` in torch.export._trace.py. + # We look for that wrapping pattern here. + if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"): + return flat_fn._orig_mod._export_root + else: + return None + + +def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool: + # For now, assume that if nn_module_stack_metadata is populated, this + # node is from the forward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this by walking + # the descendants of graph inputs corresponding to fwd inputs, didn't + # seem obvious at first glance on how to partition graph inputs into + # fwd vs bwd without relying on string names. + return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta + + +def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool: + # For now, assume that if nn_module_stack_metadata is not populated, + # this node is from the backward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this, same + # as with the forward. + return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta + + +def _collect_fwd_nodes_from_subgraph( + fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node] +) -> None: + """Collect forward nodes from a single subgraph into the global mapping.""" + for node in fx_g.graph.nodes: + if not _is_forward_node_with_seq_nr(node): + continue + seq_nr = node.meta["seq_nr"] + if seq_nr in fwd_seq_nr_to_node: + # If we already saw an op with the current `seq_nr`, that means + # that the current op did not create an autograd node, and there + # is no corresponding backward node, so we skip. + continue + fwd_seq_nr_to_node[seq_nr] = node + + +def _copy_metadata_to_bw_nodes_in_subgraph( + fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node] +) -> None: + """Copy metadata from forward nodes to backward nodes in a single subgraph.""" + for node in fx_g.graph.nodes: + annotation_log.debug("node: %s", node.name) + seq_nr = node.meta.get("seq_nr") + annotation_log.debug("seq_nr: %s", seq_nr) + + if not _is_backward_node_with_seq_nr(node): + continue + + # We exclude gradient accumulation nodes from copying tags + if node.meta.get("is_gradient_acc", False): + annotation_log.debug("is_gradient_acc") + continue + + # fwd_node should always exist, but handle non-existence just in case + fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"]) + if fwd_node is not None: + node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack") + node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") + # TODO: better to change to a specific field of custom? + custom = fwd_node.meta.get("custom") + if custom is not None: + node.meta["custom"] = copy.deepcopy(custom) + + +def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: + """ + Input: `fx_g` which contains the joint fwd+bwd FX graph created by + aot_autograd. + + This function walks the graph and copies over metadata from forward nodes + to backward nodes, using the `seq_nr` field as a one-to-many mapping + from forward node to backward node. This metadata is useful for performance + profiling and debugging. + + This function supports matching forward and backward nodes across different + subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes + in any submodule to match forward nodes in any submodule. + """ + + # Build a global mapping of seq_nr to forward nodes across all subgraphs + fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {} + + # First pass: collect all forward nodes from all subgraphs + for submod in fx_g.modules(): + if isinstance(submod, torch.fx.GraphModule): + _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node) + + if annotation_log.isEnabledFor(logging.DEBUG): + for k, v in fwd_seq_nr_to_node.items(): + annotation_log.debug("forward:: key: %s, value: %s", k, v) + + # Second pass: copy metadata to backward nodes in all subgraphs + # using the global forward mapping + for submod in fx_g.modules(): + if isinstance(submod, torch.fx.GraphModule): + _copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node) + + +def register_buffer_assignment_hook(mod, assigned_buffers): + """ + Register a hook that intercepts buffer assignments. + This is used to detect when a buffer is assigned to, and then we can + map that buffer to the corresponding proxy node in the graph. + """ + + def _map_assigned_buffer_to_proxy(_mod, name, buffer): + # We intercept buffer assignments on the root module through this hook. + if _mod._buffers is mod._buffers: + # either buffer is a functional tensor, which wraps a fake tensor + if isinstance(buffer, FunctionalTensor): + buffer = buffer.from_functional() + # or buffer is a fake tensor + assert isinstance(buffer, FakeTensor) + # The fake tensor in turn is associated with a proxy node. + proxy_mode = torch.fx.experimental.proxy_tensor.get_proxy_mode() + assert proxy_mode is not None + proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot( + buffer, proxy_mode.tracer + ).proxy.node + # We map the assigned buffer to this proxy node. + assigned_buffers[name] = proxy.name + return buffer + + return torch.nn.modules.module.register_module_buffer_registration_hook( + _map_assigned_buffer_to_proxy + ) + + +def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool: + """ + Checks if the module contains any metadata mutation ops. + """ + for node in module.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.inplace_view in node.target.tags + ): + return True + return False + + +def get_cuda_generator_meta_val(device_idx: int): + """ + Get a generator value to use as a meta val + + newly cloned generator will not contain tensors. it is only Generators that are + registered to a CUDAGraph that contain tensors. since this does not contain Tensor + it is fine to use in the meta. + """ + return torch.cuda.default_generators[device_idx].clone_state() + + +def top_saved_tensors_hooks(): + return torch._C._autograd._top_saved_tensors_default_hooks(True) + + +def saved_tensors_hooks_are_inlineable(hooks) -> bool: + if not hooks: + return False + pack, unpack = hooks + return isinstance(pack, torch.fx.GraphModule) and isinstance( + unpack, torch.fx.GraphModule + ) + + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_S = TypeVar("_S") + + +def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]: + @wraps(f) + @simple_wraps(f) + def inner(*args, **kwargs): + # pyrefly: ignore [invalid-param-spec] + return f(*args, **kwargs)[0] + + # pyrefly: ignore [bad-return] + return inner + + +_P2 = ParamSpec("_P2") +_R = TypeVar("_R") +_R2 = TypeVar("_R2") + + +def simple_wraps( + f: Callable[_P, _R], +) -> Callable[[Callable[_P2, _R2]], Callable[_P2, _R2]]: + # NB: omit ('__module__', '__name__', '__qualname__') for ease of + # debugging + return wraps(f, assigned=("__doc__", "__annotations__", "__type_params__")) + + +def call_and_expect_output_descs(fn, args): + outs_pair = fn(*args) + assert isinstance(outs_pair, tuple) and len(outs_pair) == 2, (fn, outs_pair) + outs, outs_descs = outs_pair + # The Tensor tests protects against the test when there are no outputs + out_vals, out_spec = pytree.tree_flatten(outs) + out_desc_vals, out_desc_spec = pytree.tree_flatten(outs_descs) + assert out_spec == out_desc_spec, ( + fn_wrappers(fn), + outs, + outs_descs, + out_spec, + out_desc_spec, + ) + assert not any(isinstance(x, AOTOutput) for x in out_vals), ( + fn_wrappers(fn), + outs, + outs_descs, + out_vals, + ) + assert all( + isinstance(d, AOTOutput) + for (x, d) in zip(out_vals, out_desc_vals) + if isinstance(x, (torch.Tensor, torch.SymInt)) or type(x) is int + ), (fn_wrappers(fn), outs, outs_descs, out_vals, out_desc_vals) + return outs_pair + + +def fn_wrappers(fn): + fns = [fn] + f = fn + while hasattr(f, "__wrapped__"): + f = f.__wrapped__ + fns.append(f) + return fns diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174a725775ef213c21d279f7fbd2c95281a0414f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d650bd755d3d94c03e2e14588e2f05d37e4fdce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c54a98401a6f405438cee62bf5ea32a37185a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..384120b1cbe7f651b159b28e0984fcb28d8b025d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d736dd0d9d5e23be90c6917b580b5fa03fb80f14 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59b4cf886ca3eb65755b8fa5006f5be585b56ab7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa792305c94a32d7781b9427a99205890b7e521b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd8eb8b5d18a87ed3edc4c0c87f431bd1afd5950 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..253d97a357a4f64c29dd1b7d714b9aae9ab0ce6d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b59aaf3a0cd60c9886335ac09754e2bf00cfda Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c825cf87875d2e8394baff05e0c60af1e17e764 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b068c841a95f3cefc8789112e76244944e41e39 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d173ea1091d63145b560d2d159bb71d295ca7d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa61eec9688b36d6798535195d6be99f4612221 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1490730f132d803e5dbdb0dc58c98dbbe005a6f0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/local_map.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d26c908834fe35f2c44fa5cfda3effdcb0aa1e1a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78e32821f78f47efea1f5faeb906224030700836 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42f76b2f697f5b30bee62973a8669747176718e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/partitioner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f9bf49057731a73a54d78913d7aab3ade317b6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/print.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2882a80eb11e35b0fff2f704f1c2b72b29a5856 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9280fc2b7ebc818ab7bf25f7f4252ca15a986ca1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3f94283c4a6a6d85a5ab75f63cb52dacfbbe21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03542f10143dd2ff23d382c7160968918e07e3af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef74035f6a7bf83aa46b9742333208ae83902ac0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c789d1d11362dbfbe5237f299837a1d3dc1b40a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f1ad5054514d2414b225c4cc8acb08a6348b782 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20dc0c2142685bf4ceccb80cd7253fcdbe77f2b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce989db9381c76d5dbeb6f65a7883c23efab3d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d624fe65430dab3b2d8287f7bd1582ca98d4548 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e089cfbe222fa918bf2c303a8360d29f062be00 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/async_compile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdfed38891999f18f4b56ce1be824c61dc17e829 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/augmented_graph_helper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..109cb3c520c442583ef7ba549554006d04737096 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/choices.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5908b39e433b610019dc5c9ada19c089a4617364 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19d68e32a6b68b25a95bb1d01b6037840927c050 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afc4df7d285fc41a1ec468def4521ee23a82bdcb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f0b6593e65361e934b9d5a8fcb0280b5c4316b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd78ffb060946fd8cc051e3e4f65242c58fb64be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb56810eeeeccbece7a7af3331efac686fe56519 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f53e4887476ab97f25afa8bcff0dabc4258c9d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fde3739d5b1092a16fea3ed56c897f2f6f31fa8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef8cc54ac5a4d9a85379f7a4b31527f65592f76f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48addd8c6c60c00df3d9c8b4882c1103fd69c6c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01892024e99a47c9f091e350fa01768d45ae3c7b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d13aef26cacf4fb4aad704bb246e568f3db672ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/remote_gemm_autotune_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abb657a1a16855a69e65450b07f47696a86cdb98 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/sizevars.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67fa727aee9624c9a9290183088634e065db9135 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09d8ef3eb2c038c264f93296e185329e976a1aef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11f10ba2da9cad4cd84739167263b0c0689d4ba0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5edf1e7fd26d3f902d15af82a3d0c615d20c6f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/device_info.py @@ -0,0 +1,216 @@ +import logging +from dataclasses import dataclass +from typing import Optional, Union + +import torch + + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DeviceInfo: + """ + Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not, + then the higher number is reported. Sparsity is not considered. + + + Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace. + For example, + """ + + tops: dict[Union[torch.dtype, str], float] + dram_bw_gbs: float + dram_gb: float + + +# Indexing is based on `torch.cuda.get_device_name()` +# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on. +_device_mapping: dict[str, DeviceInfo] = { + # Source: + # @lint-ignore https://www.nvidia.com/en-us/data-center/h100/ + "NVIDIA H100": DeviceInfo( + tops={ + torch.float64: 67.0, + torch.float32: 67.5, + "torch.tf32": 156.0, + torch.bfloat16: 1979.0, + torch.float16: 1979.0, + torch.float8_e8m0fnu: 3958.0, + torch.float8_e8m0fnu: 3958.0, + torch.float8_e4m3fnuz: 3958.0, + torch.float8_e5m2: 3958.0, + torch.float8_e5m2fnuz: 3958.0, + torch.float8_e8m0fnu: 3958.0, + torch.int8: 3958.0, + }, + dram_bw_gbs=3350, + dram_gb=80, + ), + # Source: + # @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/ + # nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + "NVIDIA A100": DeviceInfo( + tops={ + torch.float64: 19.5, + torch.float32: 19.5, + torch.bfloat16: 312.5, + torch.float16: 312.5, + # Not in datasheet: float8 + torch.int8: 624.0, + "torch.tf32": 156.0, + }, + dram_bw_gbs=2039.0, + dram_gb=80.0, + ), + # Source: + # @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet + "NVIDIA L4": DeviceInfo( + tops={ + # This is a guess, not in datasheet + torch.float64: 15.1, + torch.float32: 30.3, + "torch.tf32": 120.0, + torch.bfloat16: 242.0, + torch.float16: 242.0, + torch.float8_e8m0fnu: 485.0, + torch.float8_e8m0fnu: 485.0, + torch.float8_e4m3fnuz: 485.0, + torch.float8_e5m2: 485.0, + torch.float8_e5m2fnuz: 485.0, + torch.float8_e8m0fnu: 485.0, + torch.int8: 485.0, + }, + dram_bw_gbs=3350, + dram_gb=24, + ), + # Source: + # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\ + # /instinct-tech-docs/product-briefs/amd-instinct-mi350x-gpu-brochure.pdf + "AMD MI350X": DeviceInfo( + tops={ + torch.float64: 72.1, + torch.float32: 144.2, + # not specified, fall back to float32 numbers + "torch.tf32": 144.2, + torch.bfloat16: 2309.6, + torch.float16: 2309.6, + torch.float8_e8m0fnu: 4614.0, + torch.float8_e8m0fnu: 4614.0, + torch.float8_e4m3fnuz: 4614.0, + torch.float8_e5m2: 4614.0, + torch.float8_e5m2fnuz: 4614.0, + torch.float8_e8m0fnu: 4614.0, + torch.int8: 4614.0, + }, + dram_bw_gbs=8000.0, + dram_gb=288.0, + ), + # Source: + # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\ + # /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf + "AMD MI300A": DeviceInfo( + tops={ + torch.float64: 122.6, + torch.float32: 122.6, + "torch.tf32": 490.3, + torch.bfloat16: 980.6, + torch.float16: 980.6, + torch.float8_e8m0fnu: 1961.2, + torch.float8_e8m0fnu: 1961.2, + torch.float8_e4m3fnuz: 1961.2, + torch.float8_e5m2: 1961.2, + torch.float8_e5m2fnuz: 1961.2, + torch.float8_e8m0fnu: 1961.2, + torch.int8: 1961.2, + }, + dram_bw_gbs=5300.0, + dram_gb=128.0, + ), + # Source: + # @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\ + # instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf + "AMD MI300X": DeviceInfo( + tops={ + torch.float64: 163.4, + torch.float32: 163.4, + "torch.tf32": 653.7, + torch.bfloat16: 1307.4, + torch.float16: 1307.4, + torch.float8_e8m0fnu: 2614.9, + torch.float8_e8m0fnu: 2614.9, + torch.float8_e4m3fnuz: 2614.9, + torch.float8_e5m2: 2614.9, + torch.float8_e5m2fnuz: 2614.9, + torch.float8_e8m0fnu: 2614.9, + torch.int8: 2614.9, + }, + dram_bw_gbs=5300.0, + dram_gb=192.0, + ), + # Source: + # @lint-ignore https://www.amd.com/content/dam/amd/\ + # en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf + "AMD MI210X": DeviceInfo( + tops={ + torch.float64: 45.3, + torch.float32: 45.3, + # not specified, fall back to float32 numbers + "torch.tf32": 45.3, + torch.bfloat16: 181.0, + torch.float16: 181.0, + # not specified, fall back to float16 numbers + torch.float8_e8m0fnu: 181.0, + torch.float8_e8m0fnu: 181.0, + torch.float8_e4m3fnuz: 181.0, + torch.float8_e5m2: 181.0, + torch.float8_e5m2fnuz: 181.0, + torch.float8_e8m0fnu: 181.0, + torch.int8: 181.0, + }, + # pcie4.0x16 + dram_bw_gbs=1600.0, + dram_gb=64.0, + ), +} +_device_mapping["AMD INSTINCT MI350X"] = _device_mapping["AMD MI350X"] +_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"] +_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"] + + +def lookup_device_info(name: str) -> Optional[DeviceInfo]: + """ + Problem: when diffing profiles between amd and nvidia, we don't have access to the device information + of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated + to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices. + If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping. + name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name(). + """ + return _device_mapping.get(name) + + +def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]: + """ + Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device + is not in the datasheet list above. + """ + name: Optional[str] = torch.cuda.get_device_name() + if name is None: + log.info("No device found, returning None") + return None + device_info = lookup_device_info(name) + if device_info is None: + log_str = f"Device {name} not in datasheet, returning None" + log.info(log_str) + return None + if dtype not in device_info.tops: + log.info( + "Device %s does not have a datasheet entry for %s, returning None", + name, + dtype, + ) + return None + + return device_info.tops[ + "torch.tf32" if dtype == torch.float32 and is_tf32 else dtype + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6ec39003bdb2447b72c9aed892e1db01474be0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/profile_analysis.py @@ -0,0 +1,823 @@ +import json +import logging +import math +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info +from torch._inductor.utils import tabulate_2d, zip_dicts +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils.flop_counter import flop_registry + + +log = logging.getLogger(__name__) + + +ATEN_PREFIX = "aten::" + + +@dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +# adapters convert the json trace into a format that works with flops_counter +ArgsType = tuple[tuple[Any, ...], dict[Any, Any]] +AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType] +adapters_map: dict[str, AdapterType] = {} + + +def parse_list(lst: str) -> list[int]: + lst = lst.replace("[", "").replace("]", "") + substrings = lst.split(",") + + return [int(substring.strip()) for substring in substrings] + + +def register_adapter( + aten: Union[str, list[str]], +) -> Callable[ + [AdapterType], + AdapterType, +]: + def decorator(func: AdapterType) -> AdapterType: + # pyrefly: ignore [unknown-name] + global _adapters_map + + if isinstance(aten, str): + adapters_map[aten] = func + else: + for at in aten: + adapters_map[at] = func + return func + + return decorator + + +@register_adapter(["_slow_conv2d_forward"]) +def _slow_conv2d_adapter( + shapes: tuple[Any, ...], concrete: tuple[Any, ...] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes) + tmp.append(False) + tmp2 = list(concrete) + if len(tmp2) < 5: + raise ParseException("slow conv2d has less than 5 concrete inputs") + tmp2[3] = tmp2[4] + return conv_adapter(tuple(tmp), tuple(tmp2)) + + +@register_adapter( + ["convolution", "_convolution", "cudnn_convolution", "convolution_overrideable"] +) +def conv_adapter( + shapes: tuple[Any, ...], concrete: tuple[Any, ...] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes) + if len(tmp) == 4: + transposed = False + elif len(tmp) > 6: + transposed = bool(tmp[6]) + tmp[6] = transposed + else: + raise ParseException(f"Convolution has the wrong number of inputs: {len(tmp)}") + + kwargs: dict[Any, Any] = {} + if not transposed: + # calculate output shape if not transposed. + def conv_out_dims(x: int, kernel: int, stride: int) -> int: + return (x - kernel) // stride + 1 + + stride = parse_list(concrete[3]) + inp = shapes[0] + w = shapes[1] + out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)] + out = [inp[0], w[0]] + out_x_y # we only need the xy values + kwargs["out_val"] = out + + return tuple(tmp), kwargs + + +def default_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + return shapes, {} + + +@register_adapter("addmm") +def addmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes)[:3] + return tuple(tmp), {} + + +@register_adapter("bmm") +def bmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes) + return tuple(tmp[:2]), {} + + +@register_adapter("baddbmm") +def baddbmm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + tmp = list(shapes)[:3] + return tuple(tmp), {} + + +@register_adapter("mm") +def mm_adapter( + shapes: tuple[Any], concrete: tuple[Any] +) -> tuple[tuple[Any], dict[Any, Any]]: + return shapes, {} + + +def _parse_kernel_name(name: str) -> Optional[str]: + """ + parse the name of the kernel from the event name. + """ + if name.startswith(ATEN_PREFIX): + return name[len(ATEN_PREFIX) :] + elif "conv" in name: + return "convolution" + elif "addmm" in name: + return "addmm" + elif "bmm" in name: + return "bmm" + elif "baddbmm" in name: + return "baddbmm" + elif "_mm" in name: + return "mm" + else: + return None + + +def _calculate_flops(event: dict[str, Any]) -> int: + """ + This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that + will support all the different backends that can generate kernels, so make sure to update this function when new + ops and backends are desired. + """ + name = event["name"] + if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0: + return event["args"]["kernel_flop"] + op_name = _parse_kernel_name(name) + if op_name is None: + return 0 + + op_obj = getattr(torch.ops.aten, op_name, None) + if op_obj is None or op_obj not in flop_registry: + return 0 + + flop_function = flop_registry[op_obj] + + if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]: + return 0 + input_shapes = event["args"]["Input Dims"] + concrete = event["args"]["Concrete Inputs"] + if op_name in adapters_map: + try: + args, kwargs = adapters_map[op_name](input_shapes, concrete) + except ParseException as e: + msg = f"Failed to parse {op_name} with {e}" + log.warning(msg) + return 0 + else: + try: + args, kwargs = default_adapter(input_shapes, concrete) + except ParseException as e: + msg = f"Failed to parse {op_name} with {e}" + log.warning(msg) + return 0 + return flop_function(*args, **kwargs) + + +def _get_size_from_string(type_string: str) -> int: + if not hasattr(torch, type_string): + return 1 + else: + return getattr(torch, type_string).itemsize + + +def _default_estimate_gb(event: dict[str, Any]) -> float: + sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"]) + bw = 0 + for size, typ in sizes_and_types: + isize = _get_size_from_string(typ) + bw += isize * math.prod(pytree.tree_flatten(size)[0]) + return bw / 1e9 + + +def _estimate_gb(event: dict[str, Any]) -> float: + """ + Our best effort to estimate the gb, should be refactored soon with MemoryCounter. + """ + name = event["name"] + if "kernel_num_gb" in event["args"] and event["args"]["kernel_num_gb"] != 0: + return event["args"]["kernel_num_gb"] + if "Input type" not in event["args"] or "Input Dims" not in event["args"]: + return 0 + op_name = _parse_kernel_name(name) + if op_name is None: + return _default_estimate_gb(event) + + op_obj = getattr(torch.ops.aten, op_name, None) + if op_obj is None: + return _default_estimate_gb(event) + + if "Input Dims" not in event["args"] or "Concrete Inputs" not in event["args"]: + return _default_estimate_gb(event) + input_shapes = event["args"]["Input Dims"] + + # NOTE these will be refactored into a similar object to FlopCounter soon + def mm_formula(M: int, N: int, K: int, size: int) -> int: + return 2 * (M * K + N * K + M * N) * size + + if op_name == "addmm": + add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0]) + add_type_size = _get_size_from_string(event["args"]["Input type"][0]) + M = input_shapes[1][0] + N = input_shapes[1][1] + assert input_shapes[1][1] == input_shapes[2][0] + K = input_shapes[2][1] + mul_type_size = _get_size_from_string(event["args"]["Input type"][1]) + return (mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size) / 1e9 + elif op_name == "mm": + M = input_shapes[0][0] + N = input_shapes[0][1] + assert input_shapes[0][1] == input_shapes[1][0] + K = input_shapes[1][1] + type_size = _get_size_from_string(event["args"]["Input type"][0]) + return mm_formula(M, N, K, type_size) / 1e9 + elif op_name == "baddbmm": + add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0]) + add_type_size = _get_size_from_string(event["args"]["Input type"][0]) + B = input_shapes[0][0] + M = input_shapes[1][1] + N = input_shapes[1][2] + K = input_shapes[2][2] + mul_type_size = _get_size_from_string(event["args"]["Input type"][1]) + return ( + B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size + ) / 1e9 + elif op_name == "bmm": + add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0]) + add_type_size = _get_size_from_string(event["args"]["Input type"][0]) + B = input_shapes[0][0] + M = input_shapes[0][1] + N = input_shapes[0][2] + K = input_shapes[1][2] + mul_type_size = _get_size_from_string(event["args"]["Input type"][1]) + return ( + B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size + ) / 1e9 + elif op_name in [ + "convolution", + "_convolution", + "cudnn_convolution", + "_slow_conv2d_forward", + ]: + concrete = event["args"]["Concrete Inputs"] + + def conv_out_dim(x: int, kernel: int, stride: int) -> int: + return (x - kernel) // stride + 1 + + stride = parse_list( + concrete[3] if op_name != "_slow_conv2d_forward" else concrete[4] + ) + inp = input_shapes[0] + w = input_shapes[1] + out_x_y = [conv_out_dim(*args) for args in zip(inp[2:], w[2:], stride)] + out = [inp[0], w[0]] + out_x_y + # each output element reads in * w * w chunk + input_reads = out[0] * out[1] * out[2] * out[3] * inp[1] * w[2] * w[3] + # Assume weights are in cache, so only read once + weight_reads = w[0] * w[1] * w[2] * w[3] + return (input_reads + weight_reads) / 1e9 + + return _default_estimate_gb(event) + + +def _create_extern_mapping( + data: dict[str, Any], +) -> defaultdict[int, list[dict[str, Any]]]: + """ + compute a mapping from external ids to non kernels, which contain the information we need to estimate flops etc + """ + extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list) + for event in data["traceEvents"]: + if ( + "args" not in event + or "External id" not in event["args"] + or event["cat"] != "cpu_op" + ): + continue + if len(extern_mapping[event["args"]["External id"]]) > 0: + raise ParseException("duplicate external id in event") + extern_mapping[event["args"]["External id"]].append(event) + return extern_mapping + + +def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]: + extern_mapping = _create_extern_mapping(data) + + for event in data["traceEvents"]: + if "cat" not in event or event["cat"] != "kernel": + continue + if "args" not in event: + raise ParseException(f"kernel has no args: {event}") + if "External id" not in event["args"]: + event_str = f"kernel has no External id: {event}" + log.info(event_str) + continue + + external_op = extern_mapping[event["args"]["External id"]][0] + flops = _calculate_flops(external_op) + if flops == 0: + flops = _calculate_flops(event) + external_op["args"]["kernel_flop"] = flops + external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op) + event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"] + event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"] + return data + + +_dtype_map = { + "float": torch.float, + "float32": torch.float, + "int": torch.int, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int, + "long": torch.long, + "long int": torch.long, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float64": torch.double, +} + + +@dataclass(frozen=True) +class KernelStats: + flops: int + bw: float + latency: float # us + achieved_flops: float + achieved_bandwidth: float + + +KernelNameMap = defaultdict[str, OrderedSet[KernelStats]] + + +@dataclass(frozen=False) +class Device: + name: str + index: int + info: Optional[DeviceInfo] + stats: KernelNameMap + + def __repr__(self) -> str: + return f"Device({self.name}, {self.index}): {self.info}" + + +DeviceMap = dict[int, Device] +Table = tuple[list[str], dict[str, list[str]]] + + +class JsonProfile: + _devices: DeviceMap + + def __init__( + self, + path: str, + benchmark_name: Optional[str] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + ): + """ + Convenience class for running common operations on chrome/perfetto json traces. + """ + self.path = path + with open(path) as f: + self.data = json.load(f) + self.events = self.data["traceEvents"] + self.benchmark_name = benchmark_name + if dtype is None: + self.dtype = None + elif isinstance(dtype, torch.dtype): + # pyrefly: ignore [bad-assignment] + self.dtype = dtype + else: + # pyrefly: ignore [bad-assignment] + self.dtype = _dtype_map.get(dtype) + self._create_devices() + + def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]: + """ + Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation. + Issues: + - converting the strings to concrete torch.dtypes + - What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype. + """ + + if ( + "Input Dims" not in event["args"] + or "Input type" not in event["args"] + or "Concrete Inputs" not in event["args"] + ): + if "bfloat16" in event["name"]: + return torch.bfloat16 + elif "float16" in event["name"]: + return torch.float16 + else: + return None + + input_sizes = event["args"]["Input Dims"] + input_types = event["args"]["Input type"] + concrete_inputs = event["args"]["Concrete Inputs"] + assert len(input_sizes) == len(input_types) + assert len(input_types) == len(concrete_inputs) + + if len(input_sizes) == 0: + raise RuntimeError("Empty input_sizes and input_types") + + biggest_size = 0 + biggest_index = 0 + for i in range(len(input_sizes)): + if concrete_inputs[i] != "": + # concrete inputs are usually small tensors, so we can just skip + continue + my_size = input_sizes[i] + total_size = sum(parse_list(my_size)) + if total_size > biggest_size: + biggest_size = total_size + biggest_index = i + ret_type = input_types[biggest_index] + if ret_type in _dtype_map: + return _dtype_map[ret_type] + raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.") + + def _create_devices(self) -> None: + self._devices = {} + for dev in self.data["deviceProperties"]: + name = dev["name"] + device_info = lookup_device_info(name) + + if device_info is None: + log.info( + "Unsupported device in profile: %s, please consider contributing to _device_mapping.", + name, + ) + self._devices[dev["id"]] = Device( + name, dev["id"], device_info, defaultdict(OrderedSet) + ) + + def calculate_flops(self, event: dict[str, Any]) -> int: + return _calculate_flops(event) + + def estimate_gb(self, event: dict[str, Any]) -> float: + return _estimate_gb(event) + + def augment_trace(self) -> None: + self.data = _augment_trace_helper(self.data) + + def _compute_stats(self) -> None: + """populates the name -> stats map""" + for event in self.events: + if "cat" not in event or "args" not in event or event["cat"] != "kernel": + continue + if "device" not in event["args"]: + continue + dev_tmp = event["args"]["device"] + if dev_tmp not in self._devices: + continue + dev = self._devices[event["args"]["device"]] + + dur = event["dur"] # us + if "kernel_flop" in event["args"]: + assert dur != 0 + # 1,000,000us/s * flop / us + op_flops = event["args"]["kernel_flop"] / (dur / 1e6) + else: + op_flops = 0 + + if "kernel_num_gb" in event["args"]: + assert dur != 0 + # 1,000,000us/s * gb = gb/s + op_gbps = event["args"]["kernel_num_gb"] / (dur / 1e6) + else: + op_gbps = 0 + + if dev.info is not None: + dtype = self.convert_dtype(event) or self.dtype + if dtype is None: + raise RuntimeError( + "dtype is not found on tensor and default dtype is not set" + ) + achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype]) + achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs + else: + achieved_flops = 0 + achieved_bandwidth = 0 + + if "name" not in event["args"]: + continue + dev.stats[event["name"]].add( + KernelStats( + flops=op_flops, + bw=op_gbps, + latency=dur, + achieved_bandwidth=achieved_bandwidth, + achieved_flops=achieved_flops, + ) + ) + + def _create_single_table(self, dev: Device) -> Table: + """Create a table with the devices mapped to indices.""" + headers = [ + "Kernel Name", + "Kernel Count", + "FLOPS", + "Kernel Reads (GB)", + "Dur (us)", + "Achieved FLOPS %", + "Achieved Bandwidth %", + ] + rows: dict[str, list[str]] = {} + + def safe_div_format(x: float, y: float) -> str: + if y == 0: + return "0.0" + return f"{x / y:.4f}" + + for kernel_name, stats_set in dev.stats.items(): + ker_count = 0 + flops = 0 + flops_count = 0 + achieved_flops = 0.0 + bw = 0.0 + bw_count = 0 + achieved_bandwidth = 0.0 + latency = 0.0 + for stats in stats_set: + if stats.flops != 0: + flops += stats.flops + achieved_flops += stats.achieved_flops + flops_count += 1 + if stats.bw != 0: + bw += stats.bw + achieved_bandwidth += stats.achieved_bandwidth + bw_count += 1 + latency += stats.latency + ker_count += 1 + assert ker_count != 0 + rows[kernel_name] = [ + str(ker_count), + safe_div_format(flops, flops_count), + safe_div_format(bw, bw_count), + safe_div_format(latency, ker_count), + safe_div_format(achieved_flops, flops_count), + safe_div_format(achieved_bandwidth, bw_count), + ] + + return headers, rows + + def _create_tables(self, devs: DeviceMap) -> dict[int, Table]: + return {idx: self._create_single_table(dev) for idx, dev in devs.items()} + + def _combine_tables( + self, table1: Table, table1_name: str, table2: Table, table2_name: str + ) -> Table: + new_headers = ( + ["Kernel Name"] + + [f"{table1_name} {head}" for head in table1[0][1:]] + + [f"{table2_name} {head}" for head in table2[0][1:]] + ) + t1_length = len(table1[0][1:]) + t2_length = len(table2[0][1:]) + new_rows = {} + + for key, row1, row2 in zip_dicts( + table1[1], + table2[1], + d1_default=["Empty"] * t1_length, + d2_default=["Empty"] * t2_length, + ): + assert row1 is not None + assert row2 is not None + new_rows[key] = row1 + row2 + return new_headers, new_rows + + def report( + self, other: Optional["JsonProfile"] = None, name_limit: int = 40 + ) -> str: + def create_ret( + table_headers: list[str], table_rows: dict[str, list[str]] + ) -> str: + table_flattened = [ + [kernel_name[:name_limit], *kernel_vals] + for kernel_name, kernel_vals in table_rows.items() + ] + return tabulate_2d(table_flattened, headers=table_headers) + + if other is not None: + self._compute_stats() + other._compute_stats() + + self_tables = self._create_tables(self._devices) + other_tables = self._create_tables(other._devices) + + self_name = ( + self.benchmark_name if self.benchmark_name is not None else "Table 1" + ) + other_name = ( + other.benchmark_name if other.benchmark_name is not None else "Table 2" + ) + + ret = [] + assert self._devices.keys() == other._devices.keys() + for device_idx, t1, t2 in zip_dicts( + self_tables, other_tables, d1_default=None, d2_default=None + ): + assert t1 is not None + assert t2 is not None + table_headers, table_rows = self._combine_tables( + t1, self_name, t2, other_name + ) + tab_string = create_ret(table_headers, table_rows) + # pyrefly: ignore [bad-argument-type] + ret.append(f"{self._devices[device_idx]}:\n{tab_string}") + return "\n".join(ret) + self._compute_stats() + + self_tables = self._create_tables(self._devices) + + ret = [] + for idx, table in self_tables.items(): + table_headers, table_rows = table + tab_string = create_ret(table_headers, table_rows) + # pyrefly: ignore [bad-argument-type] + ret.append(f"{self._devices[idx]}:\n{tab_string}") + return "\n".join(ret) + + def dump(self, out: str) -> None: + with open(out, "w") as f: + json.dump(self.data, f) + + def combine_with(self, other: "JsonProfile") -> "JsonProfile": + """ + Combine this profile with another profile by merging their trace events. + Returns a new JsonProfile object with combined data. + """ + # Create a new combined data structure + combined_data = { + "traceEvents": self.data["traceEvents"] + other.data["traceEvents"], + "deviceProperties": self.data.get("deviceProperties", []), + } + + # Merge device properties, avoiding duplicates + other_device_props = other.data.get("deviceProperties", []) + existing_device_ids = OrderedSet( + [dev["id"] for dev in combined_data["deviceProperties"]] + ) + + for device_prop in other_device_props: + if device_prop["id"] not in existing_device_ids: + combined_data["deviceProperties"].append(device_prop) + + # Copy any other top-level properties from the first profile + for key, value in self.data.items(): + if key not in combined_data: + combined_data[key] = value + + import os + + # Create a temporary file to write the combined data + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + json.dump(combined_data, tmp_file) + tmp_path = tmp_file.name + + try: + # Create new JsonProfile from the combined data + combined_profile = JsonProfile( + tmp_path, + benchmark_name=f"{self.benchmark_name or 'Profile1'}_+_{other.benchmark_name or 'Profile2'}", + dtype=self.dtype or other.dtype, + ) + return combined_profile + finally: + # Clean up temporary file + os.unlink(tmp_path) + + +class ParseException(RuntimeError): + pass + + +def main() -> None: + """ + Main function for the profile analysis script. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--diff", + nargs=5, + metavar=( + "input_file1", + "name1", + "input_file2", + "name2", + "dtype", + ), + help="Two json traces to compare with, specified as ", + ) + parser.add_argument( + "--name_limit", + type=int, + help="the maximum name size in the final report", + ) + parser.add_argument( + "--augment_trace", + "-a", + nargs=3, + metavar=("input_file", "output_file", "dtype"), + help="Augment a trace with inductor meta information. Provide input and output file paths.", + ) + parser.add_argument( + "--analysis", + nargs=2, + metavar=("input_file", "dtype"), + help="Run analysis on a single trace, specified as ", + ) + parser.add_argument( + "--combine", + nargs="+", + metavar=("input_files", "output_file"), + help="Combine multiple profiles into a single profile by merging trace events. Specify as \ + [input_file3 ...] . The last argument is the output file, all preceding arguments are \ +input files to combine.", + ) + args = parser.parse_args() + + if args.diff: + p1 = JsonProfile(args.diff[0], args.diff[1], dtype=args.diff[4]) + p1.augment_trace() + p2 = JsonProfile(args.diff[2], args.diff[3], dtype=args.diff[4]) + p2.augment_trace() + if args.name_limit: + print(p1.report(p2, name_limit=args.name_limit)) + else: + print(p1.report(p2)) + if args.analysis: + p1 = JsonProfile( + args.analysis[0], + dtype=args.analysis[1], + ) + p1.augment_trace() + if args.name_limit: + print(p1.report(name_limit=args.name_limit)) + else: + print(p1.report()) + if args.augment_trace: + p = JsonProfile(args.augment_trace[0], dtype=args.augment_trace[2]) + p.augment_trace() + p.dump(args.augment_trace[1]) + if args.combine: + input_files = args.combine[:-1] # All arguments except the last one + output_file = args.combine[-1] # Last argument is the output file + + if len(input_files) < 2: + print("Error: At least 2 input files are required for combining") + return + + # Load the first profile + combined = JsonProfile(input_files[0], dtype=None) + + # Iteratively combine with all other profiles + for input_file in input_files[1:]: + profile = JsonProfile(input_file, dtype=None) + combined = combined.combine_with(profile) + + combined.dump(output_file) + print(f"Successfully combined {', '.join(input_files)} into {output_file}") + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..0c12ca77cf2db28bbe1fc10cca44774ced5c102f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py @@ -0,0 +1,316 @@ +import json +import os +from collections.abc import Callable +from functools import partial +from typing import Any, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + AHOperation, + Choice, + CHOICE_COL, + Feedback, + FEEDBACK_COL, + get_metadata_str_from_log, +) +from torch._inductor.autoheuristic.learned_heuristic_controller import ( + LearnedHeuristicController, +) +from torch._inductor.ir import ChoiceCaller +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import get_gpu_shared_memory + + +class LocalFeedback: + """ + To be able to collect data for a choice, a function providing feedback given a choice has to be provided. + LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice + (see pad_mm.py, where the autotuning happens locally, for an example). + """ + + def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: + self.feedback_fn = feedback_fn + + def __call__(self, choice: Choice) -> Feedback: + return self.feedback_fn(choice) + + +class InconsistentMetadata(Exception): + """ + Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does + not match the metadata it would store if the file didn't exist. + """ + + +class AutoHeuristic: + """ + AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and + generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train + a heuristic (see torchgen/autoheuristic/). + """ + + collected_feedback: dict[Choice, Feedback] + + def __init__( + self, + fallback: Callable[[], Choice], + choices: list[Choice], + feedback: Optional[LocalFeedback], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + Initializes an instance of the AutoHeuristic class. + + Args: + fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or + AutoHeuristic is in data collection mode. + choices: A list of possible choices the heuristic can make. + feedback: An instance of LocalFeedback that provides feedback for a given choice. + context: Context to store with each choice and feedback. + name: A string that identifies the heuristic. + augment_context: An optional list of AHOperation instances that augment the context. + precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. + """ + self.fallback = fallback + self.choices = choices + self.feedback = feedback + self.context = context + self.name = name + self.collected_feedback = {} + self.augment_context = augment_context + self.metadata = AHMetadata( + get_gpu_shared_memory(), + torch.cuda.get_device_capability(), + self.choices, + self.name, + ) + self.precondition = precondition + + if not self.satisfies_precondition(): + return + + if torch._inductor.config.autoheuristic_log_path == "DEFAULT": + self.log_path = self.get_default_log_path() + else: + self.log_path = torch._inductor.config.autoheuristic_log_path + + if torch._inductor.config.collect_autoheuristic(self.name): + if self.feedback is not None: + for choice in self.choices: + feedback_val = self.feedback(choice) + self.save_data(choice, feedback_val) + + def satisfies_precondition(self) -> bool: + return self.precondition is None or self.precondition( + self.metadata, self.context + ) + + def get_choice(self) -> Choice: + """ + Returns the chosen option based on the value of autoheuristic_use. + If self.name is one of the comma separated strings in autoheuristic_use, + it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. + """ + + if not self.satisfies_precondition(): + return self.fallback() + + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + decision = controller.get_decision() + if decision not in self.choices: + # TODO(AlnisM): We might want to allow this in the future + return self.fallback() + if decision is not None: + return decision + return self.fallback() + + def get_top_k_choices( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[Choice]]: + if not self.satisfies_precondition(): + return None + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + choices = controller.get_decisions_ranked(top_k) + if choices is None: + return None + if always_included is not None: + for choice in always_included: + if choice not in choices: + choices.append(choice) + return choices + return None + + def get_collected_feedback(self, choice: Choice) -> Any: + return self.collected_feedback.get(choice, None) + + @staticmethod + def get_device_identifier() -> str: + # a heuristic might work well for one GPU, but not for another + # we store the collected data per GPU model and learn a heuristic per GPU model + + # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names + device_name = torch.cuda.get_device_name().replace(" ", "_") + return device_name + + def get_default_log_path(self) -> str: + device_name = self.get_device_identifier() + path = f"{cache_dir()}/autoheuristic/{device_name}/" + os.makedirs(path, exist_ok=True) + path += f"{self.name}.txt" + return path + + def serialize_metadata(self) -> str: + metadata_dict = self.metadata.to_dict() + ( + num_features, + cat_features, + ) = self.context.get_numerical_and_categorical_features() + metadata_dict["numerical_features"] = num_features + metadata_dict["categorical_features"] = cat_features + return json.dumps(metadata_dict) + + def save_data(self, choice: Choice, feedback_val: Feedback) -> None: + self.collected_feedback[choice] = feedback_val + log_path = self.log_path + + lines = [] + log_exists = os.path.exists(log_path) + if log_exists: + # if log already exists, make sure it is consistent + metadata = self.serialize_metadata() + existing_metadata = get_metadata_str_from_log(self.log_path) + if existing_metadata != metadata: + raise InconsistentMetadata( + "Given metadata does not match existing metadata" + ) + else: + lines.append(self.serialize_metadata()) + feature_header = self.context.get_feature_names_csv() + header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL + lines.append(header) + + line = "" + feature_values = self.context.get_feature_values_csv() + line += feature_values + "," + choice + "," + str(feedback_val) + lines.append(line) + + with open(log_path, "a") as f: + f.write("\n".join(lines) + "\n") + + +class AutoHeuristicSelectAlgorithm(AutoHeuristic): + """ + AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic + when one wants to use AutoHeuristic for kernel choice selection. + """ + + def __init__( + self, + fallback: Callable[[], Optional[ChoiceCaller]], + choices: list[ChoiceCaller], + input_nodes: list[Any], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + The arguments choices, input_nodes and name have to match the ones used in the call to + autotune_select_algorithm(), e.g. if the following call is made + autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes + have to be used here. + """ + self.input_nodes = input_nodes + self.choicestr2choice: dict[str, ChoiceCaller] = {} + for choice in choices: + self.choicestr2choice[choice.autoheuristic_id()] = choice + choices_str = list(self.choicestr2choice.keys()) + + def fallback_str() -> str: + fallback_choice = fallback() + if fallback_choice is None: + # TODO: Find a nicer way to handle this + return "unsure" + return fallback_choice.autoheuristic_id() + + super().__init__( + fallback_str, + choices_str, + None, + context, + name, + augment_context, + precondition, + ) + + if ( + torch._inductor.config.collect_autoheuristic(self.name) + and self.satisfies_precondition() + ): + self.register_global_feedback(input_nodes, choices) + + def register_global_feedback( + self, input_nodes: list[Any], choices: list[ChoiceCaller] + ) -> None: + """ + Registers a callback in select_algorithm, which is called with the timing of each choice. + """ + + from torch._inductor.select_algorithm import ( + add_feedback_saver, + create_inputs_key, + create_precompile_key, + ) + + def store_global_feedback( + ah_inputs_key: str, + ah_precompile_key: str, + timings: dict[ChoiceCaller, float], + name: str, + input_nodes: list[Any], + choices: list[ChoiceCaller], + ) -> None: + current_inputs_key = create_inputs_key(input_nodes) + if current_inputs_key != ah_inputs_key: + return + current_precompile_key = create_precompile_key( + name, current_inputs_key, choices + ) + if current_precompile_key != ah_precompile_key: + return + for choice, time in timings.items(): + self.save_data(choice.autoheuristic_id(), time) + + inputs_key = create_inputs_key(input_nodes) + precompile_key = create_precompile_key(self.name, inputs_key, choices) + feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) + add_feedback_saver(feedback_saver) + + def get_choice_caller(self) -> Optional[ChoiceCaller]: + choice = self.get_choice() + return self.choicestr2choice.get(choice, None) + + def get_top_k_choices_caller( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[ChoiceCaller]]: + choices = self.get_top_k_choices(top_k, always_included) + if choices is None: + return None + return [self.choicestr2choice[choice] for choice in choices] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0435fe44b4035a8f338f503340fc351019252a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py @@ -0,0 +1,340 @@ +import functools +from collections.abc import Callable +from typing import Any + +import torch + + +Feedback = float +Choice = str +Value = Any + +CHOICE_COL = "choice" +FEEDBACK_COL = "feedback" + + +class AHFeature: + """ + The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is + categorical (i.e., not a continuous variable) to learn a machine learning model. + """ + + def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None: + self.name = name + self.value = value + self.is_categorical = is_categorical + + +class AHOperation: + """ + AHOperation can be used to augment the data collected by AutoHeuristic. + One might for example store features like m, k, n, but also want to use + features like m*n, or k*n, to learn a heuristic. Instead of storing features + that can be created from the collected data, one can use AHOperation to + create new features from the collected data. + """ + + def __init__( + self, name: str, func: Callable[[Any], Value], is_categorical: bool = False + ) -> None: + self.name = name + self.func = func + self.is_categorical = is_categorical + + def apply_operation(self, data: Any) -> None: + data[self.name] = self.func(data) + + +class AHContext: + """ + This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will + store the context and the collected feedback. The context could be something like the shape of a tensor, i.e., + information that will help to learn a heuristic. + """ + + features: list[AHFeature] + context_dict: dict[str, Value] + + def __init__(self) -> None: + self.features = [] + self.context_dict = {} + + def add_feature( + self, name: str, value: Value, is_categorical: bool = False + ) -> None: + self.features.append(AHFeature(name, value, is_categorical=is_categorical)) + self.context_dict[name] = value + + def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]: + numerical_features = [] + categorical_features = [] + for feature in self.features: + if feature.is_categorical: + categorical_features.append(feature.name) + else: + numerical_features.append(feature.name) + + return numerical_features, categorical_features + + def get_feature_names_csv(self) -> str: + return ",".join(feature.name for feature in self.features) + + def get_feature_values_csv(self) -> str: + return ",".join(str(feature.value) for feature in self.features) + + def get_value(self, name: str) -> Value: + return self.context_dict[name] + + def apply_operations(self, operations: list[AHOperation]) -> None: + for op in operations: + op.apply_operation(self.context_dict) + + +class AHMetadata: + def __init__( + self, + shared_memory: Any, + device_capa: tuple[int, int], + choices: list[Choice], + name: str, + ) -> None: + # use amount of shared_memory and device_capability to identify GPU + # TODO(AlnisM): there might be a better way to do this + self.shared_memory = shared_memory + self.device_capa = device_capa + self.choices = choices + self.name = name + + def to_dict(self) -> dict[str, Value]: + return { + "shared_memory": self.shared_memory, + "device_capa": self.device_capa, + "name": self.name, + } + + +def get_metadata_str_from_log(log_path: str) -> str: + with open(log_path, newline="") as file: + json_string = file.readline().strip() + return json_string + + +def check_minsize(context: AHContext, minsize: int) -> bool: + return ( + context.get_value("m") >= minsize + and context.get_value("k") >= minsize + and context.get_value("n") >= minsize + ) + + +def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0): + # A100 precondition + return check_minsize(context, 512) + elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0): + # H100 precondition + return check_minsize(context, 768) + return True + + +def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + m = context.get_value("m") + k = context.get_value("k") + n = context.get_value("n") + if m > 128 or k < 1024 or n < 1024: + return False + mat1_iscontig = context.get_value("mat1_iscontig") + mat2_iscontig = context.get_value("mat2_iscontig") + return mat1_iscontig and not mat2_iscontig + + +def get_mult_dims_ops() -> list[AHOperation]: + m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) + m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) + k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) + return [m_times_k_op, m_times_n_op, k_times_n_op] + + +def get_arith_intensity(data: Any) -> float: + m = data["m"] + k = data["k"] + n = data["n"] + if m == 0 or k == 0 or n == 0: + return 0.0 + return m * k * n / (m * k + k * n + m * n) + + +def pad_mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + k_div_m_times_n_op = AHOperation( + "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) + ) + + def bfloat_perf_hit(data: Any) -> bool: + m = data["m"] + k = data["k"] + n = data["n"] + is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16" + return k > (m * 1024) and k > (n * 1024) and is_bfloat + + bfloat_perf_hit_op = AHOperation( + "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True + ) + + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + dims_need_padding_ops = get_dims_need_padding_ops() + dims_multiple_ops = get_dims_multiple_ops() + is_contig_ops = get_is_contig_ops() + + ah_operations = mult_dims_ops + [ + k_div_m_times_n_op, + bfloat_perf_hit_op, + arith_intensity_op, + ] + ah_operations.extend(dims_need_padding_ops) + ah_operations.extend(dims_multiple_ops) + ah_operations.extend(is_contig_ops) + return ah_operations + + +def between_op(data: Any, dim: str, lower: int, upper: int) -> bool: + return data[dim] >= lower and data[dim] <= upper + + +def between_ops() -> list[AHOperation]: + dims = ["m", "k", "n"] + limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] + ah_operations = [] + for dim in dims: + for lower, upper in limits: + between_op_fn = functools.partial( + between_op, dim=dim, lower=lower, upper=upper + ) + # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot + between_op_name = f"{lower}LEQ{dim}LEQ{upper}" + ah_operations.append( + AHOperation(between_op_name, between_op_fn, is_categorical=True) + ) + return ah_operations + + +def pow2_op(data: Any, dim: str, exponent: int) -> bool: + return data[dim] == 2**exponent + + +def mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + return mult_dims_ops + [arith_intensity_op] + + +def mixed_mm_operations() -> list[AHOperation]: + return mm_operations() + between_ops() + + +def is_multiple(data: Any, dim: str, mult: int) -> bool: + return data[dim] % mult == 0 + + +def get_dims_multiple_ops() -> list[AHOperation]: + multiples = [2, 4, 8, 16, 32] + dims = ["m", "k", "n"] + dims_multiple_ops = [] + for dim in dims: + for mult in multiples: + is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult) + dims_multiple_op = AHOperation( + f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True + ) + dims_multiple_ops.append(dims_multiple_op) + return dims_multiple_ops + + +def get_dims_need_padding_ops() -> list[AHOperation]: + def mat1_innermost_needs_padding_fn(data: Any) -> bool: + mat1_stride_0 = data["mat1_stride_0"] + mat1_stride_1 = data["mat1_stride_1"] + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + mat1_innermost_needs_padding = False + if mat1_stride_0 == 1 and m_padded_length != 0: + mat1_innermost_needs_padding = True + if mat1_stride_1 == 1 and k_padded_length != 0: + mat1_innermost_needs_padding = True + return mat1_innermost_needs_padding + + mat1_innermost_op = AHOperation( + "mat1_innermost_needs_padding", + mat1_innermost_needs_padding_fn, + is_categorical=True, + ) + + def mat2_innermost_needs_padding_fn(data: Any) -> bool: + mat2_stride_0 = data["mat2_stride_0"] + mat2_stride_1 = data["mat2_stride_1"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + mat2_innermost_needs_padding = False + if mat2_stride_0 == 1 and k_padded_length != 0: + mat2_innermost_needs_padding = True + if mat2_stride_1 == 1 and n_padded_length != 0: + mat2_innermost_needs_padding = True + return mat2_innermost_needs_padding + + mat2_innermost_op = AHOperation( + "mat2_innermost_needs_padding", + mat2_innermost_needs_padding_fn, + is_categorical=True, + ) + + def num_dims_needs_padding_fn(data: Any) -> int: + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + num_dims_needs_padding = 0 + if m_padded_length != 0: + num_dims_needs_padding += 1 + if k_padded_length != 0: + num_dims_needs_padding += 1 + if n_padded_length != 0: + num_dims_needs_padding += 1 + return num_dims_needs_padding + + num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn) + return [mat1_innermost_op, mat2_innermost_op, num_dims_op] + + +def get_is_contig_ops() -> list[AHOperation]: + def mat1_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat1_stride_0"] + stride_1 = data["mat1_stride_1"] + k = data["k"] + return stride_0 == k and stride_1 == 1 + + mat1_is_contig_op = AHOperation( + "mat1_iscontig", mat1_is_contig_fn, is_categorical=True + ) + + def mat2_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat2_stride_0"] + stride_1 = data["mat2_stride_1"] + n = data["n"] + return stride_0 == n and stride_1 == 1 + + mat2_is_contig_op = AHOperation( + "mat2_iscontig", mat2_is_contig_fn, is_categorical=True + ) + + return [mat1_is_contig_op, mat2_is_contig_op] + + +def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None: + for i, s in enumerate(stride): + context.add_feature(f"{name}_stride_{i}", s) + + +def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None: + using_tf32 = "not_float_32" + if dtype == torch.float32: + using_tf32 = torch.backends.cuda.matmul.allow_tf32 + context.add_feature("using_tf32", using_tf32, is_categorical=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..50c11eb9a712afafee7479987a6832e412cc393a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import pkgutil +from collections import defaultdict +from typing import Any, Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic + + +def find_and_instantiate_subclasses( + package_name: str, base_class: Any +) -> list[LearnedHeuristic]: + instances = [] + + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + try: + module_basename = module_name.split(".")[-1] + if not module_basename.startswith("_"): + # learned heuristics start with an underscore + continue + module = importlib.import_module(module_name) + + # look for classes that are subclasses of base_class + for _name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, base_class) + and obj != base_class + ): + instance = obj() + instances.append(instance) + except Exception as e: + print(f"Error processing module {module_name}: {e}") + + return instances + + +class LearnedHeuristicController: + """ + Class that finds and instantiates all learned heuristics. It also provides + a way to get the decision of a learned heuristic. + """ + + existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list) + """ + A dictionary that stores all the learned heuristics for each optimization. + The key is the optimization name, and the value is a list of LearnedHeuristic objects. + """ + + heuristics_initialized: bool = False + """ + A flag that indicates whether the learned heuristics have been initialized. + Set to true when the get_decision() function is called for the first time. + """ + + def __init__( + self, + metadata: AHMetadata, + context: AHContext, + ) -> None: + self.metadata = metadata + self.context = context + + def get_heuristics(self, name: str) -> list[LearnedHeuristic]: + """ + Returns a list of learned heuristics for the given optimization name. + """ + + if not LearnedHeuristicController.heuristics_initialized: + # learned heuristics are generated into the following package + learned_heuristics_package = "torch._inductor.autoheuristic.artifacts" + + # learned heuristics have to be of type LearnedHeuristic + base_class = LearnedHeuristic + found_heuristics = find_and_instantiate_subclasses( + learned_heuristics_package, base_class + ) + + for learned_heuristic in found_heuristics: + opt_name = learned_heuristic.get_name() + LearnedHeuristicController.existing_heuristics[opt_name].append( + learned_heuristic + ) + LearnedHeuristicController.heuristics_initialized = True + + return LearnedHeuristicController.existing_heuristics[name] + + def get_decision(self) -> Optional[Choice]: + """ + Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure + which choice to make. + """ + + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + return heuristic.get_decision(self.context, self.metadata.choices) + return None + + def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]: + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + choices = heuristic.get_decisions_ranked(self.context) + if choices is None: + return None + avail_choices = [ + choice for choice in choices if choice in self.metadata.choices + ] + return avail_choices[:top_k] + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..84a941b076c314d9961af916a5a559e9948c0e00 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -0,0 +1,89 @@ +import operator +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) + + +class LearnedHeuristic: + """ + LearnedHeuristic is a base class for all learned heuristics. + """ + + def __init__(self) -> None: + pass + + def check_precondition( + self, + metadata: AHMetadata, + context: AHContext, + ) -> bool: + return True + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + return None + + def get_confidence_threshold(self) -> float: + return 1.0 + + def get_name(self) -> str: + return "" + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + return None + + +class LearnedHeuristicRegression(LearnedHeuristic): + def get_feedback(self, context: AHContext, choice: Choice) -> float: + return 1.0 + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + choice2feedback = {} + for choice in choices: + predicted_feedback = self.get_feedback(context, choice) + choice2feedback[choice] = predicted_feedback + sorted_choices_feedback = sorted( + choice2feedback.items(), key=operator.itemgetter(1) + ) + highest_feedback = sorted_choices_feedback[-1][1] + second_highest_feedback = sorted_choices_feedback[-2][1] + if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): + return sorted_choices_feedback[-1][0] + # We are not sure which choice is the best one + return None + + +class LearnedHeuristicDecision(LearnedHeuristic): + def get_choice(self, idx: int) -> Optional[str]: + return None + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + best_choices = self.get_best_choices(context) + if not best_choices: + return None + (best_choice_proba, best_choice_idx) = best_choices[0] + if best_choice_proba <= self.get_confidence_threshold(): + return None + return self.get_choice(best_choice_idx) + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + feedback_idx_list = self.get_best_choices(context) + if feedback_idx_list is None: + return None + choices = [ + self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list + ] + choices = [choice for choice in choices if choice is not None] + return choices + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + return [] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8acf64ab1ae1106dac19d64f934550a9fdcdb3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6634a81091c7dffa981588d82c10f84fee120814 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2c609ad698dc3d2efabc7faa1f7b34526e5dd23 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2d64eb241cc39436fe07cdef42dcb1489c168a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e24ec1f16079efe45280ce97287a2d8d3c74833 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5f57353adbda2a6145b1766fd5546ba36a025b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d846280a9dd894ac9d107a7343a9dcd512aa219f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60fd5cc7809ed9565fb1cd22b4c3d5eb9d7c4f86 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b1b7df052f439203819a5e789726df0b7acc7a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d57105277c2bb94ec3eb8a633ffaa5477d08ee4d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42bbf988bf02c037d6b2f56818ce32f30fd308c1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247c85d9fab49734149f08d1e4121728950e37a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f22da2e3337a158fa0271f93781b7706244e9da3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..093bb4be58d74d9b63b06a23e4e26b9f932ddb3b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4ba169848f53dbd2ed70e1d721b7875ec5c65c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..138eca734ddca4b36d428ad08ad37a4117715a05 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c74b922cd7fab8266a5b6a952ff39eef66045649 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cfa16b291a1ca46594626fd3a9f33752165f86d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3eef5292163b775d21c88e14785d5b1407afe35 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36eec2634937dfb2b934fe8c37d7ed9f21d762dd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acdb6a7807671a1fbd06981fdb3e9b9a8d07b78b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59d5d9d8f3bf24b3f43274515b403c4a3f04201 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efdf1c5d2ee61e4b3459d6fbcfb43f878fd67cb4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adcc80560030e896b5359ac5c67b03093019081e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/segmented_tree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/segmented_tree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6598704394fd3ef1f6f052c2104b3e72cd4bf655 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/segmented_tree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d57d09d37e90625fc53b8916b0e5df1d0d70430 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439e7fba1af34a93576c55aab3acfd46e9560288 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ed5f326a4fbd0d0832e04f553a4acaf932fa950 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e327c0e046a58c239ac655954b77456a170182c4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e9b45db07c4c3fec0bda24537bc7da9a670c03f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe0deb18f59b5040ca37c202312019afc4223355 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eca4f85ced9260e9122db73366ab4a136b7cc4ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py @@ -0,0 +1,36 @@ +import re + +import torch + + +# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: +# "... +# from ..codecache import CudaKernelParamCache +# ..." +# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache + + +def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: + if torch.version.hip is None and not force_hipify: + return source_codes + + try: + from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + except ImportError: + # hipify not available for non-AMD builds + return source_codes + + def c2_repl(m: re.Match[str]) -> object: + return PYTORCH_MAP[m.group(0)] + + # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, + # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching + # keyword at the beginning of code line. However, this can happen in codegen, + # which will cause the pattern to not match. + + # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example + # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" + RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") + + source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) # type: ignore[arg-type] + return source_codes diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..515ab89d1f2d1187fe2855733ecbbc523fbeae0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,488 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << '\n'; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred.\n"; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(false); + } + AOTINoGradGuard(const AOTINoGradGuard&) = delete; + AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete; + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete; + AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete; + bool prev_mode{aoti_torch_grad_mode_is_enabled()}; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0\n"; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* type) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( + AOTInductorModelContainerHandle container_handle, + size_t idx, + size_t* data_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *data_size = container->constant_data_size(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive) { + auto* container = + reinterpret_cast( + container_handle); + auto constants_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { const auto ret = container->extract_constants_map(use_inactive); + for (const auto& pair: ret) { + constants_map->emplace(pair.first, pair.second); + } + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update, /* user_managed = */ true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBufferPairs( + AOTInductorModelContainerHandle container_handle, + const AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast(container_handle); + // Build a local unordered_map inside + std::unordered_map input_map; + input_map.reserve(num_pairs); + for (size_t i = 0; i < num_pairs; ++i) { + input_map.emplace(pairs[i].name, pairs[i].handle); + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + input_map, use_inactive, validate_full_update, /*user_managed=*/true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->free_inactive_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize( + AOTInductorModelContainerHandle container_handle, + uint64_t* ret_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_size = container->constant_blob_size(); }) +} + + +// Load weights from a single blob in weight_blob_ptr +AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob( + AOTInductorModelContainerHandle container_handle, + const uint8_t* weight_blob_ptr){ + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + {container->update_constants_from_blob(weight_blob_ptr); }) + } + + +} // extern "C" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..b47c8325e21545a9ca30f513a22b22480b4d6ab0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py @@ -0,0 +1,192 @@ +import collections +import functools +import textwrap +from typing import Optional + +import sympy +from sympy import Expr, Symbol + +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from ..utils import sympy_dot, sympy_subs +from ..virtualized import V + + +class BlockPatternMatcher: + """ + Matches block indexing expressions. + """ + + _indexing_wild_signed_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer] + ) + _indexing_wild_unsigned_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative] + ) + + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: + """ + Given a sympy expression, return the subexpression comprised only of terms + involving the specified symbol. + + For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, + this returns `x * 5 + x ** 2`. + """ + expr = cls._preprocess(expr) + return sympy.S.Zero + sum( + term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols + ) + + @staticmethod + def get_slice_numels(dims: list[Expr]) -> list[Expr]: + """ + Compute the cumulative size of each dimension's slice. + This proceeds from the last dim up to the second. + """ + numels = collections.deque([sympy.S.One]) + for dim in dims[:0:-1]: + numel = dim * numels[0] + numels.appendleft(numel) + return [*numels] + + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + + @classmethod + def match_mod_div_block_expr( + cls, + index: Expr, + index_var: Symbol, + numel: Expr, + num_dims: int, + ) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]: + """ + Matches modular indexing expressions, converting them to implied block dimensions and strides. + See triton.py for more information. + """ + index = cls._preprocess(index) + + # Pattern match to find the strides and offset. + wild_unsigned_int = functools.partial( + cls._indexing_wild_unsigned_int, exclude=[index_var] + ) + wild_signed_int = functools.partial( + cls._indexing_wild_signed_int, exclude=[index_var] + ) + dims: list[Expr] = [ + wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims) + ] + strides: list[Expr] = [ + wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims) + ] + + # The first dimension's index is computed by division. + # The remaining are computed by modulo. + slice_numels = cls.get_slice_numels(dims[:num_dims]) + block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ + ModularIndexing(index_var, numel, dim) + for dim, numel in zip(dims[1:], slice_numels[1:]) + ] + + # Calculate a linear index from block indices. + match_expr = sympy_dot(strides, block_index_exprs) + + # Heuristic: if the number of dimensions is high, check that the minimum requirements + # are met before attempting an expensive full match. see triton.py:match_mod_div_block + # for more details. In short, here we check that each subexpression in sympy.Add contains + # only FloorDiv or ModularIndexing expressions. + if num_dims >= 5: + stride = sympy.symbols("stride", cls=wild_signed_int) + denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int) + mod_div_pattern = stride * ModularIndexing(index_var, denom, other) + floor_div_pattern = stride * FloorDiv(index_var, denom) + first_dim_floor_div_matched = False + match_failed = False + for arg in sympy.Add.make_args(index): + if arg.match(floor_div_pattern): + # There should only be a single FloorDiv(index, denom) expression + # corresponding to the first dimension + if first_dim_floor_div_matched: + match_failed = True + break + first_dim_floor_div_matched = True + elif arg.match(mod_div_pattern): + continue + else: + match_failed = True + break + + if match_failed: + return None + + # Pattern match. + match = index.match(match_expr) + if match is None: + return None + + # Provide default values for unmatched dims and strides. + for dim in dims[1:]: + if dim not in match: + match[dim] = sympy.S.One + for stride in strides[1:]: + if stride not in match: + match[stride] = sympy.S.Zero + + sizevars = V.graph.sizevars + + def get_match(expr: Expr) -> Expr: + return sizevars.lookup_precomputed_size(match[expr]) + + # Replace wildcards with matched expressions. + dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] + strides = [get_match(stride) for stride in strides] + slice_numels = cls.get_slice_numels(dims) + block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs] + + # The leading dimension is not directly matched in our expression. + # We solve for it by dividing the range tree numel by the product of + # all other dimensions. We quit if they are not known to be divisible. + assert dims[0] not in match, "Expected not to match the leading dimension!" + if not sizevars.statically_known_multiple_of(numel, slice_numels[0]): + return None + dims[0] = numel / slice_numels[0] + + # Sanity check that we can recover the index from the matched subexpressions. + matched_index = sympy_dot(strides, block_index_exprs) + assert sizevars.statically_known_equals( + # New precomputed replacements may be generated when the `get_match` function + # above is called, but the `index` that is being matched has not been updated. + # So remove them when checking for equivalence e.g. if ps0=3*s0 and + # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index + sizevars.remove_precomputed_replacements(matched_index), + sizevars.remove_precomputed_replacements(index), + ), textwrap.dedent( + f""" + Invalid match! + Index: {index} + Matched expression: {matched_index} + """ + ) + + return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e27336af8eab90cf38d6799515df6f6992da0ee5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,2918 @@ +from __future__ import annotations + +import atexit +import contextlib +import dataclasses +import enum +import functools +import itertools +import logging +import math +import operator +import os +import re +import tempfile +from abc import ABC, abstractmethod +from enum import auto, Enum +from itertools import chain +from typing import ( + Any, + cast, + ClassVar, + Generic, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import Self, TypeVar + +import sympy + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._config_module import ConfigModule +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .. import config, metrics +from ..dtype_propagation import DtypePropagationOpsHandler +from ..ops_handler import BasicMathOpsMixin, DefaultHandler +from ..shape_propagation import ShapePropagationOpsHandler +from ..utils import ( + boolean_ops, + DeferredLineBase, + generate_assert, + get_current_backend, + IndentedBuffer, + ir_dataclass, + ScopedDict, + sympy_dot, + sympy_index_symbol, + sympy_subs, + triton_type, + unique, +) +from ..virtualized import ( + NullHandler, + ops, + OpsHandler, + OpsValue, + ReductionType, + StoreMode, + V, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator, MutableMapping, Sequence + + from torch.fx import GraphModule + + from ..custom_graph_pass import CustomGraphModulePass + from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode + from ..loop_body import LoopBody + from ..scheduler import BaseScheduling, Scheduler, SchedulerNode + from ..shape_propagation import BlockShapeType + from .wrapper import PythonWrapperCodegen + + _T = TypeVar("_T") + SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling] + WrapperConstructor = type[PythonWrapperCodegen] + SymbolLike = Union[str, sympy.Symbol] + + # OpVarT should really be Union[CSEVariable, str], however this + # causes typing errors in subclasses (defined in other files). + OpVarT = str + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +log = logging.getLogger(__name__) + + +def data_type_logger(msg: str) -> None: + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class FileBackedGraphModule: + """ + Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these + map back to a GraphModule instead of Python source. + """ + + gm: GraphModule + compiled_fn: Callable[..., Any] + + def __post_init__(self) -> None: + # Write the code to a file for compatibility with debugging utilities. + # The file is deleted upon program termination. + self.tempfile = tempfile.NamedTemporaryFile( # noqa: SIM115 + mode="w+", suffix=".py", delete=False + ) + atexit.register(os.remove, self.tempfile.name) + with self.tempfile as f: + f.write(self.value) + + @property + def __file__(self) -> str: + return self.tempfile.name + + def call(self, args: list[Any]) -> Any: + return self.compiled_fn(*args) + + @property + def value(self) -> str: + return self.gm.code + + +class WorkspaceZeroMode(enum.Enum): + UNINITIALIZED = 0 + ZERO_ON_CALL = 1 # kernel may leave workspace dirty + ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel + + @staticmethod + def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: + if a == b or b == WorkspaceZeroMode.UNINITIALIZED: + return a + if a == WorkspaceZeroMode.UNINITIALIZED: + return b + raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") + + @staticmethod + def from_bool(zero_fill: bool) -> WorkspaceZeroMode: + if zero_fill: + return WorkspaceZeroMode.ZERO_ON_CALL + return WorkspaceZeroMode.UNINITIALIZED + + +class CodegenSymbol(ABC): + """ + An IR object possibly corresponding to a variable in the wrapper code. + """ + + @abstractmethod + def get_name(self) -> str: + pass + + @abstractmethod + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + pass + + +@ir_dataclass(frozen=True) +class WorkspaceArg(CodegenSymbol): + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + + Args: + nbytes: The size of the buffer in bytes. + zero_fill: Whether the buffer should be initialized to zero. + + """ + + count: sympy.Expr + zero_mode: WorkspaceZeroMode + device: torch.device + outer_name: str + inner_name: str = "ws_ptr" + dtype: torch.dtype = torch.uint8 + + @staticmethod + def unique_name(prefix: str = "workspace_") -> str: + return f"{prefix}{next(V.graph.workspace_id)}" + + @staticmethod + def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool: + return ( + a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device + ) + + @staticmethod + def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + return WorkspaceArg( + count=a.count + b.count, + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + @staticmethod + def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + assert ( + a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name + ) + return WorkspaceArg( + count=sympy.Max(a.count, b.count), + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code + def get_device(self) -> torch.device: + return self.device + + get_device_or_error = get_device + + def get_dtype(self) -> torch.dtype: + return self.dtype + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.get_layout().get_example() + + def get_layout(self) -> FixedLayout: + from ..ir import FixedLayout + + return FixedLayout( + device=self.device, + dtype=self.dtype, + size=[self.count], + stride=[1], + ) + + @property + def layout(self) -> FixedLayout: + return self.get_layout() + + get_output_spec = get_layout + maybe_get_output_spec = get_layout + maybe_get_layout = get_layout + + def get_offset(self) -> sympy.Expr: + return sympy.S.Zero + + def get_size(self) -> list[sympy.Expr]: + return [self.count] + + def get_stride(self) -> list[sympy.Expr]: + return [sympy.S.One] + + def get_name(self) -> str: + return self.outer_name + + def get_is_pinned(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> list[str]: + return [] + + +class TritonScratchWorkspace: + def __init__(self, size: int, generate_dtype_str: Callable[..., str]): + self.size = size + self._generate_dtype_str = generate_dtype_str + + def generate_dtype_str(self) -> str: + return self._generate_dtype_str() + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.S.Zero # c++ only + alias_of: Optional[str] = None # halide only + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + @property + def alias_of(self) -> Optional[str]: + return None + + +@dataclasses.dataclass +class ConstexprArg: + name: str + + +@dataclasses.dataclass +class TMADescriptorArg: + name: str + api_type: str # "experimental" or "stable" + block_shape: Optional[list[sympy.Expr]] # only needed for "stable" + dtype: Optional[torch.dtype] # only needed for "stable" + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: SchedulingConstructor + wrapper_codegen: WrapperConstructor + cpp_wrapper_codegen: Optional[WrapperConstructor] = None + fx_wrapper_codegen: Optional[WrapperConstructor] = None + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg] + +device_codegens: dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name: str) -> str: + raise NotImplementedError + + def set_device(self, device_idx: int) -> str: + raise NotImplementedError + + def synchronize(self) -> str: + raise NotImplementedError + + def device_guard(self, device_idx: int) -> str: + raise NotImplementedError + + def cpp_device_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_device_guard(self) -> str: + raise NotImplementedError + + def cpp_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_getStreamFromExternal(self) -> str: + raise NotImplementedError + + def kernel_header(self) -> str: + raise NotImplementedError + + def kernel_driver(self) -> str: + raise NotImplementedError + + def cpp_stream_type(self) -> str: + raise NotImplementedError + + def aoti_get_stream(self) -> str: + raise NotImplementedError + + def cpp_kernel_type(self) -> str: + raise NotImplementedError + + def cpp_device_ptr(self) -> str: + raise NotImplementedError + + def tma_descriptor_helpers(self) -> str: + raise NotImplementedError + + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None + ) -> Optional[tuple[list[str], str]]: + # optionally return (scratch definition, arg name) + raise NotImplementedError + + +device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} +custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} +custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, + device_scheduling: SchedulingConstructor, + device_wrapper_codegen: WrapperConstructor, + device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, + device_fx_wrapper_codegen: Optional[WrapperConstructor] = None, + device_custom_pass: Optional[CustomGraphModulePass] = None, + device_custom_config: Optional[ConfigModule] = None, +) -> None: + device_codegens[device] = DeviceCodegen( + device_scheduling, + device_wrapper_codegen, + device_cpp_wrapper_codegen, + device_fx_wrapper_codegen, + ) + custom_backend_passes[device] = device_custom_pass + if device_custom_config: + assert ( + isinstance(device_custom_config, ConfigModule) + and device_custom_config is not config + ), ( + f"{device_custom_config=} cannot be the same as the default inductor config {config=}" + ) + custom_backend_codegen_configs[device] = device_custom_config + + +class BackendFeature(Enum): + FOREACH = auto() + BUCKETIZE = auto() + INPLACE_BUFFERS = auto() + MASKED_SCATTER_WITH_INDEX = auto() + SCAN = auto() + SORT = auto() + TUPLE_REDUCTION = auto() + PREFER_STORE_LOOP_ORDER = auto() + TRITON_TEMPLATES = auto() + REDUCE_TO_SINGLE_ELEMENT = auto() + + +def get_backend_features( + device: Union[torch.device, str, None], +) -> OrderedSet[BackendFeature]: + if device is None: + return OrderedSet() + init_backend_registration() + if isinstance(device, torch.device): + device_type = device.type + else: + assert isinstance(device, str), type(device) + device_type = device + device = torch.device(device_type) + scheduling_ctor = get_scheduling_for_device(device_type) + assert scheduling_ctor + scheduling = scheduling_ctor(None) + return scheduling.get_backend_features(device) + + +def has_backend_feature( + device: Union[torch.device, str, None], feature: BackendFeature +) -> bool: + """See also V.graph.has_feature""" + assert isinstance(feature, BackendFeature) + return feature in get_backend_features(device) + + +def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]: + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device( + device: str, cpp_wrapper: bool = False, fx_wrapper: bool = False +) -> Optional[WrapperConstructor]: + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + if fx_wrapper: + return wrapper_codegen_obj.fx_wrapper_codegen + elif cpp_wrapper: + return wrapper_codegen_obj.cpp_wrapper_codegen + else: + return wrapper_codegen_obj.wrapper_codegen + return None + + +def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]: + return custom_backend_passes.get(device) + + +def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]: + return custom_backend_codegen_configs.get(device) + + +@functools.cache +def init_backend_registration() -> None: + """ + Register the backend for different devices, including the scheduling + for kernel code generation and the host side wrapper code generation. + """ + from .cpp import CppScheduling + from .cpp_wrapper_cpu import CppWrapperCpu + from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef + from .cpp_wrapper_gpu import CppWrapperGpu + from .cpp_wrapper_mps import CppWrapperMps + from .cuda_combined_scheduling import CUDACombinedScheduling + from .halide import HalideScheduling + from .mps import MetalScheduling + from .pallas import PallasScheduling + from .python_wrapper_mtia import PythonWrapperMtia + from .triton import TritonScheduling + from .wrapper import PythonWrapperCodegen + from .wrapper_fxir import WrapperFxCodegen + + if get_scheduling_for_device("cpu") is None: + cpu_backends = { + "cpp": CppScheduling, + "halide": HalideScheduling, + "triton": TritonScheduling, + "pallas": PallasScheduling, + } + register_backend_for_device( + "cpu", + lambda scheduling: cpu_backends[config.cpu_backend](scheduling), + PythonWrapperCodegen, + CppWrapperCpuArrayRef + if config.aot_inductor.allow_stack_allocation + else CppWrapperCpu, + WrapperFxCodegen, + ) + + if get_scheduling_for_device("cuda") is None: + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + cuda_backends = { + "triton": CUDACombinedScheduling, + "halide": HalideScheduling, + "pallas": PallasScheduling, + } + register_backend_for_device( + "cuda", + lambda scheduling: cuda_backends[config.cuda_backend](scheduling), + PythonWrapperCodegen, + CppWrapperGpu, + WrapperFxCodegen, + ) + + if get_scheduling_for_device("xpu") is None: + register_backend_for_device( + "xpu", + TritonScheduling, + PythonWrapperCodegen, + CppWrapperGpu, + WrapperFxCodegen, + ) + + if get_scheduling_for_device("mps") is None: + register_backend_for_device( + "mps", + MetalScheduling, + PythonWrapperCodegen, + CppWrapperMps, + WrapperFxCodegen, + ) + + if get_scheduling_for_device("mtia") is None: + register_backend_for_device( + "mtia", + TritonScheduling, + PythonWrapperMtia, + CppWrapperGpu, + WrapperFxCodegen, + ) + + private_backend = torch._C._get_privateuse1_backend_name() + if ( + private_backend != "privateuseone" + and get_scheduling_for_device(private_backend) is None + ): + from torch.utils.backend_registration import _get_custom_mod_func + + try: + device_scheduling = _get_custom_mod_func("Scheduling") + wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen") + cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen") + fx_wrapper_codegen = _get_custom_mod_func("WrapperFxCodegen") + if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: + register_backend_for_device( + private_backend, + device_scheduling, + wrapper_codegen, + cpp_wrapper_codegen, + fx_wrapper_codegen, + ) + except RuntimeError: + pass + + +def index_prevent_reordering( + index: Sequence[sympy.Expr], + index_vars: Sequence[sympy.Expr], + sizes: Sequence[sympy.Expr], +) -> list[sympy.Expr]: + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides( + device: str, device_op_overrides: DeviceOpOverrides +) -> None: + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str) -> DeviceOpOverrides: + assert isinstance(device, str), type(device) + + if not device_op_overrides_dict: + from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 + from .cuda import device_op_overrides # noqa: F401 + from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + + return device_op_overrides_dict[device] + + +DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +def deduce_output_dtype_by_name( + op_name: str, + *args: Any, + **kwargs: Any, +) -> Optional[torch.dtype]: + """ + Given op name and a list of input dtypes, deduce the output dtype + """ + if op_name in boolean_ops(): + return torch.bool + elif op_name in ( + "to_dtype", + "index_expr", + ): + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "rand", + "randn", + ): + return torch.float + elif op_name in ( + "get_index", + "randint64", + "load_seed", + ): + return torch.int64 + elif op_name == "reduction": + return kwargs["dtype"] if "dtype" in kwargs else args[1] + elif op_name == "constant": + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "load", + "store", + "store_reduction", + ): + buf_name = args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] + return None + + +def check_dtype( + buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype +) -> None: + backend = get_current_backend() + if config.test_configs.runtime_triton_dtype_assert and backend == "triton": + buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})") + elif config.test_configs.static_cpp_dtype_assert and backend == "cpp": + from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP + + assert isinstance(var, CppCSEVariable), type(var) + if dtype == torch.bool: + if var.is_vec: + is_same_dt = f"IsVecMaskType::value" + else: + # operator&(bool, bool) returns int and it can be used as boolean in C++ + is_same_dt = f"std::is_same_v || std::is_same_v" + else: + c_var_type = f"decltype({var})" + if var.is_vec: + c_var_type = f"typename {c_var_type}::value_type" + is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>" + + buffer.writeline(f"static_assert({is_same_dt});") + + +def check_shape( + buffer: IndentedBuffer, var: CSEVariableType, shape: BlockShapeType +) -> None: + backend = get_current_backend() + assert shape is not None + if config.test_configs.runtime_triton_shape_assert and backend == "triton": + shape_str = ( + ", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]}," + ) + buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))") + + +def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None: + backend = get_current_backend() + if backend == "triton": + msg = "NaN or Inf found" + buffer.writeline( + f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')" + ) + + +class DataTypePropagation: + def __init__(self, body: LoopBody) -> None: + self.body = body + self.graphs: dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]: + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propagated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propagated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype: + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: + if node.op == "placeholder": + return None + + if node.target == "output" and len(node.args) != 1: + # we can infer output node if it only have 1 arg + return None + + if node.target is operator.getitem: + node_arg = node.args[0] + assert isinstance(node_arg, torch.fx.Node), type(node_arg) + return self.deduce_node_dtype(node_arg) + + assert isinstance(node.target, str), type(node.target) + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + if ( + output_dtype := deduce_output_dtype_by_name( + node.target, + *node.args, + **node.kwargs, + ) + ) is not None: + return output_dtype + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]: + assert graph.nodes + graph_dtype: Optional[torch.dtype] = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self) -> Optional[torch.dtype]: + return self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]: + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]: + from ..loop_body import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode), type(node) + assert isinstance(node._body, LoopBody), type(node._body) + return DataTypePropagation.propagate_loopbody(node._body) + + +class PythonPrinter(_PythonPrinter): + def doprint( + self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True + ) -> str: + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str: + if isinstance(item, sympy.Mod): + # use parenthesis to enforce precedence. + # in sympy 1.13.3, -2*Mod(x,y) becomes -2*x%y, which is wrong. + return f"({self._print(item)})" + else: + return super().parenthesize(item, level, strict) + + +class OpDecompositions: + """ + Decomposes inductor ops + """ + + @staticmethod + def identity(value: OpVarT) -> OpVarT: + # used to trigger cse + return value + + @staticmethod + def reciprocal(x: OpVarT) -> OpVarT: + return ops.truediv(ops.constant(1, torch.int32), x) + + @staticmethod + def square(x: OpVarT) -> OpVarT: + return ops.mul(x, x) + + @staticmethod + def erfc(x: OpVarT) -> OpVarT: + return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) + + @staticmethod + def erfcx(x: OpVarT) -> OpVarT: + return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) + + @staticmethod + def expm1(x: OpVarT) -> OpVarT: + return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) + + @staticmethod + def log10(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) + + @staticmethod + def log2(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) + + @staticmethod + def exp2(x: OpVarT) -> OpVarT: + return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) + + @staticmethod + def log1p(x: OpVarT) -> OpVarT: + return ops.log(ops.add(x, ops.constant(1, torch.int32))) + + @staticmethod + def sigmoid(x: OpVarT) -> OpVarT: + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + + @staticmethod + def relu(x: OpVarT) -> OpVarT: + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT: + # for backends that don't override this (halide) + return ops.add(ops.mul(x, y), z) + + @staticmethod + def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> OpVarT: + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.round(a), dtype) + + +_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE) + + +def _all_in_parens(string: str) -> bool: + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): + @staticmethod + def paren(string: OpVarT) -> OpVarT: + if ( + isinstance(string, CSEVariable) + or _RE_PAREN_NOT_NEEDED.fullmatch(string) + or _all_in_parens(string) + ): + # don't put extra parens for strings that are already wrapped in parens + # pyrefly: ignore [bad-return] + return string + return f"({string})" + + @staticmethod + def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT: + return repr(value) + + @staticmethod + def bitwise_not(x: OpVarT) -> OpVarT: + return f"~{OpOverrides.paren(x)}" + + @staticmethod + def logical_not(a: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(a)} == 0" + + @staticmethod + def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" + + @staticmethod + def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT: + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + + @staticmethod + def load_seed(name: str, offset: OpVarT) -> OpVarT: + return ops.load(name, sympy.Integer(offset)) + + def indirect_indexing( + self, + var: OpVarT, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + return sympy_index_symbol(str(var)) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: check_bounds should be handled by CSEProxy" + ) + + def load(self, name: str, index: sympy.Expr) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: load should be handled by CSEProxy" + ) + + def store( + self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store should be handled by CSEProxy" + ) + + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + ) + + def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store_reduction should be handled by CSEProxy" + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[OpVarT, tuple[OpVarT, ...]], + ) -> Union[OpVarT, tuple[OpVarT, ...]]: + raise NotImplementedError( + f"{type(self).__name__}: reduction should be handled by CSEProxy" + ) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[OpVarT, ...], tuple[OpVarT, ...]], + tuple[OpVarT, ...], + ], + values: tuple[OpVarT, ...], + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: scan should be handled by CSEProxy" + ) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[OpVarT, ...], + stable: bool, + descending: bool, + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: sort should be handled by CSEProxy" + ) + + def bucketize( + self, + values: OpVarT, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: OpVarT, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[OpVarT] = None, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: bucketize should be handled by CSEProxy" + ) + + def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: halide_clamp only implemented for Halide backend" + ) + + def dot(self, x: OpVarT, y: OpVarT) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: dot only implemented for Triton backend" + ) + + def inline_asm_elementwise( + self, + *inputs: OpVarT, + asm: str, + constraints: Optional[str] = None, + dtype: torch.dtype = torch.float32, + is_pure: bool = True, + pack: int = 1, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" + ) + + def output(self, *args: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + + @staticmethod + def _unimplemented(name: str) -> Callable[..., OpVarT]: + def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__} does not implement ops.{name}" + ) + + unimplemented.__name__ = name + unimplemented.is_unimplemented = True # type: ignore[attr-defined] + return unimplemented + + @classmethod + def _is_unimplemented(cls, name: str) -> bool: + fn = getattr(cls, name, None) + default_fn = getattr(OpsHandler, name, None) + return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False) + + @classmethod + def _initialize_pointwise_overrides(cls, target: str) -> None: + assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if impl is None: + if cls._is_unimplemented(funcname): + setattr(cls, funcname, cls._unimplemented(funcname)) + else: + assert funcname not in cls.__dict__, ( + f"multiple definitions of {funcname} on {cls.__name__}" + ) + impl.__name__ = funcname + setattr(cls, funcname, staticmethod(impl)) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: Callable[..., str] + # None when not impl in libdevice/triton + triton: Optional[Callable[..., str]] = None + # None when not impl in aten/.../vec + cppvec: Optional[Callable[..., str]] = None + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + halide: Optional[Callable[..., str]] = None + mps: Optional[Callable[..., str]] = None + + +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too +pointwise_overrides_data: dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j0_forward({x})", + triton=lambda x: f"libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j1_forward({x})", + triton=lambda x: f"libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y0_forward({x})", + triton=lambda x: f"libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y1_forward({x})", + triton=lambda x: f"libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_digamma({x})", + cppvec=lambda x: f"{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_erfcx({x})", + triton=lambda x: f"libdevice.erfcx({x})", + name="special_erfcx", + ), + fma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", + cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", + triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", + name="fma", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + cppvec=lambda x: f"{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0e({x})", + cppvec=lambda x: f"{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i0_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i1_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, + y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +def is_buffer_removed(name: str) -> bool: + return any( + name in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ) + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name: str, line: str): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self) -> Optional[str]: + if not is_buffer_removed(self.name): + return self.line + return None + + def _new_line(self, line: str) -> DeferredLine: + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: list[str] + + +@dataclasses.dataclass +class ArgName: + name: str + # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list + is_constexpr: bool = False + + def full_name(self) -> str: + return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}" + + +class RemovedArg: + def __str__(self) -> str: + return "REMOVED" + + +REMOVED = RemovedArg() + + +class KernelArgs: + @staticmethod + def _lookup( + prefix: str, + odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]], + name: _T, + ) -> str: + result: Union[str, RemovedArg] = odict.get(name, REMOVED) + if isinstance(result, RemovedArg): + odict[name] = new_result = f"{prefix}{len(odict)}" + return new_result + return result + + def __init__(self) -> None: + self.input_buffers: dict[str, str] = {} + self.output_buffers: dict[str, Union[str, RemovedArg]] = {} + self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {} + self.sizevars: dict[sympy.Expr, str] = {} + self.workspace_args: list[WorkspaceArg] = [] + + def __repr__(self) -> str: + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + @staticmethod + def _buffer_is_marked_removed(name: Any) -> bool: + # this function is needed by MTIA + return isinstance(name, RemovedArg) + + def input(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return cast(str, self.output_buffers[name]) + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) + assert output_name not in self.inplace_buffers, output_name + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + assert not isinstance(buf, RemovedArg) + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + alive_buffers = [ + val + for val in self.inplace_buffers.values() + if not isinstance(val, RemovedArg) + ] + removed_buffers = [ + val + for val in self.inplace_buffers.values() + if isinstance(val, RemovedArg) + ] + inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers) + buf = InplacedBuffer( + f"in_out_ptr{inplace_buffer_idx}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace( + self, nelem: sympy.Expr, zero_fill: bool, dtype: torch.dtype = torch.uint8 + ) -> tuple[str, str, int]: + """ + Allocate or extend a workspace buffer of nelem elements. + + This function manages the allocation of a workspace buffer. It either creates + a new WorkspaceArg or extends an existing one. + + Note: + - Calling this function will in-place mutate the args by adding or updating + a WorkspaceArg. + - The codegen for generating the Python argdefs and call_defs will check + this field and allocate the buffer accordingly. + - A new argument "ws_ptr" will be present in the generated code. + + Args: + nelem (sympy.Expr): The number of elements to allocate. + zero_fill (bool): Whether to initialize the buffer to zero. + dtype (torch.dtype): the dtype of the workspace tensor + + Returns: + Tuple[str, str, int]: A tuple containing: + - "ws_ptr": A string identifier for the workspace pointer. + - "workspace_{i}": agraph level unique identifier for + the workspace tensor. + - offset: An integer representing the item offset in the workspace. + """ + arg = WorkspaceArg( + count=nelem, + zero_mode=WorkspaceZeroMode.from_bool(zero_fill), + device=V.graph.get_current_device_or_throw(), + outer_name=WorkspaceArg.unique_name(), + dtype=dtype, + ) + for i, existing_arg in enumerate(self.workspace_args): + if WorkspaceArg.can_join(existing_arg, arg): + offset = existing_arg.count + self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg) + return existing_arg.inner_name, existing_arg.outer_name, offset + assert ( + existing_arg.inner_name != arg.inner_name + and existing_arg.outer_name != arg.outer_name + ), existing_arg + self.workspace_args.append(arg) + return arg.inner_name, arg.outer_name, 0 + + def semaphores(self, min_size: sympy.Expr) -> str: + """ + Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by + all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit. + + Warning: multiple calls to this function will return the same buffer. + + Args: + min_size: the number of int32 semaphores required + + Returns: + name of the semaphores buffer + """ + current_device = V.graph.get_current_device_or_throw() + arg = WorkspaceArg( + count=min_size, + zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH, + dtype=torch.uint32, + inner_name="sem_ptr", + outer_name=f"semaphores_{current_device.type}_{current_device.index}", + device=current_device, + ) + for existing_arg in self.workspace_args: + if existing_arg.inner_name == arg.inner_name: + assert arg == existing_arg, (arg, existing_arg) + self.workspace_args.append(arg) + return arg.inner_name + + def seed_offset(self, name: str, value: int) -> str: + assert isinstance(value, int), (type(value), value) + # here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits + value = sympy.Integer(value) + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name: sympy.Symbol) -> str: + assert isinstance(name, sympy.Symbol), (type(name), name) + if name.name == "seed": + self.sizevars[name] = "seed" # don't manage the name of seeds + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self) -> Iterator[str]: + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def arg_name(self, name: str) -> Optional[str]: + """ + Returns inner name of a given outer name. + """ + inplaced = self.inplace_buffers.get(name, None) + if inplaced is not None and not isinstance(inplaced, RemovedArg): + return inplaced.inner_name + output_name = self.output_buffers.get(name, None) + if output_name is not None and not isinstance(output_name, RemovedArg): + return output_name + return self.input_buffers.get(name, None) + + def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str: + return buf + + def wrap_size_arg(self, size: SymbolLike) -> str: + return str(size) + + def cpp_argdefs( + self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None + ) -> tuple[list[str], list[str], list[str]]: + from .cpp_utils import INDEX_TYPE + + if dtype_to_cpp_type is None: + from .cpp_utils import DTYPE_TO_CPP + + dtype_to_cpp_type = DTYPE_TO_CPP + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, maybe_inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {maybe_inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + if isinstance(outer, sympy.Symbol) and symbol_is_type( + outer, (SymT.UNBACKED_FLOAT) + ): + arg_defs.append(f"const float {inner}") + arg_types.append("const float") + else: + arg_defs.append(f"const {INDEX_TYPE} {inner}") + arg_types.append(f"const {INDEX_TYPE}") + call_args.append(self.wrap_size_arg(outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert not self.workspace_args, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs( + self, + ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]: + arg_defs: list[ArgName] = [] + call_args: list[str] = [] + arg_types: list[Any] = [] + precompile_args: list[KernelArgType] = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + arg_defs.append(ArgName(inplaced.inner_name)) + call_args.append(inplaced.other_names[-1]) + arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), + # pyrefly: ignore [bad-argument-type] + self.output_buffers.items(), + ): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(V.graph.get_dtype(outer)) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(type(outer)) + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + for arg in self.workspace_args: + arg_defs.append(ArgName(arg.inner_name)) + call_args.append(arg.outer_name) + precompile_args.append(arg) + arg_types.append(arg.dtype) + return arg_defs, call_args, precompile_args, arg_types + + def aliases(self) -> Iterator[tuple[str, str]]: + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield cast(str, self.output_buffers[other]), inplaced.inner_name + + def is_removed(self, name: str) -> bool: + return isinstance( + self.output_buffers.get(name, REMOVED), RemovedArg + ) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self) -> OrderedSet[str]: + live_outs: OrderedSet[str] = OrderedSet() + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__( + self, + name: str, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ): + super().__init__() + assert isinstance(bounds, ValueRanges), type(bounds) + self.name = name + self.bounds = bounds + self.use_count = 1 # track how many times this expression is used + self.dtype = dtype + self.shape = shape + + def __str__(self) -> str: + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + return isinstance(other, CSEVariable) and other.name == self.name + + def update_on_args(self, name: str, args: Any, kwargs: Any) -> None: + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r})" + + +AugmentedKeyT = TypeVar("AugmentedKeyT", default=str) +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + +if TYPE_CHECKING: + ReductionCacheKey = tuple[ + torch.dtype, + ReductionType, + Union[CSEVariable, tuple[CSEVariable, ...]], + ] + + +class CSE(Generic[CSEVariableType, AugmentedKeyT]): + """Common subexpression elimination""" + + def __init__( + self, + prefix: str = "", + suffix: str = "", + name_prefix: str = "tmp", + iter_buffers: Optional[itertools.count[int]] = None, + store_cache: Optional[MutableMapping[str, CSEVariableType]] = None, + reduction_cache: Optional[ + MutableMapping[ReductionCacheKey, CSEVariableType] + ] = None, + varname_map: Optional[dict[str, CSEVariableType]] = None, + ): + self.prefix = prefix + self.suffix = suffix + self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {} + self.name_prefix = name_prefix + self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {} + self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = ( + reduction_cache or {} + ) + self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count() + self.invalidated_stores: OrderedSet[str] = OrderedSet() + self.varname_map: dict[str, CSEVariableType] = varname_map or {} + + def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None: + for name, tmp in [*self.store_cache.items()]: + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + if keep_vars: + self._cache = {k: v for k, v in self._cache.items() if v in keep_vars} + else: + self._cache = {} + + def clone(self) -> Self: + return type(self)( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + reduction_cache=self.reduction_cache, + ) + + def scoped_copy(self) -> Self: + """Return a copy of using ScopedDict so changes to *_cache aren't visible in self""" + new_cse = self.clone() + new_cse._cache = ScopedDict(self._cache) + new_cse.reduction_cache = ScopedDict(self.reduction_cache) + new_cse.store_cache = ScopedDict(self.store_cache) + return new_cse + + def augment_key(self, cache_key: str) -> AugmentedKeyT: + "Override this method to augment cache key with backend specifics" + return cast(AugmentedKeyT, cache_key) + + def put(self, cache_key: str, val: CSEVariableType) -> None: + self._cache[self.augment_key(cache_key)] = val + + def contains(self, cache_key: str) -> bool: + return self.augment_key(cache_key) in self._cache + + def try_get(self, cache_key: str) -> Optional[CSEVariableType]: + return self._cache.get(self.augment_key(cache_key), None) + + def get(self, cache_key: str) -> CSEVariableType: + return self._cache[self.augment_key(cache_key)] + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write: bool = True, + assignment: bool = True, + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ) -> CSEVariableType: + if isinstance(expr, OpsValue): + expr = expr.value + + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + expr.use_count += 1 + return cast(CSEVariableType, expr) + elif isinstance(expr, IndentedBuffer): + cache_key = expr.getvalue() + elif isinstance(expr, DeferredLineBase): + cache_key = expr.line + else: + assert isinstance(expr, str) + cache_key = expr + var = self.try_get(cache_key) + if shape is None and not assignment: + # since there's no assignment to a variable, use any shape here + # other than None to avoid the unknown shape failures + shape = () + if not var: + var = self.newvar(bounds, dtype, shape) + self.put(cache_key, var) + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + elif isinstance(expr, DeferredLineBase): + assert assignment + buffer.writeline( + expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}") + ) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + + # cpp backend cannot determine is_vec at this point + if ( + assignment + and ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ) + and dtype is not None + and get_current_backend() != "cpp" + ): + check_dtype(buffer, var, dtype) + + else: + var.bounds = var.bounds.tighten(bounds) + var.use_count += 1 + + return var + + def newvar( + self, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ) -> CSEVariableType: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype, shape) + self.varname_map[var_name] = var + return var + + def namedvar( + self, + name: str, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ) -> CSEVariableType: + torch._check_value( + name not in self.varname_map, lambda: f"duplicate name: {name}" + ) + var = V.kernel.create_cse_var(name, bounds, dtype, shape) + self.varname_map[name] = var + return var + + +class CodeGen: + def __init__(self) -> None: + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self) -> Self: + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class Kernel(CodeGen, Generic[CSEVariableType]): + newvar_prefix: str = "" + suffix: str = "" + overrides: Optional[Callable[[], OpsHandler[Any]]] = None + + def __init__( + self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True + ) -> None: + super().__init__() + if increase_kernel_count: + # pyrefly: ignore [bad-assignment] + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + + self.atomic_add_found = False + self.num_load = 0 + self.num_store = 0 + self.num_reduction = 0 + + self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self._load_mask: Optional[str] = None + self._load_other: Union[None, int, float] = None + # OrderedSet in set_current_node + self.current_node: Optional[SchedulerNode] = None + self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None + + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers: dict[str, str] = {} + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name: Optional[str] = None + + @contextlib.contextmanager + def set_current_node(self, node: SchedulerNode) -> Iterator[None]: + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers( + self, + lb: IndentedBuffer, + cb: Optional[IndentedBuffer] = None, + sb: Optional[IndentedBuffer] = None, + ) -> Iterator[None]: + if cb is None: + cb = lb + if disallow_stores := sb is None: + sb = IndentedBuffer() + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = cse.scoped_copy() + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + # pyrefly: ignore [unbound-name] + if disallow_stores: + assert not sb, "unexpected store inside swap_buffers" + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError + + def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable: + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + raise NotImplementedError + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError + + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + raise NotImplementedError + + def partial_accumulate( + self, + name: str, + reduction_type: ReductionType, + value: CSEVariable, + extra_meta: dict[str, Any], + ) -> None: + raise NotImplementedError + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + raise NotImplementedError + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError + + @property + def assert_function(self) -> str: + raise NotImplementedError + + def indirect_assert( + self, + var: Union[CSEVariable, str], + lower: Optional[str], + upper: Optional[str], + mask: Optional[Union[CSEVariable, str]] = None, + ) -> str: + if isinstance(var, CSEVariable): + var = str(var) + assert isinstance(var, str), type(var) + assert lower is None or isinstance(lower, str) + assert upper is None or isinstance(upper, str) + if lower and upper: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is supported by triton + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower} <= {var} < {upper}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = cond + else: + assert upper + cond = f"{var} < {upper}" + cond_print = cond + + if mask: + cond = f"({cond}) | ~({mask})" + + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError + + def __enter__(self) -> Self: + super().__enter__() + assert self.overrides + self.exit_stack.enter_context( + V.set_ops_handler(CSEProxy(self, self.overrides())) + ) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def remove_kernel_local_buffers(self) -> None: + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + scheduler = V.graph.scheduler + if not scheduler: + return + fused_node_names = OrderedSet( + scheduler.name_to_buf[buf].defining_op_name() + for buf in self.store_buffer_names + if buf in scheduler.name_to_buf + ) + names_to_remove: OrderedSet[str] = OrderedSet() + for name in self.store_buffer_names: + if ( + name not in self.must_keep_buffers + and name not in self.args.input_buffers + and scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ) + ): + self.num_store -= 1 + names_to_remove.add(name) + + for name in names_to_remove: + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if isinstance(buf, RemovedArg): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + self.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + self.args.output_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + self.args.inplace_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def rename_indexing( + self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr] + ) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if symbol_is_type( + x, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.UNBACKED_FLOAT, + ), + ) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable: + return CSEVariable(*args, **kwargs) + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return self.args.arg_name(node.get_name()) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + +@functools.cache +def jinja2_env() -> Any: + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def indent_except_first( + source: str, num_indents: int, indents_spacing: int = 4 + ) -> str: + lines = source.splitlines(True) + if len(lines) > 1: + lines[1:] = [ + (" " * indents_spacing * num_indents) + line for line in lines[1:] + ] + return "".join(lines) + + @staticmethod + def _template_from_string(source: str) -> Any: + env = jinja2_env() + if env is None: + return None + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + try: + return env.from_string(source) + except TemplateSyntaxError as e: + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error: TemplateSyntaxError) -> None: + super().__init__( + # pyrefly: ignore [bad-argument-type] + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self) -> str: + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + # pyrefly: ignore [missing-attribute] + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i + 1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i + 1}: {lines[i]}\n" + return error_info + + raise DetailedTemplateSyntaxError(e) from e + + @staticmethod + def _fake_get_dtype( + fake_outs: Union[list[Buffer], Buffer], + ) -> Callable[[str], torch.dtype]: + _get_dtype_real = V.graph.get_dtype + if isinstance(fake_outs, (list, tuple)): + lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs} + else: + lookup = {fake_outs.get_name(): fake_outs.get_dtype()} + + def get_dtype(name: str) -> torch.dtype: + result = lookup.get(name) + if result is not None: + return result + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str, hash: Optional[str] = None) -> None: + self.name = name + self._hash = hash + + @property + def uid(self) -> str: + """ + entry point to override for templates to ensure a uid e.g. through a prefix + + the purpose of this is that every KernelTemplate/ExternKernelChoice is unique + in the system, but reproducible e.g. restarting pytorch should yield the same id + """ + # TODO(coconutruben): add some central registration to assert on global uniqueness + return self.name + + @property + def src_hash(self) -> Union[str, None]: + """ + source hash for a Template. + + Templates can optionally provide a src hash to make it easier to cache/validate that + a template has not changed from one version to another. Override this if that detection + is different for your specific Template + """ + return self._hash + + def choice_or_none(self, **kwargs: Any) -> Optional[ChoiceCaller]: + """ + Maybe generates a new ChoiceCaller and returns it, or None if generation fails. + + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + temp_choices: list[Any] = [] + result = self.maybe_append_choice(temp_choices, **kwargs) + if result is None and len(temp_choices) == 1: + return temp_choices[0] + return None + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.info( # noqa: G200 + "Cannot Append Choice: %s. KernelTemplate type is %s", + e, + type(self), + stack_info=log.getEffectiveLevel() < logging.INFO, + ) + return e + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError + + +class CSEProxy(DefaultHandler): + """A ops handler that proxies calls to `kernel` and its + handler and returns `CSEVariable`s with correct shape and dtype. + """ + + name = "CSEProxy" + + def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): + super().__init__() + from ..bounds import ValueRangeAnalysis + + self.vr_analysis = ValueRangeAnalysis() + self.kernel = kernel + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bounds = self._bound_variable(name, *args, **kwargs) + + value = getattr(self.parent_handler, name)(*args, **kwargs) + dtype_handler = DtypePropagationOpsHandler() + shape_handler = ShapePropagationOpsHandler() + + backend = get_current_backend() + + shape_op = getattr(shape_handler, name) + output_dtype = None + output_shape = None + + if name == "masked" and backend == "triton": + output_dtype = value.dtype + output_shape = value.shape + elif name == "masked" and backend == "cpp": + output_dtype = V.interpreter.current_node.meta.get( + OptimizationContext.key, None + ).dtype + # TODO: fix me + output_shape = None + elif backend in ("triton", "cpp", "mps"): + dtype_op = getattr(dtype_handler, name) + output_dtype = dtype_op(*args, **kwargs) + output_shape = shape_op(*args, **kwargs) + + if backend in ("triton", "cpp"): + # maybe there are some exceptions on mps? + assert output_dtype is not None + + output_idx = 0 + + def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: + # we tree_map over the output, so we need to fetch corresponding dtype + nonlocal output_idx + var_dtype: Optional[torch.dtype] = ( + output_dtype[output_idx] + if isinstance(output_dtype, (list, tuple)) + else output_dtype + ) + var_shape: BlockShapeType = ( + output_shape[output_idx] # type: ignore[assignment] + if isinstance(output_shape, (list, tuple)) + and len(output_shape) > 0 + and isinstance(output_shape[0], (list, tuple)) + else output_shape + ) + output_idx += 1 + + # some cpp op implementations don't set the dtype + if isinstance(v, CSEVariable): + if backend == "cpp" and v.dtype is None: + v.dtype = var_dtype + if v.shape is None: + v.shape = var_shape + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + shape=output_shape, + ) + + csevar.update_on_args(name, args, kwargs) + + if ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ): + assert var_dtype is not None + check_dtype(V.kernel.compute, csevar, var_dtype) + + if config.test_configs.runtime_triton_shape_assert: + assert output_shape is not None + check_shape(V.kernel.compute, csevar, output_shape) + + if config.runtime_triton_nan_asserts: + check_nan(V.kernel.compute, csevar) + + return csevar + + return pytree.tree_map(do_cse, value) + + def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..bounds import ValueRangeAnalysis + from ..select_algorithm import TritonTemplateKernel + from .cuda.cuda_kernel import CUDATemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + if isinstance(V.kernel, CUDATemplateKernel): + return ValueRanges.unknown() + + if isinstance(V.interpreter, NullHandler): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.kernel.node_to_bounds is not None: + assert isinstance(self.kernel.node_to_bounds, dict), type( + self.kernel.node_to_bounds + ) + return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x: Any) -> Any: + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(self.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + def indirect_indexing( + self, + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + if isinstance(size, int): + size = sympy.Integer(size) + assert isinstance(size, sympy.Expr), (type(size), size) + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.kernel.cse.generate( + self.kernel.compute, + stm, + bounds=new_bounds, + dtype=var.dtype, + shape=var.shape, + ) + + sympy_var = self.parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + return self.kernel.check_bounds(expr, size, lower, upper) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + if name in self.kernel.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.kernel.indirect_load(name, index) + store_cache = self.kernel.cse.store_cache + if name in store_cache: + return store_cache[name] + out = self.kernel.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.kernel.num_load += 1 + return out + + def _update_store_cache(self, name: str, value: CSEVariable) -> None: + self.kernel.cse.store_cache[name] = value + if self.kernel.current_node and name in V.graph.name_to_buffer: + buf = self.kernel.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.kernel.cse.store_cache[other_name] = value + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.kernel.store_buffer_names.add(name) + if mode is None: + self._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + self.kernel.store(name, index, value, mode=mode) + self.kernel.num_store += 1 + + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + self.kernel.device_assert_async(cond, msg) + + # pyrefly: ignore [bad-override] + def partial_accumulate(self, *args: Any) -> None: + self.kernel.partial_accumulate(*args) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + self.kernel.store_buffer_names.add(name) + self._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + self.kernel.num_store += 1 + return self.kernel.store_reduction(name, index, value) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + self.kernel.num_reduction += 1 + return self.kernel.reduction(dtype, src_dtype, reduction_type, value) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], + tuple[CSEVariable, ...], + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + return self.kernel.scan(dtypes, combine_fn, values) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + return self.kernel.sort(dtypes, values, stable, descending) + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to be bucketized. + boundaries: a tuple containing + (a) the name of the boundaries tensor (which must be sorted, unless + the sorting tensor is present), + (b) the length of the tensor in the last dimension (i.e. the length of + one set of boundaries), + (c) the number of elements in the underlying storage (i.e. the length + of the flattened tensor, ignoring striding), and + (d) the stride of the tensor in the last dimension. + boundary_indices: indices into a flattened version of the boundaries + tensor, of the same size and shape as "values". Each index points to + the first element in the set of boundaries to be used for the + corresponding value. + indexing_dtype: the dtype to use when indexing into the boundaries + tensor. This must be int64 or int32. This additionally specifies the + dtype of the return value. + right: see "Details" below. + sorter: an optional tuple containing + (a) the name of an optional sorting tensor, used to access unsorted + boundaries without reordering the boundaries tensor, and + (b) the stride of the tensor in the last dimension. + The values in the sorting tensor are used as indices into the *last* + dimension of the boundaries tensor, with all other indices matching. + The size of the sorting and boundaries tensors must be equivalent. + sorter_indices: must be present if the sorting array is present; see + "boundary_indices" for the equivalent definition for the boundaries + tensor. + + Output: + ------- + The buckets each value belongs in, within a given set of boundaries. 0 + indicates a position before the first boundary, and len(boundaries_set) + represents a position after the last boundary. + + Details: + -------- + Given a value and a set of boundaries, calculate the bucket that each + value belongs to. This works differently in 1-D and N-D cases. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True + return = [[ 0, 1, 1, 1], [1, 3, 3, 4]]. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True + return = [[ 0, 1, 1, 1], [0, 1, 1, 2]] + + Note that in the N-D boundaries case, the shape of "values" and + "boundaries" must match in every dimension _except_ the last. + + When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]]. + When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]). + + Boundaries must be non-decreasing, or a sorter must be provided which + would re-index offsets in a non-decreasing order (e.g. the second output + of torch.sort(offsets)). Otherwise, the result is undefined. + """ + return self.kernel.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c45cd32981418fe1121c47c78aaac35b0e65b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,5826 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import math +import operator +import re +import sys +import warnings +from collections.abc import Callable, Sequence +from enum import Enum +from typing import Any, cast, Optional, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._prims_common import is_float_dtype, is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +from ..debug import set_kernel_post_grad_provenance_tracing +from ..loop_body import LoopBody +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + ExternKernelSchedulerNode, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + has_free_symbols, + is_multi_outputs_template, + is_welford_reduction, + parallel_num_threads, + Placeholder, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from ..virtualized import NullKernelHandler, ops, OpsValue, V +from .common import ( + BackendFeature, + BracesBuffer, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) +from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, + cexpr, + cexpr_index, + codegen_rand, + CppCSEVariable, + DTYPE_TO_CPP, + get_promote_dtype, + INDEX_TYPE, + LocalBufferContext, + may_unify_binary_op_mask_type, + promote_args, + template_fusion_with_epilogues_supported, + unify_mask_base_type, + value_to_cpp, +) + + +_IS_WINDOWS = sys.platform == "win32" + + +@functools.cache +def get_export_declaration(): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +NATIVE_OMP_RTYPES = OrderedSet(["+", "*", "^", "||", "min", "max"]) +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = OrderedSet( + [ + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + "argmin", + "argmax", + "any", + ] +) + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "std::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + +VECTORIZABLE_DTYPES: list[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, +] + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in ("max", "argmax", "min", "argmin"): + cdtype = DTYPE_TO_CPP[dtype] + if dtype == torch.bool and reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[torch.float] + min_var = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + max_var = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + init_var = min_var if reduction_type in ("max", "argmax") else max_var + return ( + init_var + if reduction_type in ("max", "min") + else f"IndexValue<{cdtype}>{{0, {init_var}}}" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + if reduction_type in ("argmin", "argmax"): + if dtype == torch.bool: + scalar_type = DTYPE_TO_CPP[torch.float] + return f"IndexValue<{scalar_type}>" + return scalar_type + + +def reduction_combine( + reduction_type, + var, + next_value, + helper_val=None, + index: Optional[sympy.Symbol] = None, + src_dtype=None, +): + is_bool = src_dtype == torch.bool + if reduction_type == "sum": + if helper_val: + return f"cascade_sum_combine({next_value}, &{helper_val})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + if helper_val: + return f"welford_combine({var}, {next_value}, &{helper_val})" + else: + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + if reduction_type in ("argmin", "argmax"): + if ( + hasattr(next_value, "dtype") + and next_value.dtype == torch.bool + and not next_value.is_vec + ): + if index is not None: + return f"{reduction_type}_combine({var}, static_cast({next_value}), {index})" + else: + return ( + f"{reduction_type}_combine({var}, static_cast({next_value}))" + ) + if index is not None: + return f"{reduction_type}_combine({var}, {next_value}, {index})" + else: + return f"{reduction_type}_combine({var}, {next_value})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in ("argmin", "argmax"): + return f"{acc}.index" + return acc + + +def move_code_under_inner_loop( + code: IndentedBuffer, + iter_var: sympy.Expr, + new_iter_var: str, + loop_start: sympy.Expr, + loop_end: sympy.Expr, +) -> BracesBuffer: + r""" + f(iter_var) is transformed to f(new_iter_var) under the inner loop + \/ + for (new_iter_var = loop_start; new_iter_var < loop_end; new_iter_var++) { + f(new_iter_var) + } + Please be careful while using this function, + as the variable defined in f(iter_var) will be invalid outside the for loop. + For example: + auto tmp0 = in_ptr[x0]; -> + for (new_x0 = start; new_x0 < end; new_x0++){ + auto tmp0 = in_ptr[new_x0]; + } + The tmp0 is invalid outside the loop. + """ + transformed_code = BracesBuffer() + with contextlib.ExitStack() as stack: + transformed_code.writeline( + f"for ({INDEX_TYPE} {new_iter_var} = {cexpr_index(loop_start)};" + + f"{new_iter_var} < {cexpr_index(loop_end)}; {new_iter_var}++)" + ) + stack.enter_context(transformed_code.indent()) + for _, line in enumerate(code._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + deferred_name = None + if isinstance(line, DeferredLine): + deferred_name = line.name + line = line.line + new_line = re.sub(r"\b" + f"{iter_var}" + r"\b", f"{new_iter_var}", line) + if deferred_name: + new_line = DeferredLine(deferred_name, new_line) # type: ignore[assignment] + transformed_code.writeline(new_line) + return transformed_code + + +def reduction_prefix_array( + acc_var: Union[str, CSEVariable], + acc_type: str, + reduction_type: str, + dtype: torch.dtype, + len: Union[str, int], + init_fn, +): + """ + MSVC don't support dynamic array(VLA). So we use std::unique_ptr here. + Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler + MSVC is the only one compiler without VLA. support. Since MSVC can't get good performance here. + We just use unique_ptr make it works on MSVC. + For other compilers, we continue to use VLA to get best performance. + """ + code_buffer = IndentedBuffer() + acc_decl = ( + f"auto {acc_var}_arr = std::make_unique<{acc_type}[]>({len});" + if cpp_builder.is_msvc_cl() + else f"{acc_type} {acc_var}_arr[{len}];" + ) + code_buffer.writeline(f"{acc_decl}") + code_buffer.writelines( + [ + f"for (int i = 0; i < {len}; i++)", + "{", + f" {acc_var}_arr[i] = {init_fn(reduction_type, dtype)};", + "}", + ], + ) + return code_buffer + + +def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str): + for i, line in enumerate(buffer._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + if isinstance(line, DeferredLine): + line.line = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line.line) + else: + buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line) + + +def replace_cascade_sum_with_add(buffer: IndentedBuffer): + """ + Replaces `acc = cascade_sum_combine(value, ...)` with `acc = acc + value;` + """ + + pattern = r"(.*?)\s*=\s*cascade_sum_combine\(([^,]+),.*?\);" + for i, line in enumerate(buffer._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + content = line.line if isinstance(line, DeferredLine) else line + match = re.search(pattern, content) + if match: + acc, value = match.groups() + new_content = re.sub(pattern, f"{acc} = {acc} + {value};", content) + if isinstance(line, DeferredLine): + line.line = new_content + else: + buffer._lines[i] = new_content + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + if not index.has(var): + # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu + # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. + # in this case, there is no dependencies between index and var. + return sympy.S.Zero + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor", integer=True) + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus", integer=True) + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) + + +@dataclasses.dataclass +class ParallelDepth: + """ + A class representing parallel depth. + Includes the starting depth of parallelism and the depth of parallelism. + """ + + parallel_depth: int + start_depth: int + + +class OuterLoopFusedSchedulerNode(FusedSchedulerNode): + @classmethod + def fuse( # type: ignore[override] + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth + ): + assert node1.scheduler is node2.scheduler + assert all( + type(node) + in ( + OuterLoopFusedSchedulerNode, + SchedulerNode, + FusedSchedulerNode, + ) + for node in (node1, node2) + ) + if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return cls( + node1.scheduler, + # pyrefly: ignore [bad-argument-type] + ( + list(node1.get_outer_nodes()) + if type(node1) is OuterLoopFusedSchedulerNode + else [ + node1, + ] + ) + + ( + list(node2.get_outer_nodes()) + if type(node2) is OuterLoopFusedSchedulerNode + else [ + node2, + ] + ), + outer_loop_fusion_depth, + ) + else: + return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] + + def __init__( + self, + scheduler: "Scheduler", + outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]], + outer_loop_fusion_depth, + ): + self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = ( + outer_fused_nodes + ) + self.outer_loop_fusion_depth = outer_loop_fusion_depth + flatten_snodes = [] + for _node in self.outer_fused_nodes: + assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) + flatten_snodes.extend(list(_node.get_nodes())) + super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] + + def get_outer_nodes(self): + return self.outer_fused_nodes + + def check_outer_fusion_loop_level_attr( + self, cpp_kernel_proxy_list, outer_loop_fusion_depth + ): + # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. + # In the fusion stage, we only examine nodes with same vars and reduce. + # However, for nodes with same vars and reduce, the loops may still have different tile splits. + # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): + # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. + # If the check failed, we should fall back to standard loop codegen. + def _inner( + left_loop_nest: LoopNest, + right_loop_nest: LoopNest, + loop_fusion_depth: int, + current_checking_depth: int, + ) -> bool: + assert left_loop_nest.loops + assert right_loop_nest.loops + left_loop_level = left_loop_nest.loops[current_checking_depth] + right_loop_level = right_loop_nest.loops[current_checking_depth] + # Check if same loop level attr + outer_loops_attr_compare_list = [ + "var", + "size", + "offset", + "steps", + ] + if not ( + all( + getattr(left_loop_level, attr_compare) + == getattr(right_loop_level, attr_compare) + for attr_compare in outer_loops_attr_compare_list + ) + ): + return False + + assert loop_fusion_depth >= 1 + if (loop_fusion_depth := loop_fusion_depth - 1) > 0: + # Check next loop level attr + current_checking_depth = current_checking_depth + 1 + assert current_checking_depth < len(left_loop_nest.loops) + assert current_checking_depth < len(right_loop_nest.loops) + if not _inner( + left_loop_nest, + right_loop_nest, + loop_fusion_depth, + current_checking_depth, + ): + return False + + return True + + for idx in range(len(cpp_kernel_proxy_list) - 1): + left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest + right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest + if not _inner( + left_loop_nest, + right_loop_nest, + outer_loop_fusion_depth, + 0, + ): + return False + + for cpp_kernel_proxy in cpp_kernel_proxy_list: + outer_ranges = functools.reduce( + operator.mul, + cpp_kernel_proxy.ranges[:outer_loop_fusion_depth], + ) + # When the range of the first inner loop is much larger than the range of + # all outer loops, do not fuse outer loop and fallback to standard loop codegen, + # so that the inner loops with larger range have a chance to be parallelized. + # We set a conservative threshold here: + # First inner loop range / all outer loops range > 300. + if ( + len(cpp_kernel_proxy.ranges) > outer_loop_fusion_depth + and isinstance(outer_ranges, sympy.Integer) + and isinstance( + cpp_kernel_proxy.ranges[outer_loop_fusion_depth], + sympy.Integer, + ) + and outer_ranges * 300 + < cpp_kernel_proxy.ranges[outer_loop_fusion_depth] + ): + return False + + return True + + def merge_outer_fusion_kernels( + self, + cpp_kernel_proxy_list, + ): + kernel_group = cpp_kernel_proxy_list[0].kernel_group + outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) + outer_loop_fused_kernel.inner = [ + proxy.loop_nest.from_loop_level(self.outer_loop_fusion_depth) + for proxy in cpp_kernel_proxy_list + ] + outer_fused_proxy = cpp_kernel_proxy_list[0] + outer_fused_proxy.loop_nest.kernel = outer_loop_fused_kernel + outer_fused_proxy.loop_nest.loops = outer_fused_proxy.loop_nest.loops[ + : self.outer_loop_fusion_depth + ] + return outer_fused_proxy + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +def decltype_promoted(*args): + assert not any(isinstance(arg, CppCSEVariable) and arg.is_vec for arg in args), ( + "Promotion of vector types is not supported" + ) + + if (dt := get_promote_dtype(args)) is not None: + return DTYPE_TO_CPP[dt] + else: + return f"decltype({args[0]})" + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"{decltype_promoted(a, b)}({a} + {b})" + + @staticmethod + def sub(a, b): + return f"{decltype_promoted(a, b)}({a} - {b})" + + @staticmethod + def mul(a, b): + return f"{decltype_promoted(a, b)}({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): + assert isinstance(x, CppCSEVariable) + if src_dtype is None: + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + """ + On windows std::signbit only support float type. + Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 + """ + return ( + f"std::signbit(static_cast({x}))" + if _IS_WINDOWS + else f"std::signbit({x})" + ) + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) + mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def log2(x): + return f"std::log2({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})(0);") + code.writeline( + f"return decltype({a})(static_cast>({a}) << {b});" + ) + code.writeline("()") + return code + + @staticmethod + def bitwise_right_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT - std::is_signed_v<{scalar_t}>;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})({a} >> max_shift);") + code.writeline(f"return decltype({a})({a} >> {b});") + code.writeline("()") + return code + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + def partial_accumulate( + self, + name: str, + reduction_type: str, + value: CSEVariable, + extra_meta: dict[str, Any], + ) -> None: + raise NotImplementedError + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, (int, sympy.Expr)) + or (isinstance(arg, CppCSEVariable) and not arg.is_vec) + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + new_args = [] + for arg in args: + if isinstance(arg, (int, sympy.Expr)): + if isinstance(arg, sympy.Expr) and not arg.is_number: + arg = ops.index_expr(arg, torch.int64) + else: + arg = ops.constant(arg, torch.int64) + arg = arg.value if isinstance(arg, OpsValue) else arg + new_args.append(arg) + + # DType Promotion + if vectors: + # We have saw several data type mismatch issues related with index_expr in + # the lowering phase of torch.int8. torch.int32, torch.int64. + # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu + # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu + # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu + if len(new_args) == 2: + new_args = promote_args(new_args) + elif func is CppVecOverrides.where: + new_args[1:] = promote_args(new_args[1:]) + + # Broadcast scalar args to vector + if scalars and vectors: + assert isinstance(V.kernel, CppVecKernel) + new_args = [ + ( + V.kernel.broadcast(new_arg) + if ( + isinstance(new_arg, CppCSEVariable) + and not new_arg.is_vec + and func + not in [ + CppVecOverrides.rand, + CppVecOverrides.randn, + CppVecOverrides.randint64, + ] + ) + else new_arg + ) + for new_arg in new_args + ] + + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr(scalar_ops, func.__name__) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) is staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" + + @staticmethod + def ne(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + + @staticmethod + def lt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" + + @staticmethod + def gt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" + + @staticmethod + def le(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" + + @staticmethod + def ge(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"~{a}" + + @staticmethod + def logical_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def load_seed(name, offset): + assert isinstance(V.kernel, CppVecKernel) + return f"{V.kernel.load(name, offset)}" + + @staticmethod + def rand(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = ( + f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" + ) + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randn(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randint64(seed, offset, low, high): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" + return codegen_rand(offset, code, rand_function, torch.int64) + + @staticmethod + def remainder(a, b): + assert a.dtype == b.dtype, ( + "remainder vec implementation expect the same inputs' dtype." + ) + return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + if config.cpp.use_decompose_tanh: + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return ( + f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + ) + else: + return f"{a}.tanh()" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def log2(x): + return f"{x}.log2()" + + @staticmethod + def nextafter(x, y): + return f"{x}.nextafter({y})" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + return f"{x}.asinh()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + if is_float_dtype(a.dtype): + assert a.dtype == b.dtype, ( + "div_floor_floating_vec implementation expect the same inputs' dtype." + ) + return f"div_floor_floating_vec({a}, {b})" + else: + assert all(is_integer_dtype(item.dtype) for item in [a, b]) + # a and b are integer type + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({b})" + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(V.kernel, CppVecKernel) + if b.dtype == torch.bool: + assert c.dtype == torch.bool + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + else: + return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): + assert dtype in [ + torch.bool, + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], f"{__name__} does not support {dtype}" + assert isinstance(x, CppCSEVariable) + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + dtype = result.dtype + body_code = f"{var}()" + + def maskify_or_vecify(code): + return ( + f"{V.kernel._get_mask_type()}::from({code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({code})" + ) + + if result.is_vec: + body_code_vec = body_code + else: + body_code_vec = maskify_or_vecify(body_code) + other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) + # loading bool as VecMask + other_code_vec = maskify_or_vecify(other_code) + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec: + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if ({new_mask}.all_zero())") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + # Create cse variable to reuse kernel.overrides.where + body_vec_var = V.kernel.cse.generate( + V.kernel.compute, + body_code_vec, + ) + other_vec_var = V.kernel.cse.generate( + V.kernel.compute, + other_code_vec, + ) + assert isinstance(body_vec_var, CppCSEVariable), body_vec_var + assert isinstance(other_vec_var, CppCSEVariable), other_vec_var + body_vec_var.dtype = dtype + other_vec_var.dtype = dtype + overrides: type[Union[CppOverrides, CppVecOverrides]] = ( + # pyrefly: ignore [bad-assignment] + V.kernel.overrides + ) # type: ignore[has-type] + code.writeline( + f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + result.is_vec = True + elif result.is_vec: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = V.kernel._try_get_const_stride(index, tiling_var) + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + elif stride is not None: + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] + None, index, dtype, V.kernel.compute + ) + # pyrefly: ignore [missing-attribute] + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @classmethod + def _scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if output_mask: + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + elif n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + vec_vars = vars(CppVecOverrides) + for name, method in vars(CppOverrides).items(): + if isinstance(method, staticmethod) and name not in vec_vars: + func = cls._scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + """ + Base class for C++ kernel code generation in PyTorch Inductor. + This class is responsible for generating C++ code from the intermediate representation. + + Args: + args: Kernel arguments used for code generation + num_threads: Number of threads for parallel execution + """ + + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + # Indicate when this kernel is active, for example + # {x0, {24, 26}} -> this kernel is active when x0 >= 24 and x0 < 26 + self.active_ranges: dict[sympy.Expr, tuple[sympy.Expr, ...]] = {} + # Indicate this kernel will be moved under the inner for-loop + # See move_code_under_inner_loop + self.inner_itervars: list[sympy.Symbol] = [] + self.call_ranges: Optional[tuple[sympy.Expr, ...]] = None + self.ranges: list[sympy.Expr] = [] + self.itervars: list[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + # We need this because when we run "reduction" nodes here, we lack + # "loop" information to decide whether we need a scalar init or an array init + # in the reduction prefix. Meanwhile, we have other information like + # reduction types and dtype to generate the reduction prefix. We record the information + # with a callable lambda function, and when we have enough information to finalize + # the reduction prefix, we can invoke the functions here with additional information. + self.reduction_prefix_generators: list[Callable] = [] # type: ignore[type-arg] + self.reduction_suffix = IndentedBuffer() + self.parallel_reduction_prefix = IndentedBuffer() + self.parallel_reduction_suffix = IndentedBuffer() + self.local_reduction_init = IndentedBuffer() + self.local_reduction_stores = IndentedBuffer() + self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() + self.non_parallel_reduction_suffix = IndentedBuffer() + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.welford_helper_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="welford_helper" + ) + self.cascade_helper_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="cascade_helper" + ) + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: dict[tuple[str, str], str] = {} + self.reduction_var_names: list[str] = [] + + def _gen_parallel_reduction_buffers( + self, + acc, + acc_type, + reduction_type, + dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ): + if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: + self.parallel_reduction_prefix.writeline( + "int max_threads = omp_get_max_threads();" + ) + acc_local = f"{acc}_local" + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + acc_local_in_array = f"{acc}_arr[tid]" + self.local_reduction_init.writeline( + f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" + ) + self.parallel_reduction_prefix.splice( + reduction_prefix_array( + acc, + acc_type, + reduction_type, + dtype, + num_threads, + reduction_init_fn, + ) + ) + self.local_reduction_stores.writeline(f"{acc_local_in_array} = {acc_local};") + self.parallel_reduction_suffix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", + "}", + ], + ) + + def update_stores_with_parallel_reduction(self): + for var_name in self.reduction_var_names: + replace_acc_name(self.stores, var_name, f"{var_name}_local") + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is None + code = BracesBuffer() + with contextlib.ExitStack() as stack: + if hasattr(self, "codegen_inner_loops"): + code.splice(self.preloads) + self.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(self.loads) + code.splice(self.compute) + code.splice(self.stores) + if hasattr(self, "codegen_inner_loops"): + code.splice(self.poststores) + + if self.inner_itervars: + for idx in self.inner_itervars: + start, end = self.active_ranges[idx] + code = move_code_under_inner_loop(code, idx, f"{idx}_tail", start, end) + return code + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + # pyrefly: ignore [bad-assignment] + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def var_ranges(self): + return dict(zip(self.itervars, self.ranges)) + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + indirect = free_symbol_is_type(expr, SymT.TMP) + if indirect: + # indexing in compute + csevar = ops.index_expr(expr, torch.int64).value + buffer = V.kernel.compute + else: + # indexing in loads + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads + + size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None + + line = self.indirect_assert( + csevar, "0" if lower else None, size_str, self._load_mask + ) + self.cse.generate(buffer, line, assignment=False) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + csevar = self.cse.generate(self.loads, line, dtype=V.graph.get_dtype(name)) + csevar.update_on_args("load", (self, name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def device_assert_async(self, cond, msg): + self.compute.writeline( + f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0));' + ) + + def _gen_reduction_prefix( + self, + acc: Union[CSEVariable, str], + acc_type: str, + rtype: str, + dtype: torch.dtype, + init_fn, + ): + # Generate reduction prefix + # If size is None, we will define and initialize a single reduction variable + # => float tmp_acc0 = 0; + # Otherwise, we will define and initialize a reduction array + # => float tmp_acc0_arr[size]; + # => for (int i = 0; i < size; i++) tmp_acc0_arr[i] = 0; + def inner(size: Optional[int] = None): + if size is None: + return f"{acc_type} {acc} = {init_fn(rtype, dtype)};" + else: + return reduction_prefix_array( + acc, + acc_type, + rtype, + dtype, + size, + init_fn, + ) + + return inner + + def finalize_reduction_prefix(self, size: Optional[int] = None): + for gen_fn in self.reduction_prefix_generators: + self.reduction_prefix.splice(gen_fn(size)) + + def need_use_acc_helper(self, reduction_type, dtype, use_scalar): + # Check if we need accumulate helper for the reduction operation. + # using accumulate helper generates the necessary code to improve precision for + # sum and welford + # Note: using helper has non-negligible impact on performance + + if reduction_type == "welford_reduce": + return True + + # TODO add supports for more data types when needed + if reduction_type == "sum" and dtype == torch.float: + assert self.call_ranges is not None + reduction_size = functools.reduce( + operator.mul, self.call_ranges[self.reduction_depth :] + ) + + # chunk size to balance accuracy and performance + chunk_size = 4096 + + # use acc helper If cannot get size_hint + try: + reduction_size_hint = V.graph.sizevars.size_hint(reduction_size) + except Exception: + return True + + if reduction_size_hint > chunk_size: + # use helper if the reduction size is too large + V.graph.sizevars.check_lt(chunk_size, reduction_size) + return True + else: + V.graph.sizevars.check_leq(reduction_size, chunk_size) + return False + + def _acc_helper_init( + self, + reduction_type, + helper_val, + helper_range, + dtype, + num_threads=None, + use_scalar=False, + ): + num_range_thread = ( + CeilDiv(helper_range, num_threads) if num_threads else helper_range + ) + num_range_thread_expr = cexpr_index(num_range_thread) + assert reduction_type in ["welford_reduce", "sum"] + chunk_size = 4096 + num_chunks = CeilDiv(num_range_thread, chunk_size) + helper_type = ( + "WelfordHelper" + if reduction_type == "welford_reduce" + else "CascadeSumHelper" + ) + if use_scalar: + h_type = DTYPE_TO_CPP[dtype] + else: + h_type = ( + self._get_vec_type(dtype) + if hasattr(self, "_get_vec_type") + else DTYPE_TO_CPP[dtype] + ) + helper_init_line = ( + f"{helper_type}<{h_type}, {chunk_size}> {helper_val}" + f"(" + f"{num_range_thread_expr}" + f");" + ) + if reduction_type == "sum": + return helper_init_line + if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: + # When the number of chunks <= 1, there is no need to use cascade summation to improve + # reduction accuracy. We can initialize a static WelfordHelper to improve performance. + return f"static {helper_init_line}" + else: + return helper_init_line + + def _use_acc_helper( + self, reduction_type, acc, helper_val, helper_range, dtype, use_scalar=False + ): + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + self.non_parallel_reduction_prefix.writeline( + self._acc_helper_init( + reduction_type, helper_val, helper_range, dtype, None, use_scalar + ) + ) + self.local_reduction_init.writeline( + self._acc_helper_init( + reduction_type, helper_val, helper_range, dtype, num_threads, use_scalar + ) + ) + result = acc if use_scalar else f"{acc}_vec" + if reduction_type == "welford_reduce": + self.non_parallel_reduction_suffix.writeline( + f"{result} = welford_combine({result}, &{helper_val});" + ) + self.local_reduction_stores.writeline( + f"{result}_local = welford_combine({result}_local, &{helper_val});" + ) + else: + self.non_parallel_reduction_suffix.writeline( + f"{result} = cascade_sum_final(&{helper_val});" + ) + self.local_reduction_stores.writeline( + f"{result}_local = cascade_sum_final(&{helper_val});" + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in ("argmax", "argmin") + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.reduction_var_names.append(f"{acc}") + self.is_reduction = True + init_dtype = src_dtype if argmax_or_argmin else dtype + acc_type = reduction_acc_type(reduction_type, init_dtype) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + + if self.need_use_acc_helper(reduction_type, dtype, True): + # use cascade_helper for vec kernel + reduction_size = functools.reduce( + operator.mul, self.ranges[self.reduction_depth :] + ) + # use welford_helper/cascade_helper for vec kernel + if reduction_type == "welford_reduce": + helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + else: + helper_val = self.cascade_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + # rename the helper variable to distinguish it from vectorized version + scalar_helper_val = f"scalar_{helper_val}" + self._use_acc_helper( + reduction_type, + acc, + scalar_helper_val, + reduction_size, + dtype, + use_scalar=True, + ) + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, scalar_helper_val)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index=index)};" + ) + + self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), ( + f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + ) + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(self.ranges)) + ] + # pyrefly: ignore [bad-assignment] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + assert self.call_ranges is not None + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + assert isinstance(self, CppKernelProxy) + threads = parallel_num_threads() + assert self.call_ranges is not None + if isinstance(loop_nest.kernel, OuterLoopFusedKernel): + par_depth = loop_nest.kernel.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + else: + par_depth = self.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + + is_reduction_loop = ( + loop_nest.loops is not None + and loop_nest.loops[par_depth.start_depth].is_reduction + ) + with contextlib.ExitStack() as stack: + if par_depth.parallel_depth: + if is_reduction_loop: + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_kernel(_loop_nest: LoopNest): + def is_parallel_reduction(): + assert _loop_nest.loops + root = _loop_nest.loops[par_depth.start_depth] + return root.is_reduction and root.parallel + + kernel = _loop_nest.get_kernel() + if isinstance(kernel, OuterLoopFusedKernel): + for _loop_nest in kernel.inner: + gen_loop_nest(_loop_nest) + else: + assert isinstance(kernel, CppKernelProxy) + if _loop_nest.loops is not None and is_parallel_reduction(): + kernel.update_stores_with_parallel_reduction() + with contextlib.ExitStack() as stack: + stack.enter_context(code.indent()) + kernel.gen_body(code) + + def get_reduction_prefix_suffix(kernel, parallel=False, is_suffix=False): + if is_suffix: + suffix = kernel.reduction_suffix + if parallel: + suffix = kernel.parallel_reduction_suffix + suffix + else: + suffix = kernel.non_parallel_reduction_suffix + suffix + return suffix + else: + prefix = kernel.reduction_prefix + if parallel: + prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix + return prefix + + def gen_loop_with_reduction( + _loop_nest: LoopNest, depth: int = 0, in_reduction=False + ): + kernel = _loop_nest.get_kernel() + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + with contextlib.ExitStack() as stack_outer: + if loop.is_reduction and not in_reduction: + reduction_prefix = get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=False + ) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if is_reduction_loop and loop.parallel: + worksharing.parallel(threads) + if kernel.local_reduction_init: + assert kernel.local_reduction_stores + code.splice(kernel.local_reduction_init) + + gen_loop_at(_loop_nest, depth) + + if is_reduction_loop and loop.parallel: + if kernel.local_reduction_stores: + code.splice(kernel.local_reduction_stores) + worksharing.close() + if loop.is_reduction and not in_reduction: + code.splice( + get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=True + ) + ) + + def gen_loop_at(_loop_nest: LoopNest, depth: int = 0): + with contextlib.ExitStack() as stack: + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + gen_loop_nest(_loop_nest, depth + 1, loop.is_reduction) + + def gen_loop_nest( + _loop_nest: LoopNest, + depth: int = 0, + in_reduction: bool = False, + ): + if _loop_nest.loops is None or depth == len(_loop_nest.loops): # type: ignore[arg-type] + gen_kernel(_loop_nest) + else: + gen_loop_with_reduction(_loop_nest, depth, in_reduction) + + stack.enter_context(code.indent()) + + if ( + isinstance(loop_nest.kernel, OuterLoopFusedKernel) + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + for local_buffer in local_buffers.values(): + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + local_buffer_name = local_buffer.get_name() + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" + ) + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" + ) + gen_loop_nest(loop_nest) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNest.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, max_parallel_depth, threads): + assert self.call_ranges is not None + ranges = self.call_ranges[ + max_parallel_depth.start_depth : ( + max_parallel_depth.start_depth + max_parallel_depth.parallel_depth + ) + ] + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return ParallelDepth( + parallel_depth=depth, start_depth=max_parallel_depth.start_depth + ) + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + def get_to_dtype_expr(self, src, dtype, src_dtype): + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" + + def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): + expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) + self.cse.put(expr, dst) + + def codegen_conditions( + self, + code: BracesBuffer, + prefix: Optional[str] = None, + var: Optional[sympy.Symbol] = None, + ): + if prefix is None: + prefix = "" + if not self.active_ranges: + return True + conditions = [] + + def gen(start, end, var): + if start == end: + return False + var_id = None + for i, _var in enumerate(self.itervars): + if var == _var: + var_id = i + break + if ( + type(self) is CppKernel + and var_id + and start == 0 + and end == self.ranges[var_id] + ): + end = 1 + # pyrefly: ignore [bad-argument-type] + conditions.append(f"{var} >= {cexpr_index(start)}") + # pyrefly: ignore [bad-argument-type] + conditions.append(f"{var} < {cexpr_index(end)}") + return True + + if var is not None: + assert var in self.active_ranges + start, end = self.active_ranges[var] + if not gen(start, end, var): + return False + else: + for _var, _range in self.active_ranges.items(): + start, end = _range + if not gen(start, end, _var): + return False + joined_conditions = " && ".join(conditions) + if joined_conditions: + code.writeline(f"if({prefix}({joined_conditions}))") + return True + else: + return False + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_idx, + tail_size=None, + ): + super().__init__(args, num_threads) + self.vec_isa = cpu_vec_isa.pick_vec_isa() + assert self.vec_isa + assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + self.tail_size = tail_size + self.num_elems = tail_size if tail_size else tiling_factor + + def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): + if self.index_indirect_depends_on(index, itervar): + return None + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + return None + stride = stride_at_vec_range(index, itervar, self.tiling_factor) + return stride if stride.is_number else None + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: + # This utility function is used to check if the vector lanes has been + # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. + return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: + if dtype == torch.bool: + return "" + num_vectors = self._get_num_vectors(dtype) + return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: + assert mask.dtype == torch.bool, repr(mask) + num_vectors = self._get_num_vectors(dtype) + return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + cpp_type = DTYPE_TO_CPP[dtype] + num_vectors = self._get_num_vectors(dtype) + load_mask_str = None + if load_mask: + if not load_mask.is_vec: + # TODO: avoid hard-code torch.float + load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" + else: + load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.bool: + # TODO: should we consider load mask here? + line = f"{self._get_mask_type()}::from({loadbuf}, {cexpr_index(self.num_elems)})" + else: + line = ( + f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" + ) + return line + + def _load_or_store_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, + ) -> Optional[CppCSEVariable]: + """ + Load or store a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided + :return: a CppCSEVariable that represents the loaded vector or None if it is a store. + """ + assert not store_value or var is not None, "store var must be provided" + if accu_store: + assert store_value + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.num_elems * (4 // dtype.itemsize) + else: + return self.num_elems + + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) + result_declare = ( + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" + ) + code.writeline(result_declare) + if store_value: + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + load_mask = None + if self._load_mask is not None: + assert not store_value, "unexpected store with load mask" + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = f"{self._load_mask}.is_masked({itervar_inner})" + else: + load_mask = f"{self._load_mask} != 0" + if cpp_builder.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + + f"{itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + index_c = cexpr_index(index) + for indirect_var in replacements: + index_c = re.sub( + r"\b" + f"{indirect_var}" + r"\b", + replacements[indirect_var], + index_c, + ) + rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + if store_value: + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") + else: + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + if not store_value: + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + if store_value: + code.writeline(";") + buffer.splice(code) + return None + else: + csevar = self.cse.generate(buffer, code, dtype=dtype) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = self._try_get_const_stride(index, tiling_var) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + elif stride == 1: + # load contiguously + line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) # type: ignore[assignment] + else: + csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (self, name, index), {}) + csevar.is_vec = True + return csevar + + def _get_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + accu_store: bool = False, + ): + """ + Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles + both contiguous and non-contiguous store cases. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + var_expr = f"{var} + {cexpr_index(index)}" + stride = self._try_get_const_stride(index, tiling_var) + code = IndentedBuffer() + if stride == 1: + if accu_store: + load = ( + f"{self._get_vec_type(dtype)}::loadu({var_expr})" + if dtype == torch.float and self.tail_size is None + else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})" + ) + value = f"({value} + {load})" + if dtype == torch.float and self.tail_size is None: + code.writeline(f"{value}.store({var_expr});") + else: + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) + else: + self._load_or_store_non_contiguous( + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store + ) + return code + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + var = self.args.output(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert isinstance(index, CppCSEVariable) and index.is_vec + if self.tail_size: + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value}, {cexpr_index(self.tail_size)});" + else: + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") + + def reduction(self, dtype, src_dtype, reduction_type, value): + """ + Perform vectorized reduction operation. + + This method handles vectorized reduction for different reduction types. + It manages special cases for low-precision floating point types and + employs precision improvement techniques for certain reduction operations. + + Args: + dtype: The output data type for the reduction result + src_dtype: The source data type of the input value + reduction_type: Type of reduction operation (sum, min, max, etc.) + value: The input value to reduce + + Returns: + The result of the reduction operation + """ + # Note: For argmax and argmin on bool type, we always convert bool to float. + # Fix issue: https://github.com/pytorch/pytorch/issues/143568 + assert reduction_type in VECTORIZABLE_RTYPES + argmax_or_argmin = reduction_type in ("argmax", "argmin") + horizontal_reduction = self.tiling_idx >= self.reduction_depth + init_dtype = src_dtype if argmax_or_argmin else dtype + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + acc_type = reduction_acc_type(reduction_type, init_dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + assert isinstance(acc, CppCSEVariable) + acc_vec = f"{acc}_vec" + masked_acc = f"masked_{acc}" + masked_acc_vec = f"masked_{acc_vec}" + self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec] + self.is_reduction = True + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + self.reduction_init_vec, + ) + ) + + use_acc_helper = self.need_use_acc_helper(reduction_type, dtype, False) + if use_acc_helper: + # use masked acc_vec for tail vec kernel + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + self.reduction_init_vec, + ) + ) + + # use welford_helper/cascade_helper for vec kernel + assert self.reduction_depth is not None + reduction_size = functools.reduce( + operator.mul, self.ranges[self.reduction_depth :] + ) + if reduction_type == "welford_reduce": + helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + else: + helper_val = self.cascade_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + masked_helper_val = f"masked_{helper_val}" + helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + else sympy.Integer(0) + ) + masked_helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if self.ranges[self.tiling_idx] % self.tiling_factor + else sympy.Integer(0) + ) + # scalar helper for scalar welford_reduce/sum is also needed when vec kernel is included + scalar_helper_val = f"scalar_{helper_val}" + self._use_acc_helper( + reduction_type, + acc, + scalar_helper_val, + reduction_size, + dtype, + use_scalar=True, + ) + self._use_acc_helper( + reduction_type, acc, helper_val, helper_vec_range, dtype + ) + self._use_acc_helper( + reduction_type, + masked_acc, + masked_helper_val, + masked_helper_vec_range, + dtype, + ) + + # use masked acc_vec for tail vec kernel + acc_vec_ = masked_acc_vec if self.tail_size else acc_vec + helper_val_ = masked_helper_val if self.tail_size else helper_val + if reduction_type == "sum": + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};" + ) + else: + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + kwargs = { + "next_value": value, + "index": index, + "horizontal_reduction": horizontal_reduction, + "src_dtype": src_dtype, + } + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, **kwargs)};" + ) + self._gen_parallel_reduction_buffers( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + self._gen_parallel_reduction_buffers( + acc, + acc_type, + reduction_type, + init_dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ) + if use_acc_helper: + # use masked acc_vec for tail vec kernel + self._gen_parallel_reduction_buffers( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + tmpvar: Union[str, CSEVariable] + is_bool = dtype == torch.bool + if horizontal_reduction: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert self._get_num_vectors(dtype) in [ + 1, + 2, + ], "Welford reduction does not support VectorizedN (N>2)" + next_value = f"welford_vec_reduce_all({acc_vec})" + masked_next_value = f"welford_vec_reduce_all({masked_acc_vec})" + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" + ) + elif argmax_or_argmin: + next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" + elif is_bool: + if reduction_type in ( + "any", + "sum", + "max", + ): + next_value = f"!{acc_vec}.all_zero()" + else: + assert reduction_type == "min" + next_value = f"{acc_vec}.all_masked()" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + is_bool = dtype == torch.bool + # we are using at::vec::VecMask for bool + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" + result_vec = f"{acc_vec}" + if use_acc_helper: + assert reduction_type == "sum" + result_vec = f"{acc_vec} + {masked_acc_vec}" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {result_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + if is_welford_reduction(reduction_type): + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" + ) + elif use_acc_helper: + assert reduction_type == "sum" + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {tmpvar} + {masked_tmpvar};" + ) + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + if out_dtype.is_floating_point and out_dtype != torch.double: + dtype = torch.float + else: + dtype = out_dtype + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) + code = IndentedBuffer() + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + code.writeline( + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" + ) + else: + # Vertical reduction + if out_dtype != dtype: + converted_value = ( + f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}" + ) + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") + value = converted_value + code.splice(self._get_store_line(value, var, index, out_dtype)) + self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: + assert not index.is_vec + assert index.dtype is not None + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = index.dtype + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + if reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[scalar_type] + acc_type = self.reduction_acc_type_vec(reduction_type, dtype) + if reduction_type == "argmin": + val = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + else: + val = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + return f"{acc_type}({val})" + + if reduction_type == "any": + return f"{self._get_mask_type()}::from(0)" + + scalar_init = reduction_init(reduction_type, dtype) + vec_init = f"{vec_type}({scalar_init})" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "sum") + return f"{self._get_mask_type()}::from({scalar_init})" + return vec_init + + def reduction_acc_type_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + if reduction_type in ("argmin", "argmax"): + n_src = self._get_num_vectors(scalar_type) + n_idx = self._get_num_vectors(torch.int64) + if dtype == torch.bool: + return f"IndexValueVec<{DTYPE_TO_CPP[torch.float]}, {n_src}, {n_idx}>" + return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "any", "sum") + return f"{self._get_mask_type()}" + return vec_type + + def reduction_combine_vec( + self, + reduction_type, + var, + next_value, + helper_val=None, + index: Optional[sympy.Symbol] = None, + horizontal_reduction: Optional[bool] = None, + src_dtype: Optional[torch.dtype] = torch.float32, + ): + is_bool = src_dtype == torch.bool + if reduction_type == "max": + if self.tail_size: + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} | {next_value}" + if is_bool + else f"at::vec::maximum({var}, {next_value})" + ) + elif reduction_type == "min": + if self.tail_size: + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} & {next_value}" + if is_bool + else f"at::vec::minimum({var}, {next_value})" + ) + elif reduction_type == "sum": + if helper_val: + if self.tail_size: + return f"cascade_sum_combine({next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" + else: + return f"cascade_sum_combine({next_value}, &{helper_val})" + else: + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + elif reduction_type == "prod": + if self.tail_size: + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + if self.tail_size: + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + if helper_val: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" + else: + return f"welford_combine({var}, {next_value}, &{helper_val})" + else: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + if self.tail_size: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + elif reduction_type in ("argmin", "argmax"): + assert src_dtype is not None + cdtype = DTYPE_TO_CPP[src_dtype] + if src_dtype == torch.bool: + cdtype = DTYPE_TO_CPP[torch.float] + n_src = self._get_num_vectors(src_dtype) + n_idx = self._get_num_vectors(torch.int64) + t_extra = "" + arg_extra = "" + if index is not None: + assert horizontal_reduction is not None + t_extra = f", {str(horizontal_reduction).lower()}" + arg_extra = f", {index}" + if self.tail_size: + return ( + f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" + ) + else: + return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" + elif reduction_type == "any": + if isinstance(next_value, CppCSEVariable): + assert next_value.dtype == torch.bool + (next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,)) + if self.tail_size: + return f"any_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} | {next_value}" + else: + raise NotImplementedError + + def indirect_assert(self, var, lower, upper, mask=None): + assert isinstance(var, CppCSEVariable) + assert var.dtype is not None + if not var.is_vec: + if isinstance(mask, CppCSEVariable) and mask.is_vec: + mask = f"({mask}).all_masked()" + return super().indirect_assert(var, lower, upper, mask) + lower_scalar = lower + upper_scalar = upper + if lower: + lower = f"{self._get_vec_type(var.dtype)}({lower})" + if upper: + upper = f"{self._get_vec_type(var.dtype)}({upper})" + if lower and upper: + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = f"{lower_scalar} <= {var}" + else: + assert upper + cond = f"{var} < {upper}" + cond_print = f"{var} < {upper_scalar}" + cond = f"{self._get_mask_type(var.dtype)}({cond})" + if mask: + if not mask.is_vec: + mask = f"{self._get_mask_type(var.dtype)}({mask})" + # We need not check when the mask is False + cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) + cond = f"({cond}).all_masked()" + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def get_to_dtype_expr(self, src, dtype, src_dtype): + assert isinstance(src, CppCSEVariable) + if not src.is_vec: + return super().get_to_dtype_expr(src, dtype, src_dtype) + src_cpp_type = DTYPE_TO_CPP[src_dtype] + src_num_vectors = self._get_num_vectors(src_dtype) + dst_cpp_type = DTYPE_TO_CPP[dtype] + dst_num_vectors = self._get_num_vectors(dtype) + expr = f"({src})" + if src_dtype != torch.bool and dtype == torch.bool: + expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" + elif src_dtype == torch.bool and dtype != torch.bool: + expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" + elif src_dtype != dtype: + if src_num_vectors == dst_num_vectors == 1: + expr = f"at::vec::convert<{dst_cpp_type}>({src})" + else: + expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" + return expr + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_indices, + inner_tail_size=None, + outer_tail_size=None, + ): + super().__init__( + args, + num_threads, + tiling_factor, + tiling_indices[1], + inner_tail_size, + ) + self.tiling_indices = tiling_indices + self.inner_tail_size = inner_tail_size + self.outer_tail_size = outer_tail_size + self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor + self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor + self.inner_is_tiling_idx = True + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store( + self, name, var, index, is_store, store_mode=None + ): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{cexpr_index(self.num_elems)}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false" + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" + ) + else: + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst});" + ) + if is_store: + tile_var = self.cse.newvar() + elif not self.cse.contains(load_or_store): + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.get(load_or_store) + + if need_define: + cpp_dtype = DTYPE_TO_CPP[dtype] + # tiling_factor might be smaller than the alignment of cpp_dtype, such as + # with a vector that only holds 4 elements due to NEON 128-bit vectors and + # cpp_dtype being a 64-bit integer. + alignas = f"alignas(std::max(std::size_t({factor}), alignof({cpp_dtype})))" + define_line = f"{alignas} {cpp_dtype} {tile_var}[{factor}*{factor}];" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) + csevar.update_on_args("load", (self, name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True, store_mode=mode + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ + torch.uint8, + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + if self.inner_is_tiling_idx: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" + ) + else: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + if self.tiling_idx == self.tiling_indices[0]: + self.tail_size = self.outer_tail_size + self.num_elems = self.outer_num_elems + self.inner_is_tiling_idx = False + else: + self.tail_size = self.inner_tail_size + self.num_elems = self.inner_num_elems + self.inner_is_tiling_idx = True + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +def get_loop_body_lowp_fp(_body: LoopBody) -> tuple[Optional[torch.dtype], bool]: + """ + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. + """ + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + + _lowp_fp_type: Optional[torch.dtype] = None + _use_fp32 = False + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + _use_fp32 = True + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") + else: + _lowp_fp_type = opt_ctx.dtype + else: + _use_fp32 = True + + return _lowp_fp_type, _use_fp32 + + +class TilingSelect: + """ + Implement the heuristic to select the tiling factors and tiling indices. + In the future, we can implement advanced heuristic in a subclass. + """ + + def select_tiling( + self, + fn_list, + var_sizes_list, + ) -> tuple[list[int], list[int]]: + # TODO(jgong5): support alternative tiling factors and data types + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] + dtype = torch.float + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] + if _lowp_fp_dtype and all( + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) + for loop_body in loop_bodies[1:] + ): + dtype = _lowp_fp_dtype + + tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + tiling_indices = self._select_tiling_indices( + fn_list, var_sizes_list, tiling_factor + ) + + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + + if config.cpp.enable_tiling_heuristics: + + def _try_get_stride( + index, + itervars, + tiling_factor, + tiling_indices, + ): + itervar = itervars[tiling_indices[0]] + stride = stride_at_vec_range(index, itervar, tiling_factor) + return stride if stride.is_number else None + + def _update_negative_op_count( + node_name, non_contig_indexing_op_counter + ): + if node_name not in non_contig_indexing_op_counter: + non_contig_indexing_op_counter[node_name] = 1 + else: + non_contig_indexing_op_counter[node_name] += 1 + + def _is_valid_indices( + itervars, + tiling_indices, + ): + return ( + len(tiling_indices) == 1 + and len(itervars) > 0 + and ( + tiling_indices[0] + if tiling_indices[0] >= 0 + else tiling_indices[0] + len(itervars) + ) + < len(itervars) + ) + + itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(call_ranges)) + ] + reduction_depth = len(group) + vars, reduction_vars = ( + itervars[:reduction_depth], + itervars[reduction_depth:], + ) + op_counter: dict[str, int] = {} + # ops may cause overhead with vectorization, like non-contiguous + # index_expr, load, store + non_contig_indexing_op_counter: dict[str, int] = {} + for _body in loop_bodies: + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.target in ["index_expr", "load", "store"]: + # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 + index = sub_block.body.indexing_from_args( + (vars, reduction_vars) + )[_node.args[arg_idx].args[0]] + if _is_valid_indices(itervars, tiling_indices): + stride = _try_get_stride( + index, itervars, tiling_factor, tiling_indices + ) + if ( + stride is None + if _node.target == "index_expr" + else stride not in [0, 1] + ): + _update_negative_op_count( + _node.target, non_contig_indexing_op_counter + ) + if isinstance(_node.target, str) and not ( + _node.target.startswith("masked_subblock") + or _node.target + in ["ops", "output", "constant", "get_index"] + ): + if _node.target not in op_counter: + op_counter[_node.target] = 1 + else: + op_counter[_node.target] += 1 + + op_num = sum(op_counter.values()) + non_contig_indexing_op_num = sum( + non_contig_indexing_op_counter.values() + ) + ratio_threshold = 0.12 + quantity_threshold = 35 + if non_contig_indexing_op_num >= quantity_threshold or ( + op_num > 0 + and non_contig_indexing_op_num / op_num >= ratio_threshold + ): + # Too many non-contiguous load/store/index_expr which hurts the + # vectorization performance. Disable vectorization when exceeding + # the thresholds. + return [], [] + + if ( + not reduction_group + and group + and len(tiling_indices) == 1 + and not has_free_symbols( + [ + group[tiling_indices[0]], + ] + ) + and group[tiling_indices[0]] < tiling_factor / 4 + and op_num < 10 + ): + # We found that when the number of elements in the inner loop range is + # relatively small(< tiling_factor / 4) and the number of operations is + # not large(< 10), vectorization is not efficient. + # And found that `#pragma GCC ivdep` has better performance than + # `#pragma omp simd simdlen(8)` for these cases. + return [], [] + + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.check_lt(call_range, factor_lowp) # type: ignore[arg-type] + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + def _select_tiling_indices( + self, + fn_list, + var_sizes_list, + tiling_factor, + ): + all_index = [] + for fn, var_sizes in zip(fn_list, var_sizes_list): + rw = dependencies.extract_read_writes(fn, *var_sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = OrderedSet[int]() + contig_vars_list = [] + non_contig_stride_const = OrderedSet[int]() + non_contig_stride_other = OrderedSet[int]() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + num_itervars = len(group) + len(reduction_group) + if len(contig_vars) == 0: + # no contiguous vars + return [num_itervars - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == num_itervars - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + +class CppKernelProxy(CppKernel): + # Subclass CppKernel, CppVecKernel, etc., to customize code generation. + # Override CppOverrides or CppVecOverrides to emit custom ops. + # Earlier, this meant copying codegen_functions() to use your subclasses. + # Now, use kernel_cls and vec_kernel_cls class attributes instead. + # This lets CppKernelProxy subclasses inject custom behavior cleanly. + # No need to duplicate codegen_functions() just to swap kernel classes. + kernel_cls: type[CppKernel] = CppKernel + vec_kernel_cls: type[CppVecKernel] = CppVecKernel + tile2d_kernel_cls: type[CppTile2DKernel] = CppTile2DKernel + + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.kernels: list[CppKernel] = [] + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, LoopBody): + return True + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) + + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): + def add_to_dtype(sub_graph: torch.fx.Graph): + def get_input_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get input dtype for nodes that may consumes lowp fp dt""" + if node.target == "store": + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target == "to_dtype_bitcast": + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype": + if len(node.args) > 3: + return node.args[3] # type: ignore[return-value] + else: + return node.kwargs.get("src_dtype", None) # type: ignore[return-value] + else: + return None + + def get_output_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get output dtype for nodes that may produce lowp fp dt""" + if node.target == "load": + assert len(node.args) == 3 + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target in ["to_dtype", "constant", "index_expr"]: + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype_bitcast": + return node.args[2] # type: ignore[return-value] + else: + return None + + def is_lowp_fp_source(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node produces output with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + return get_output_dtype(node) == dt + + def is_lowp_fp_sink(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node accept input with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + if input_dtype := get_input_dtype(node): + return input_dtype == dt + elif node.target == "to_dtype": + # The `src_dtype` of a `to_dtype` node might miss, in which case the node accept any input dtype. + return True + else: + return False + + def is_lowp_fp_source_no_promote(node: torch.fx.Node, dt: torch.dtype): + """Check if the node is a lowp fp sources which are all directly fed to ops that accepts lowp fp input + thus no need to promote to float + """ + return is_lowp_fp_source(node, dt) and all( + is_lowp_fp_sink(user, dt) for user in node.users + ) + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if ( + _node.target in ["load", "index_expr"] + and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP + ): + # No need to promote to float if all users are ops that accepts lowp fp input + # pyrefly: ignore [bad-argument-type] + if all(is_lowp_fp_sink(user, dt) for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + # pyrefly: ignore [bad-assignment] + metrics.cpp_to_dtype_count += 1 + elif ( + _node.target == "store" + and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP + ): + ops, name, _, value_var, _ = _node.args + # pyrefly: ignore [bad-argument-type] + if is_lowp_fp_source_no_promote(value_var, dt): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + # pyrefly: ignore [bad-assignment] + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "constant" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, value, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + _node.args = (ops, value, torch.float) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, x, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + elif _node.target == "to_dtype_bitcast": + (ops, value_var, dtype, src_dtype) = _node.args + + # to_dtype_bitcast act as a lowp fp sink: + # c10::bit_cast requires the source and target have the same bitwidth. Because the input tensor's + # dtype could be promoted, e.g. from float16 to float, we have to cast the tensor to its original + # source dtype before invoking bit_cast. + if src_dtype in DTYPE_LOWP_FP: + # No need to promote to float if it is a user of a lowp fp sources + # which are all directly fed to ops that accepts lowp fp input + if not is_lowp_fp_source_no_promote(value_var, src_dtype): + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, src_dtype) + ) + _node.replace_input_with(value_var, to_type_node) + # pyrefly: ignore [bad-assignment] + metrics.cpp_to_dtype_count += 1 + + # to_dtype_bitcast act as a lowp fp source: + # We also need to convert the bit-casted tensor back to float to make sure we keep using higher + # precision values for the rest of the computation. + if dtype in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + if not ( + all(is_lowp_fp_sink(user, dtype) for user in _node.users) + ): + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + # pyrefly: ignore [bad-assignment] + metrics.cpp_to_dtype_count += 1 + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + def legalize_lowp_fp_dtype(self, nodes): + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): + self.legalize_lowp_fp_dtype_loopbody(body) + + def codegen_functions(self, fn_list, var_sizes_list): + assert len(fn_list) == len(var_sizes_list) + kernel_group = self.kernel_group + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + # pyrefly: ignore [bad-assignment] + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for fn, var_sizes in zip(fn_list, var_sizes_list): + if var_sizes in [ + (group, reduction_group), + (tuple(itertools.chain(group, reduction_group)), ()), + ]: + assert not in_suffix + fn(vars, reduction_vars) + else: + in_suffix = True + assert var_sizes == ( + group, + (), + ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + fn(vars, ()) + + scalar_kernel = codegen_kernel(self.kernel_cls) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNest.build(scalar_kernel) + + if not self.picked_vec_isa or not self.itervars: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers(False, None) + self.loop_nest.set_kernel(self) + return + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_select = TilingSelect() + tiling_factors, tiling_indices = tiling_select.select_tiling( + fn_list, var_sizes_list + ) + assert len(tiling_factors) == len(tiling_indices) + _inner_loop_reduction_outer_not = False + _outer_loop = None + if tiling_indices: + inner_loop_reduction = False + outer_loop_level = tiling_indices[0] + inner_loop_level = outer_loop_level + 1 + if len(self.loop_nest.loops) > inner_loop_level: + inner_loop_reduction = self.loop_nest.loops[ + inner_loop_level + ].is_reduction + outer_loop_reduction = self.loop_nest.loops[ + outer_loop_level + ].is_reduction + _inner_loop_reduction_outer_not = ( + inner_loop_reduction and not outer_loop_reduction + ) + + if len(tiling_indices) == 1: + # pyrefly: ignore [bad-assignment] + metrics.generated_cpp_vec_kernel_count += 1 + loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + tail_size = loop.size - loop.tiled_size + vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)} + if config.cpp.enable_loop_tail_vec: + tail_kernel = codegen_kernel( + self.vec_kernel_cls, + tiling_factors[0], + tiling_indices[0], + tail_size, + ) + else: + tail_kernel = scalar_kernel + scalar_kernel.inner_itervars = [loop.var] + tail_kernel.active_ranges = {loop.var: (loop.tiled_size, loop.size)} + self.kernels = [vec_kernel, tail_kernel] + _outer_loop = loop + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + + # pyrefly: ignore [bad-assignment] + metrics.generated_cpp_vec_kernel_count += 2 + outer_loop = self.loop_nest.tile( + tiling_indices[0], factor=tiling_factors[0] + ) + outer_ranges = { + "main": (0, outer_loop.tiled_size), + "tail": (outer_loop.tiled_size, outer_loop.size), + } + outer_tail_size = outer_loop.size - outer_loop.tiled_size + inner_loop = self.loop_nest.tile( + tiling_indices[1], factor=tiling_factors[0] + ) + inner_ranges = { + "main": (0, inner_loop.tiled_size), + "tail": (inner_loop.tiled_size, inner_loop.size), + } + inner_tail_size = inner_loop.size - inner_loop.tiled_size + tile2d_kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + ) + tile2d_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["main"], + } + tail_kernel = [] + if config.cpp.enable_loop_tail_vec: + for outer_r, inner_r in ( + ("main", "tail"), + ("tail", "main"), + ("tail", "tail"), + ): + _inner_tail_size = ( + inner_tail_size if inner_r == "tail" else None + ) + _outer_tail_size = ( + outer_tail_size if outer_r == "tail" else None + ) + kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + _inner_tail_size, + _outer_tail_size, + ) + kernel.active_ranges = { + outer_loop.var: outer_ranges[outer_r], + inner_loop.var: inner_ranges[inner_r], + } + tail_kernel.append(kernel) + else: + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + vec_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["tail"], + } + vec_kernel.inner_itervars = [inner_loop.var] + tail_kernel.append(vec_kernel) + scalar_kernel.active_ranges = { + outer_loop.var: outer_ranges["tail"], + inner_loop.var: (0, inner_loop.size), + } + scalar_kernel.inner_itervars = [inner_loop.var, outer_loop.var] + tail_kernel.append(scalar_kernel) + self.kernels = [tile2d_kernel] + tail_kernel + _outer_loop = outer_loop + else: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers( + _inner_loop_reduction_outer_not, _outer_loop + ) + self.loop_nest.set_kernel(self) + + def codegen_loop_bodies(self, loop_bodies, var_sizes_list): + for body in loop_bodies: + self.legalize_lowp_fp_dtype_loopbody(body) + DataTypePropagation.propagate_loopbody(body) + self.codegen_functions(loop_bodies, var_sizes_list) + + def codegen_nodes(self, nodes: list[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + assert len(nodes) >= 1 + + def fn(node, *index_vars): + node.decide_inplace_update() + node.mark_run() + if isinstance(V.kernel, NullKernelHandler): + return node._body(*index_vars) + else: + return node.codegen(index_vars) + + fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + + def wrap_fn(fn): + wrapped_fn = V.local_buffer_context.localize_function( + fn, + ) + wrapped_fn.original_fn = fn + return wrapped_fn + + fn_list = [wrap_fn(fn) for fn in fn_list] + + var_sizes_list = [node.group[1] for node in nodes] + self.codegen_functions(fn_list, var_sizes_list) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + def update_stores_with_parallel_reduction(self): + for kernel in self.kernels: + kernel.update_stores_with_parallel_reduction() + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is not None + if_prefix = "C10_LIKELY" + for kernel in self.kernels: + with contextlib.ExitStack() as stack: + if kernel.codegen_conditions(code, if_prefix): + if_prefix = "C10_UNLIKELY" + stack.enter_context(code.indent()) + code.splice(kernel.gen_body()) + + def aggregate_reduction_buffers( + self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"] + ): + """ + CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. + Here, we decide how to aggregate them together and place new reduction buffers + under CppKernelProxy. + """ + + def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): + assert len(self.kernels) >= 2 + main_loop_kernel = self.kernels[0] + tail_loop_kernel = self.kernels[-1] + assert isinstance(main_loop_kernel, self.vec_kernel_cls) + + # Prefix + if type(tail_loop_kernel) is self.kernel_cls: + # if tail loop kernel is a scalar kernel, we need to extend tmp_acc -> tmp_acc_arr[] to + # hold the temporary inner loop acc result for outer tail loop + tail_loop_kernel.finalize_reduction_prefix( + main_loop_kernel.tiling_factor + ) + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice( + tail_loop_kernel.reduction_prefix + + main_loop_kernel.reduction_prefix + ) + else: + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_loop_kernel.reduction_prefix) + + # Suffix + suffix_buf = BracesBuffer() + with contextlib.ExitStack() as stack: + if main_loop_kernel.codegen_conditions( + suffix_buf, "C10_LIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + suffix_buf.splice(main_loop_kernel.reduction_suffix) + with contextlib.ExitStack() as stack: + if tail_loop_kernel.codegen_conditions( + suffix_buf, "C10_UNLIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + if type(tail_loop_kernel) is self.kernel_cls: + reduction_vars = tail_loop_kernel.reduction_var_names + for name in reduction_vars: + new_name = f"{name}_arr[{outer_loop.var}_tail - {cexpr_index(outer_loop.tiled_size)}]" + replace_acc_name(tail_loop_kernel.stores, name, new_name) + replace_acc_name( + tail_loop_kernel.reduction_suffix, name, new_name + ) + # If tail loop kernel is a scalar kernel, use direct sum instead of cascade_sum_combine + # as the reduction vars are extended: tmp_acc -> tmp_acc_arr[]. + replace_cascade_sum_with_add(tail_loop_kernel.stores) + suffix_buf.splice( + move_code_under_inner_loop( + tail_loop_kernel.reduction_suffix, + outer_loop.var, + f"{outer_loop.var}_tail", + outer_loop.tiled_size, + outer_loop.size, + ) + ) + else: + suffix_buf.splice(tail_loop_kernel.reduction_suffix) + self.reduction_suffix = suffix_buf + + main_kernel = self.kernels[0] + if inner_loop_reduction_outer_not: + assert outer_loop + aggregate_reduction_prefix_suffix(outer_loop) + else: + main_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_kernel.reduction_prefix) + self.reduction_suffix.splice(main_kernel.reduction_suffix) + self.parallel_reduction_prefix.splice(main_kernel.parallel_reduction_prefix) + self.parallel_reduction_suffix.splice(main_kernel.parallel_reduction_suffix) + self.local_reduction_init.splice(main_kernel.local_reduction_init) + self.local_reduction_stores.splice(main_kernel.local_reduction_stores) + self.non_parallel_reduction_prefix.splice( + main_kernel.non_parallel_reduction_prefix + ) + self.non_parallel_reduction_suffix.splice( + main_kernel.non_parallel_reduction_suffix + ) + + +class OuterLoopFusedKernel(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.inner: list[LoopNest] = [] + + def decide_parallel_depth(self, max_parallel_depth, threads): + kernels_parallel_depth = [] + nested_kernels: list[CppKernel] = [ + loop_nest.get_kernel() for loop_nest in self.inner + ] + # TODO(leslie-fang-intel): only enable parallel within all outer loop levels. + for kernel in nested_kernels: + # For any ScalarKernel, VecKernel, or Tile2DKernel, + # they should all have the same call_ranges + call_ranges = kernel.call_ranges + assert call_ranges is not None + kernels_parallel_depth.append( + kernel.decide_parallel_depth( + ParallelDepth( + parallel_depth=( + len(call_ranges) - max_parallel_depth.start_depth + ), + start_depth=max_parallel_depth.start_depth, + ), + threads, + ).parallel_depth + ) + return ParallelDepth( + parallel_depth=min( + max_parallel_depth.parallel_depth, max(kernels_parallel_depth) + ), + start_depth=max_parallel_depth.start_depth, + ) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # Subclass CppKernelProxy to customize codegen without copying codegen_node(). + # Use kernel_proxy_cls to inject custom proxies in CppScheduling subclasses. + # Avoid duplicating codegen_node() just to swap in a custom kernel proxy class. + kernel_proxy_cls: type[CppKernelProxy] = CppKernelProxy + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + backend_features = OrderedSet( + [ + BackendFeature.INPLACE_BUFFERS, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + return cls.backend_features + + def __init__(self, scheduler): + super().__init__(scheduler) + if scheduler: + self.reset_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def reset_kernel_group(self): + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif node1.is_template(): + assert not node2.is_template() + return FusedSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = OrderedSet[Any]() + for snode in node.snodes: + v, exprs = get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + ref_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=ref_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + + if vars1 == vars2: + return FusedSchedulerNode.fuse(node1, node2) + + # recompute ref_node if its ranges are also changed + node_to_recomp_indexing_constraints = get_indexing_ranges_exprs( + node_to_recomp + ) + if isinstance(ref_node, SchedulerNode): + ref_node.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + else: + assert isinstance(ref_node, FusedSchedulerNode) + for snode in ref_node.snodes: + assert isinstance(snode, SchedulerNode) + snode.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + ref_node = FusedSchedulerNode(ref_node.scheduler, ref_node.snodes) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + return FusedSchedulerNode.fuse(node1, node2) + elif self.can_fuse_vertical_outer_loop(node1, node2): + return OuterLoopFusedSchedulerNode.fuse( + node1, node2, self._get_outer_loop_fusion_depth(node1, node2) + ) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + + assert isinstance(node_to_recomp, SchedulerNode) + if isinstance(node_to_recomp.node, ir.TemplateBuffer): + return False + assert isinstance(node_to_recomp.node, ir.ComputedBuffer) + # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges + # but without variable name + ranges2 = node_to_recomp.node.data.get_size() + ranges1 = None + if isinstance(ref_node, FusedSchedulerNode): + ranges_set = OrderedSet[tuple[Any, ...]]() + for snode in ref_node.snodes: + if isinstance(snode.node, ir.TemplateBuffer): + break + assert isinstance(snode.node, ir.ComputedBuffer) + ranges_set.add(tuple(snode.node.data.get_size())) + + if len(ranges_set) != 1: + return False + + ranges1 = list(next(iter(ranges_set))) + else: + assert isinstance(ref_node, SchedulerNode) + assert isinstance(ref_node.node, ir.ComputedBuffer) + ranges1 = ref_node.node.data.get_size() # type: ignore[assignment] + + if ranges1 != ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + if any( + isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) + ): + return False + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def can_fuse_multi_outputs_template( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if template_buf := node1.get_template_node(): + return ( + isinstance(template_buf.layout, ir.MultiOutputLayout) + and isinstance(node2.node, ir.MultiOutput) + and len(node2.node.inputs) == 1 + and node2.node.inputs[0].get_name() == template_buf.name # type: ignore[union-attr] + ) + return False + + def _get_outer_loop_fusion_depth(self, node1, node2): + DISABLE_OUTER_LOOP_FUSION = 0 + if not all( + type(node) + in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) + for node in (node1, node2) + ): + return DISABLE_OUTER_LOOP_FUSION + + _node1 = ( + node1.get_outer_nodes()[-1] + if isinstance(node1, OuterLoopFusedSchedulerNode) + else node1 + ) + assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) + _node2 = ( + node2.get_outer_nodes()[0] + if isinstance(node2, OuterLoopFusedSchedulerNode) + else node2 + ) + assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) + + _, (vars1, reduce1) = _node1.group + _, (vars2, reduce2) = _node2.group + if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): + # Reduction only + return DISABLE_OUTER_LOOP_FUSION + if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return ( + node1.outer_loop_fusion_depth + if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth + else DISABLE_OUTER_LOOP_FUSION + ) + outer_loop_fusion_depth = min(len(vars1), len(vars2)) + if ( + outer_loop_fusion_depth >= 1 + and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] + ): + if any( + type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) + ): + _compare_node = ( + node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 + ) + if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: + # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + else: + return DISABLE_OUTER_LOOP_FUSION + else: + # First 2 nodes to generate OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + return DISABLE_OUTER_LOOP_FUSION + + def can_fuse_vertical_outer_loop(self, node1, node2): + return ( + not node1.is_template() + and not node2.is_template() + and node1.get_operation_names() & node2.ancestors + and not ( + self._can_fuse_horizontal_impl(node1, node2) + and not node1.is_reduction() + ) + and self._get_outer_loop_fusion_depth(node1, node2) >= 1 + ) + + def get_fusion_pair_priority(self, node1, node2): + if self.can_fuse_vertical_outer_loop(node1, node2): + # Outer loop fusion with lower priority + return 1 + else: + return 0 + + def can_fuse_vertical(self, node1, node2): + if node2.is_template(): + # TODO(jgong5): support pre-op fusion with template + return False + if node1.is_template(): + template_fusion_supported, _ = template_fusion_with_epilogues_supported( + node1, [node2] + ) + return not node2.is_reduction() and template_fusion_supported + return ( + self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + ) or self.can_fuse_vertical_outer_loop(node1, node2) + + def try_loop_split(self, nodes: list[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, + where the divisor is an integer and not too small (the divisor > 8), the dividend is + one of the iter_vars, and this var, i.e. the dimension that needs to be split, is + contiguous in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + num_div = 0 + div_expr_ = None + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + if not isinstance(expr, sympy.Expr): + continue + for div_expr in expr.find(FloorDiv): + if ( + any(div_expr.has(var) for var in original_body.iter_vars) + and div_expr != div_expr_ + ): + div_expr_ = div_expr + num_div += 1 + if num_div > 1: + return nodes + if ( + isinstance(div_expr.args[1], sympy.core.numbers.Integer) + and div_expr.args[0] in original_body.iter_vars + and name is not None + and all( + stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1) + for name_, expr_ in original_body.indexing_exprs.items() + if name_ != name + ) + and div_expr.args[1] > 8 + ): + split_var = div_expr.args[0] + split_number = div_expr.args[1] + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: list[self.kernel_proxy_cls] = [] # type: ignore[name-defined] + nodes_list: list[list[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + local_buffers: list[ir.Buffer] = [] + # Map local buffer name to a list of global buffers + local_to_global_buffers: dict[str, list[ir.Buffer]] = {} + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer in + # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + visited_scheduler_nodes: OrderedSet[str] = OrderedSet() + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + assert isinstance(scheduler_node, SchedulerNode) + visited_scheduler_nodes.add(scheduler_node.get_name()) + if ( + scheduler_node.is_reduction() + or len(scheduler_node.get_outputs()) != 1 + ): + continue + + scheduler_buffer = scheduler_node.get_outputs()[0] + if all( + user.node in node.get_nodes() for user in scheduler_buffer.users + ): + global_buffer = scheduler_buffer.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + # pyrefly: ignore [missing-attribute] + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + # pyrefly: ignore [missing-attribute] + write_index_expr = scheduler_node._body.get_write_expr( + scheduler_buffer.get_name() + ) + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + isinstance(user.node, SchedulerNode) + and is_contiguous_index( + user.node._body.get_read_expr( + scheduler_buffer.get_name() + ), + ) + for user in scheduler_buffer.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and is_all_write_read_contiguous() + ): + continue + # Local Buffer is a view of global buffer + local_buffer_stride: list[int] = [] + stride = global_buffer_layout.stride[-1] + local_buffer_size = get_call_ranges(scheduler_node)[ + size_offset: + ] + for sz in reversed(local_buffer_size): + local_buffer_stride.insert(0, stride) + stride *= sz + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + local_buffer_size, + local_buffer_stride, + ) + + def try_share_local_buffer(local_buffer_layout, local_buffers): + for local_buf in local_buffers: + if local_buffer_layout == local_buf.layout and all( + all( + user.node.get_name() in visited_scheduler_nodes + for user in V.graph.scheduler.name_to_buf[ + global_buffer.name + ].users + ) + for global_buffer in local_to_global_buffers[ + local_buf.name + ] + if global_buffer.name is not None + ): + return local_buf + return None + + local_buf_prefix = "local_buffer_data" + # Share existing local buffer + local_buffer_used = try_share_local_buffer( + local_buffer_layout, local_buffers + ) + if not local_buffer_used: + # Create new local buffer + local_buffer_used = ir.Buffer( + name=f"{local_buf_prefix}_{len(local_buffers)}", + layout=local_buffer_layout, + ) + local_buffers.append(local_buffer_used) + local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] + # pyrefly: ignore [index-error] + local_to_global_buffers[local_buffer_used.name].append( + global_buffer, + ) + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + for local_buffer in local_buffers: + assert local_buffer.name is not None + scope.add_local_buffer( + local_buffer, local_to_global_buffers[local_buffer.name] + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + for removed_buffer in scope.removed_buffers: + # Restore the removed buffers by this context before + # fallback to codegen without using Local Buffer + V.graph.removed_buffers.remove(removed_buffer) + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(scope.local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [*itertools.chain.from_iterable(nodes_list)], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: list[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + + def codegen_node( + self, + node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], + ): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + if isinstance(node, OuterLoopFusedSchedulerNode): + self.codegen_outer_loop_node(node) + else: + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + assert not prologue_nodes + + # remove MultiOutput from epilogue_nodes + epilogue_nodes = [ + epilogue_node + for epilogue_node in epilogue_nodes + if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode)) + ] + # The counter cpp_templated_kernel_counter is used for verifying if a + # a templated kernel was successfully compiled in a UT + counters["inductor"]["cpp_templated_kernel_counter"] += 1 + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template(template_node), ( + "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Optional[ir.Operation]] = [ + n.node for n in epilogue_nodes + ] + assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + + def template_buffer_has_other_users( + template_buffer, outputs_by_name, epilogue_nodes + ): + if not epilogue_nodes: + return False + + assert template_buffer.get_name() in outputs_by_name + users = outputs_by_name[template_buffer.get_name()].users + return not all( + isinstance(user.node, BaseSchedulerNode) + and user.node.node in epilogue_nodes + for user in users + ) + + flag_template_buffer_has_other_users = template_buffer_has_other_users( + ctb, template_node.outputs_by_name, epilogue_ir_nodes + ) + kernel, render = ctb.make_kernel_render( # type: ignore[misc] + ctb, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_ir_nodes, + ) + with kernel: + if not is_multi_outputs_template(template_node.node): + template_node.mark_run() # type: ignore[attr-defined] + for node in epilogue_nodes: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + + if is_multi_outputs_template(template_node.node): + # For multi outputs template, allocate buffers for each output after the epilogue + # codegen to which determines if the buffer has been removed. + assert len(template_node.outputs) == 1, ( + "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" + ) + for user in template_node.outputs[0].users: + assert isinstance(user.node, ExternKernelSchedulerNode), ( + "Multi outputs template should be with ExternKernelSchedulerNode" + ) + assert isinstance(user.node.node, ir.MultiOutput), ( + "Multi outputs template has multi users with MultiOutput" + ) + user.node.mark_run() + + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def define_kernel(self, src_code, nodes, kernel_args=None): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + wrapper.src_to_kernel[src_code] = kernel_name + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "//") + + # Get the lines in the source code representing the function definition, + # excluding the first line including cpp_prefix.h. + first_char = src_code.rfind('extern "C"') + last_char = src_code.find(")", first_char) + if _IS_WINDOWS: + # get_export_declaration introduced one more ')' in Windows + last_char = src_code.find(")", last_char + 1) + kernel_definition = f"{src_code[first_char : last_char + 1]};\n" + + compile_wrapper = IndentedBuffer() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() + if not V.graph.cpp_wrapper: + compile_wrapper.writeline( + f"async_compile.cpp_pybinding({arg_types!r}, r'''" + ) + compile_wrapper.splice(src_code, strip=True) + if not V.graph.cpp_wrapper: + compile_wrapper.writeline("''')") + wrapper.define_kernel( + kernel_name, + compile_wrapper.getvalue(), + gpu=False, + cpp_definition=kernel_definition, + ) + return kernel_name + + def flush(self): + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + kernel_name, + self.kernel_group.scheduled_nodes, # type: ignore[arg-type] + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() + + self.reset_kernel_group() + self._set_flush_status(False) + + def codegen_comment(self, node_schedule, kernel_name=None): + # below add provenance tracing info for cpu CppKernel types + wrapper = V.graph.wrapper_code + debug_handle = set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + # pyrefly: ignore [bad-argument-type] + kernel_name, + ) + wrapper.write_provenance_debug_handle(kernel_name, debug_handle) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, _call_args, _arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_group(self, name=None) -> str: + self.stack.close() + if not self.scheduled_nodes: + return "" + code = BracesBuffer() + # 1. Include header files + # TODO: support kernel profile on other platforms + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + code.writelines(["#include "]) + code.writeline("#include ") + + # 2. Function definition + kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name + kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name + arg_defs, _, _ = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + func_export_decl = get_export_declaration() + inline_attr = ( + "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else "" + ) + code.writeline( + f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})' + ) + + # 3. Function body + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + ( + "torch::aot_inductor::RAIIAtenRecordFunctionHandle " + f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);' + ) + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + return code.getvalue() + + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, + call_args, + triton=False, + arg_types=arg_types, + ) + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + self.code.writeline( + "int tid = omp_get_thread_num();", + ) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.S.Zero + # Note [tiled_size] + # We may do loop-tiling at this loop level. + # When var is in [offset, tiled_size), we will perform the vectorization kernel. + # When var is in [tiled_size, size), we will perform the scalar or masked vectorization kernel. + # for (var = offset; var < size; var += steps) { + # if (var >= offset && var < tiled_size) vec_loop_body(); + # if (var >= tiled_size && var < size) scalar_or_maskvec_loop_body(); + # } + tiled_size: sympy.Expr = sympy.S.Zero + steps: sympy.Expr = sympy.S.One + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + is_reduction: bool = False + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def tile(self, factor): + sympy_factor = sympy.Integer(factor) + loop = LoopLevel(self.var, self.size) + loop.steps = sympy_factor + loop.simd_vec = True + loop.tiled_size = FloorDiv(loop.size, sympy_factor) * sympy_factor + loop.parallel = self.parallel + loop.collapsed = False + loop.is_reduction = self.is_reduction + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = "#pragma omp for" + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}" + elif not self.is_reduction and cpp_builder.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNest: + """ + A loop-nest-like structure. It is built with the `build` method + as a loop nest and then will perform loop-tiling at some depth. + + A typical case is for vectorization, where we typically do loop-tiling + at the innermost loop level. A more complicated case is when we do + 2D tiling at both the innermost and outer levels. + """ + + loops: Optional[list[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + loops: Optional[list[LoopLevel]] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size) + if not loops: + loops = [loop] + else: + loops.append(loop) + if loop_idx >= reduction_depth: + loop.is_reduction = kernel.is_reduction + + loop_nest = LoopNest(loops) + return loop_nest + + def __bool__(self): + return bool(self.loops) + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: All reduction or non-reduction levels. + When the range of the first inner loop beyond the maximum parallel depth is much + larger than the range of all outer loops within the maximum parallel depth, + change the starting depth of parallelism to the first inner loop and recalculate + the maximum parallel depth. + """ + if self.loops is None: + return ParallelDepth(parallel_depth=0, start_depth=0) + + start_depth = 0 + max_depth = 0 + is_reduction = self.loops[0].is_reduction + num_steps = sympy.Integer(1) + for loop in self.loops: + if loop.is_reduction != is_reduction: + break + num_steps = num_steps * FloorDiv(loop.size, loop.steps) + max_depth += 1 + + def get_simd_vec_depth(loops): + # Return the first loop level which is simd_vec + for i, loop in enumerate(loops): + if loop.simd_vec: + return i + return None + + simd_vec_depth = get_simd_vec_depth(self.loops) + + def has_scalar_kernel(loop_nest: LoopNest): + assert isinstance(loop_nest.kernel, CppKernelProxy) + return any( + not isinstance(kernel, CppVecKernel) + for kernel in loop_nest.kernel.kernels + ) + + # When the number of steps of the first inner loop is much larger than the number of steps of + # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. + if ( + max_depth < len(self.loops) + and isinstance(num_steps, sympy.Integer) + and isinstance(self.loops[max_depth].size, sympy.Integer) + and num_steps * 300 + < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps) + and not ( + # Disable parallel reduction under the vec loop + simd_vec_depth is not None + and max_depth > simd_vec_depth + and self.loops[max_depth].is_reduction + and has_scalar_kernel(self) + ) + ): + start_depth = max_depth + max_depth = 0 + is_reduction = self.loops[start_depth].is_reduction + for i in range(start_depth, len(self.loops)): + if self.loops[i].is_reduction != is_reduction: + break + max_depth += 1 + return ParallelDepth(parallel_depth=max_depth, start_depth=start_depth) + + def mark_parallel(self, par_depth): + assert par_depth.parallel_depth <= self.max_parallel_depth().parallel_depth, ( + "Parallel depth cannot exceed the maximal allowed parallel depth" + ) + assert self.loops is not None + assert len(self.loops) >= par_depth.parallel_depth + loop = self.loops[par_depth.start_depth] + loop.parallel = par_depth.parallel_depth + if loop.is_reduction: + # pyrefly: ignore [bad-assignment] + metrics.parallel_reduction_count += 1 + for i in range(par_depth.start_depth + 1, par_depth.parallel_depth): + self.loops[i].collapsed = True + + def tile(self, depth, factor): + """ + Do loop-tiling at the `depth` level with `factor`. + for (x0 = 0; x0 < x0_end; x0++) + -> + for (x0 = 0; x0 < x0_end; x0 += factor) + See details in Note [tiled_size]. + """ + assert self.loops + self.loops[depth] = self.loops[depth].tile(factor) + return self.loops[depth] + + def get_kernel(self) -> CppKernel: + assert self.kernel + return self.kernel + + def set_kernel(self, kernel): + self.kernel = kernel + + def from_loop_level(self, level: int): + assert self.loops + assert len(self.loops) >= level + loops = None if level == len(self.loops) else self.loops[level:] + return LoopNest(loops, self.kernel) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a7c2ef1640690bf751e08b1c1e4d33a3c147b4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py @@ -0,0 +1,263 @@ +# mypy: allow-untyped-defs +import contextlib +import itertools +from collections.abc import Callable +from typing import Any, Optional +from unittest.mock import patch + +import sympy + +from .. import ir +from ..select_algorithm import PartialRender +from ..virtualized import V +from .common import ArgName +from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE +from .cpp_micro_gemm import LayoutType +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking + + +# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition +GEMM_SINGLE_THREAD_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_single_thread_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +GEMM_THREADED_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_threaded_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +BMM_TEMPLATE = r""" +{{ template.codegen_microkernel_def() }} +{{ template.codegen_single_thread_gemm() }} +{{ template.codegen_multi_thread_gemm() }} + +extern "C" +{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}} +{ + const int64_t B = {{kernel.size(BY_2d, 0)}}; + {%- if num_threads > 1 %} + constexpr int64_t num_threads = {{num_threads}}; + int64_t B_single_thread_block = (B / num_threads) * num_threads; + + #pragma omp parallel for num_threads({{num_threads}}) + {%- else %} + int64_t B_single_thread_block = B; + {%- endif %} + for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_single_thread_mm", + "", + b_index="b_start", + )}} + } + for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_threaded_mm", + "", + b_index="b_start", + )}} + } +} +""" + + +class CppBmmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = False, + name="bmm", + ): + """ + In order to simplify the implementation and increase code reuse, the BMM template implements + two versions of the GEMM kernel: a single-threaded version and a multi-threaded version. + GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls + for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel. + + We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM + template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template + without any changes to the GEMM template itself. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=should_block_weights, + name=name, + ) + self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True) + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + if should_block_weight: + # Tensor is constant or not contiguous, so we will pad and block + new_size, padded_n = CppGemmTemplate.get_padded_size( + n, block_n, k, should_block_weight + ) + # Add the new batch dimension + new_size.insert(0, -1) + return new_size, padded_n + else: + new_size = [-1, k, n] + return new_size, n + + @staticmethod + def check_if_block_weight(W, micro_gemm): + assert isinstance(W, ir.IRNode) + _, n = W.get_size()[-2:] + result = ( + not W.get_layout().is_contiguous() + or W.get_name() in V.graph.constants + or ( + n % micro_gemm.register_blocking.block_n != 0 + and micro_gemm.get_b_layout != LayoutType.NORMAL + ) + ) + return result + + def get_gemm_function_call( + self, + kernel: CppTemplateKernel, + function_name: str, + placeholder: str, + b_index: str, + ) -> str: + """ + Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition, + generate a function call for the GEMM kernel. + Args: + placeholder: The string to replace the function call with + b_index: The index for slicing the 3D batch tensors + """ + + def hook(): + arg_defs, call_args, _, _ = kernel.args.python_argdefs() + for i, buf in enumerate(call_args): + if buf == self.b_index: + arg_defs[i] = ArgName(b_index) + call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});" + return call + + assert placeholder not in kernel.render_hooks + kernel.render_hooks[placeholder] = hook + return placeholder + + def get_default_reindexers(self, epilogue_nodes): + def reindexer(args): + # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index + return [self.b_index] + args + + return [reindexer] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> dict[str, Any]: + options = super().get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + + BX, BW, BY = options["X"], options["W"], options["Y"] + options["BX"], options["BW"], options["BY"] = BX, BW, BY + options["BY_2d"] = options["Y_2d"] + for kword in ["X", "W", "GemmOut", "Y_2d"]: + options[kword] = kernel.select(options[kword], 0, self.b_index) + for kword in ["X", "W", "Y_2d"]: + options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype] + options["b_index"] = self.b_index + options["BY_sizevars"] = [ + s + for sym in itertools.chain(BY.get_size(), BY.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ] + options["kernel_name"] = kernel.kernel_name + + return options + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + result = self._template_from_string(BMM_TEMPLATE).render(**options) + + # Finalize the function definitions for the gemm routines + sub_mm_hooks = { + name: hook + for name, hook in kernel.render_hooks.items() + if "FOR_BMM" in name + } + result = PartialRender(result, sub_mm_hooks).finalize_all() + for name in sub_mm_hooks: + del kernel.render_hooks[name] + del kernel.args.sizevars[options["b_index"]] + return result + + def codegen_single_thread_gemm(self): + stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + {**self.render_options, "num_threads": 1} + ) + + def codegen_multi_thread_gemm(self): + stub = self._template_from_string(GEMM_THREADED_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + return "" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ceecf7f7c9ea8081660c21a8ddf96254c98a68 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -0,0 +1,1090 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import re +from typing import Optional +from unittest.mock import patch + +import sympy + +import torch +import torch.utils + +from ...utils._ordered_set import OrderedSet +from .. import ir +from ..ir import TensorBox +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp_template import CppTemplate +from .cpp_utils import GemmBlocking + + +log = logging.getLogger(__name__) + +# TODO: reuse cpp codegen to generate below pointwise/reduction kernels +SOFTMAX_FUSIONS = r""" +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + at::native::_store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +inline void {{kernel_name}}_mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + at::native::_store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + max = std::max( + tmp_max, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max)); +} + +template +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +// out = a * scale +template +inline void {{kernel_name}}_mul_scale_kernel( + scalar_t* a, + scalar_t scale, + int64_t size) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + at::native::_store(a + i, tmp1); + } + for (int64_t i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + a[i] = tmp1; + } +} + +""" + +BRGEMM_PACK_FUNCTIONS = r""" +template +inline void {{kernel_name}}_copy_value_with_pad( + const scalar_t* value_ptr, + scalar_t* dst_ptr, + int64_t rows, + int64_t cols, + int64_t prows, + int64_t pcols, + int64_t ldi) { + auto vec_size = at::vec::Vectorized::size(); + int64_t i = 0; + for (; i < rows; i++) { + int64_t j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + auto zero_vec = at::vec::Vectorized(0); + int64_t pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + zero_vec.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + // row padding + for (; i < prows; i++) { + auto zero_vec = at::vec::Vectorized(0); + int64_t j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero_vec.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + zero_vec.store(dst_ptr + i * pcols + j, pcols - j); + } + + } +} +""" + +MICRO_GEMM_TEMPLATE = r""" +GEMM_DEFINE +""" + +ALLOCATE_BUFFER = r""" + int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4; + auto& {{buffer_name}}_allocator = *at::getCPUAllocator(); + auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize); + void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get(); + {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr; +""" + +FLEX_ATTENTION_TEMPLATE = r""" +{{template.header().getvalue()}} +#include +#include +#include +{{template.codegen_micro_gemm(kernel.kernel_name)}} +{{template.codegen_softmax_fusion(kernel.kernel_name)}} +{{template.codegen_brgemm_pack_function(kernel.kernel_name)}} +{%- set kernel_args = {"query": query, "key": key, "value": value, + "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices, + "full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %} +{%- set kernel_args = template.update_kernel_args(kernel_args) %} + +extern "C" +{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}} +{ + {{ kernel.maybe_codegen_profile() }} + int64_t qBlockSize = {{qBlockSize}}; + int64_t kvBlockSize = {{kvBlockSize}}; + int64_t num_thread = {{num_thread}}; + + // dtypes of kernel and internal buffers + using scalar_t = {{kernel.dtype(query)}}; + constexpr bool is_reduced_type = c10::is_reduced_floating_point_v; + using accum_t = at::opmath_type<{{kernel.dtype(query)}}>; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = {{scale}}; + int64_t batchSize = {{kernel.size(query, 0)}}; + int64_t qSize = {{kernel.size(query, 1)}}; + int64_t num_head = {{kernel.size(query, 2)}}; + int64_t headSize = {{kernel.size(query, 3)}}; + int64_t batchSize_k = {{kernel.size(key, 0)}}; + int64_t num_head_k = {{kernel.size(key, 2)}}; + int64_t headSize_v = {{kernel.size(value, 3)}}; + bool is_broadcast_bs_kv = batchSize != batchSize_k; + bool is_broadcast_head_kv = num_head != num_head_k; + int64_t gqa_shards = num_head / num_head_k; + int64_t bs_shards = batchSize / batchSize_k; + + int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}}; + int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}}; + int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}}; + bool is_broadcast_bs_kvi = batchSize != batchSize_kvi; + bool is_broadcast_head_kvi = num_head != num_head_kvi; + int64_t gqa_shards_kvi = num_head / num_head_kvi; + int64_t bs_shards_kvi = batchSize / batchSize_kvi; + + int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}}; + int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}}; + int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}}; + + int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}}; + int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}}; + +{%- if has_full_kv_block %} + int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}}; + int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}}; + int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}}; + + int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}}; + int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}}; + auto full_kv_indices_data = full_kv_indices; + auto full_kv_num_blocks_data = full_kv_num_blocks; +{%- endif %} + + auto kv_num_blocks_data = kv_num_blocks; + auto kv_indices_data = kv_indices; + + // Strides + int64_t qStrideB = {{kernel.stride(query, 0)}}; + int64_t qStrideM = {{kernel.stride(query, 1)}}; + int64_t qStrideH = {{kernel.stride(query, 2)}}; + int64_t kStrideB = {{kernel.stride(key, 0)}}; + int64_t kStrideN = {{kernel.stride(key, 1)}}; + int64_t kStrideH = {{kernel.stride(key, 2)}}; + int64_t vStrideB = {{kernel.stride(value, 0)}}; + int64_t vStrideN = {{kernel.stride(value, 1)}}; + int64_t vStrideH = {{kernel.stride(value, 2)}}; + int64_t oStrideB = {{kernel.stride(output, 0)}}; + int64_t oStrideM = {{kernel.stride(output, 2)}}; + int64_t oStrideH = {{kernel.stride(output, 1)}}; + + int64_t kvSize = {{kernel.size(key, 1)}}; + + int64_t qSplitSize = qBlockSize; + int64_t kvSplitSize = kvBlockSize; + + + qSplitSize = qSplitSize > qSize ? qSize : qSplitSize; + kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + + bool need_pack = false; + // Whether pack is needed for BFloat16/Half + if (is_reduced_type) { + // check platform ability + need_pack = std::is_same_v ? at::native::cpublas::could_pack(at::kBFloat16) + : at::native::cpublas::could_pack(at::kHalf); + } + if (need_pack) { + // When the number of gemm is greater than the number of pack, + // the pack overhead can be overlapped. + int64_t thresh_size = 64; + need_pack = kvSize >= thresh_size && qSize >= thresh_size; + if (need_pack) { + double pack_size = batchSize * num_head * kvSize * headSize; + double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread; + double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize; + need_pack = gemm_size_per_thread / pack_size >= 4; + } + } + // Pad is needed for packing when K is not even + bool headSize_even = headSize % 2 == 0; + int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; + int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; + int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; + int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + + // Allocate per thread temp buf (accumulate type) + int64_t _size_per_thread = + /* qk */ qSplitSize * kvSplitSize + + /* qk_max */ qSplitSize + + /* qk_sum */ qSplitSize + + /* dst */ qSplitSize * headSize_v; + + // Inputs/outputs buffers + const scalar_t* q_data = query; + const scalar_t* k_data = key; + const scalar_t* v_data = value; + scalar_t* out_data = output; + + // Buffers to store accum results, padding query and transpose/packing key/value + {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}} + {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}} + {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}} + {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}} + {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}} + {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}} + if (need_pack) { + // Pack K, V + at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr; + at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + auto k_addr = + k_data + i * kStrideB + j * kStrideH + n * kStrideN; + auto v_addr = + v_data + i * vStrideB + j * vStrideH + n * vStrideN; + // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize] + at::native::utils::transpose( + cur_kvSplitSize, + headSize, + /* src_ptr */ + reinterpret_cast(k_addr), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_ptr), + /* ld_dst */ cur_kvSplitSize); + + // Pack [headSize, cur_kvSplitSize] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(transpose_ptr), + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head_k * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), + /* ld_src */ cur_kvSplitSize, + /* K */ headSize, + /* N */ cur_kvSplitSize); + + // Pack [cur_kvSplitSize, headSize_v] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(v_addr), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head_k * kv_padding_size * headSize_v + + j * kv_padding_size * headSize_v + n * headSize_v), + /* ld_src */ vStrideN, + /* K */ cur_kvSplitSize, + /* N */ headSize_v); + // Move to the next query + at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice); + } + }); + } + // Attention loop below + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread; + accum_t* qk_data = buf_ptr; + accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_sum_data = qk_max_data + qSplitSize; + accum_t* dst_data = qk_sum_data + qSplitSize; + scalar_t *qk_reduced_data = + is_reduced_type + ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize + : nullptr; + scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; + + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i; + auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j; + auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int kv_indice_num = *kv_logical_num_data; + std::vector kv_indice_list(kv_indice_num); + for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){ + auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB + + j_kvi * kviStrideH + k*kviStrideQ + kv_i; + kv_indice_list[kv_i] = *kv_logical_data; + } + bool is_skip_kv = kv_indice_num > 0 ? false : true; +{%- if has_full_kv_block %} + auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int full_kv_indice_num = *full_kv_logical_num_data; + std::vector full_kv_indice_list(full_kv_indice_num); + for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){ + auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB + + j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i; + full_kv_indice_list[kv_i] = *full_kv_logical_data; + } + is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true; +{%- endif %} + int64_t m = k * qSplitSize; + int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m); + if (!is_skip_kv){ + // Initialize max and sum + {{kernel.kernel_name}}_fill_stub(qk_max_data, + -std::numeric_limits::infinity(), cur_qSplitSize); + {{kernel.kernel_name}}_fill_stub(qk_sum_data, + static_cast(0), cur_qSplitSize); + + if (!headSize_even && need_pack) { + // Pad query if headSize is not even + {{kernel.kernel_name}}_copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + cur_qSplitSize, + headSize, + cur_qSplitSize, + eheadSize, + qStrideM + ); + } + } + +{%- if has_full_kv_block %} + for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) { + auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize; +{%- else %} + for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) { + auto n = kv_indice_list[n_idx]*kvSplitSize; +{%- endif %} + + auto cur_n = n/kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize; + + // Calculate scale * q @ k.T + auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i; + auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j; + + if (!need_pack) { + auto k_addr = + k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN; + + {{kernel.kernel_name}}_kernel_micro_gemm_transpose_b(false)>( + q_data + i * qStrideB + j * qStrideH + + m * qStrideM, + k_addr, + qk_data, + cur_qSplitSize, + cur_kvSplitSize, + headSize, + qStrideM, + kStrideN, + cur_kvSplitSize); + + } else { + at::native::cpublas::brgemm( + cur_qSplitSize, + cur_kvSplitSize, + eheadSize, + headSize_even ? qStrideM : eheadSize, + cur_kvSplitSize, + cur_kvSplitSize, + false, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize + + j_kv * eheadSize * kvSize + n * eheadSize, + qk_data, + need_pack); + } + + {{kernel.kernel_name}}_mul_scale_kernel(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize); + +{%- if score_mod and mask_mod %} + // TODO: reduce the number of calls of q_idx and kv_idx initialization + std::vector q_idx(cur_qSplitSize); + for (int64_t i = 0; i < cur_qSplitSize; ++i) { + q_idx[i] = m + i; + } + + std::vector kv_idx(cur_kvSplitSize); + for (int64_t i = 0; i < cur_kvSplitSize; ++i) { + kv_idx[i] = n + i; + } + + std::vector b_idx = {i}; + std::vector h_idx = {j}; + + accum_t* in_ptr0 = qk_data; + + auto in_ptr1 = b_idx.data(); + auto in_ptr2 = h_idx.data(); + auto in_ptr3 = q_idx.data(); + auto in_ptr4 = kv_idx.data(); + + // apply score mod function + { + {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }} + accum_t* out_ptr{{score_buf_idx}} = in_ptr0; + {{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }} + } + + if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){ + // Apply block mask, fill unused with -inf + { + {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }} + accum_t* out_ptr{{mask_buf_idx}} = in_ptr0; + {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }} + } + } + +{%- endif %} + // Update coefficients with Softmax + accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // apply scaling factor and max per row in fusion + {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel( + qk_data + row * cur_kvSplitSize, + static_cast(1), + cur_kvSplitSize, + qk_data + row * cur_kvSplitSize, + tmp_max); + tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; + if (tmp_max == -std::numeric_limits::infinity()) { + // to avoid `nan = exp2f(-inf - (-inf))` + {{kernel.kernel_name}}_fill_stub( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + static_cast(0), cur_kvSplitSize); + } else { + tmp_sum = tmp_max; + // qk <- exp(qk - max) and sum per row + {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel( + qk_data + row * cur_kvSplitSize, cur_kvSplitSize, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + tmp_sum); + // exp_tmp <- exp(max[row] - max) + exp_tmp = std::exp(qk_max_data[row] - tmp_max); + // sum[row] <- sum + exp_tmp * sum[row] + qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; + // max[row] <- max + qk_max_data[row] = tmp_max; + // dst <- dst * exp_tmp + if (n_idx > 0) { + at::vec::map( + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, + dst_data + row * headSize_v, + dst_data + row * headSize_v, + headSize_v); + } + } + if (need_pack && cur_kvSplitSize % 2 != 0) { + // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1] + *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0); + } + } + // Calculate Softmax(q @ k.T) @ v + if (!need_pack) { + auto v_addr = + v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN; + // Fallback Half brgemm is slower than micro gemm + if (!std::is_same_v) { + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v, + n_idx > 0, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + need_pack); + } else { + if (n_idx > 0) { + {{kernel.kernel_name}}_kernel_micro_gemm(true)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } else { + {{kernel.kernel_name}}_kernel_micro_gemm(false)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } + } + } else { + int64_t psize = n / kvSplitSize * ekvSplitSize; + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + headSize_v, + headSize_v, + n_idx > 0, + qk_reduced_data, + value_reorder_ptr + + i_kv * num_head_k * kv_padding_size * headSize_v + + j_kv * kv_padding_size * headSize_v + psize * headSize_v, + dst_data, + need_pack); + } + } + + // dst <- dst / sum[row] + // reorder MHA output with strides + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; + accum_t sum_reciprocal = 1 / qk_sum_data[row]; + at::vec::map( + [sum_reciprocal, is_skip_kv](Vec x) { return is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); }, + out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, + dst_data + row * headSize_v, + headSize_v); + } + + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + + at::native::cpublas::brgemm_release(need_pack); + + }); +} +""" + + +class CppFlexAttentionTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.float16] + super().__init__("flex_attention", input_nodes, layout, parallel_num_threads()) + self.scale = scale + self.score_mod = score_mod + self.mask_mod = mask_mod + self.score_buf_name = ( + V.graph.register_buffer(self.score_mod) if self.score_mod else None + ) + self.mask_buf_name = ( + V.graph.register_buffer(self.mask_mod) if self.mask_mod else None + ) + + def get_idx(buf_name): + match = re.search(r"\d+", buf_name) + assert match, f"incorrect score buf name: {buf_name}" + return match.group() + + self.score_buf_idx = ( + get_idx(self.score_buf_name) if self.score_buf_name else None + ) + self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None + self.kv_block_size = kv_block_size + self.q_block_size = q_block_size + self.has_other_buffer = has_other_buffer + self.no_full_kv_block = no_full_kv_block + self.other_buffer_input_offset = 2 + if self.no_full_kv_block: + self.other_buffer_input_offset = 0 + self.fake_buffers = fake_buffers + self.len_score_other = len_score_other + self.len_mask_other = len_mask_other + self.kernel_input_name_to_buffer = kernel_input_name_to_buffer + self.block_vars = block_vars + self.extra_sizevars = list( + OrderedSet( + val + for val in self.kernel_input_name_to_buffer.values() + if isinstance(val, sympy.Symbol) + ) + ) + self.other_buf_start_idx = 5 + self.score_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset : self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other + ] + if self.has_other_buffer + else None + ) + self.mask_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other : + ] + if self.has_other_buffer + else None + ) + self.other_ptr_data = {} # type: ignore[var-annotated] + + def update_kernel_args(self, kernel_args): + kernel_args.update( + { + key: value + for key, value in self.kernel_input_name_to_buffer.items() + if not isinstance(value, sympy.Symbol) + } + ) + return kernel_args + + def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args): + kernel_input_name_to_buffer_name = { + key: value if isinstance(value, sympy.Symbol) else value.get_name() + for key, value in self.kernel_input_name_to_buffer.items() + } + + def get_arg(name): + return kernel_input_name_to_buffer_name.get(name) + + def get_arg_name(name): + if isinstance(get_arg(name), sympy.Symbol): + return kernel_args.sizevars.get(get_arg(name)) + return kernel_args.input_buffers.get(get_arg(name)) + + if not self.has_other_buffer: + return "" + + if start_offset == -1: + start_offset = self.len_score_other + + length = getattr(self, len_attr) + for i in range(length): + pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}" + buffer_key = f"{buf_list}_{i}" + if pointer not in self.other_ptr_data: + self.other_ptr_data[pointer] = ( + get_arg_name(buffer_key), + get_arg(buffer_key), + ) + + return "\n".join( + f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items() + ) + + def modification(self, subgraph_buffer, output_name, output_idx): + assert isinstance(subgraph_buffer, ir.ComputedBuffer) + subgraph_buffer_data = subgraph_buffer.data + from ..loop_body import LoopBody + from ..utils import sympy_index_symbol_with_prefix, SymT + from ..virtualized import V + from .cpp import CppKernelProxy, KernelGroup, ParallelDepth + + kernel_group = KernelGroup() + kernel_input_args = { + "score": "in_ptr0", + "b": "in_ptr1", + "h": "in_ptr2", + "q_idx": "in_ptr3", + "kv_idx": "in_ptr4", + } + if self.has_other_buffer: + kernel_input_args.update( + {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()} + ) + + kernel_output_args = {output_name: f"out_ptr{output_idx}"} + + args = kernel_group.args + for name, inp in kernel_input_args.items(): + args.input_buffers[name] = inp + + for name, inp in kernel_output_args.items(): + args.output_buffers[name] = inp + + for name in self.extra_sizevars: + args.sizevars[name] = f"k{name}" + + kernel_group.args = args + + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + var_sizes = tuple(subgraph_buffer.get_size()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes) + } + + dst_layout = subgraph_buffer.get_layout() + output_index = dst_layout.make_indexer()([*var_ranges.keys()]) + + def fn(*args): + V.ops.store( + output_name, + output_index, + subgraph_buffer_data.make_loader()(args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys())), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + + from ..loop_body import MemoryUsageType + + assert all( + mem.buffer_name in kernel_group.args.input_buffers + for mem in body.memory_usage[MemoryUsageType.LOAD] + ), ( + "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" + ) + + bodies.append(body) + var_sizes_list.append((var_sizes, ())) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + output_code = kernel_group.loops_code.getvalue() + + var_q_symbol, var_kv_symbol = self.block_vars + # See [Note] Handle the case where the split sizes are not statically known. + # We don't know the value of qBlockSize and rkvBlockSize during compilation time + # thus we've represented them by symbols. + # We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize" + # in the generated code thus they'll be filled with the real value during runtime. + if var_q_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize" + ) + if var_kv_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize" + ) + + return output_code + + @staticmethod + def add_choices( + choices, + input_nodes, + layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ): + def preprocessor(input_nodes, layout): + return input_nodes, layout + + def postprocessor(output): + return output + + template = DataProcessorTemplateWrapper( + CppFlexAttentionTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=score_mod, + mask_mod=mask_mod, + kv_block_size=kv_block_size, + q_block_size=q_block_size, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len_score_other, + len_mask_other=len_mask_other, + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=block_vars, + ) + template.maybe_append_choice(choices) + return template + + def apply_score_mod(self, score, b, h, q_idx, kv_idx): + return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item() + + def render( # type: ignore[override,return] + self, + kernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + if epilogue_nodes is not None and epilogue_nodes != []: + raise NotImplementedError( + "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate." + ) + # Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + # -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + # Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + # Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + + query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3]) + key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3]) + value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3]) + self.accumulate_dtype = torch.float + self.input_dtype = query.layout.dtype + + num_threads = parallel_num_threads() + assert isinstance(self.output_node, ir.IRNode) + buf_out: ir.IRNode = TensorBox.create(self.output_node) + if template_buffer_node is not None: + buf_out = template_buffer_node + options = dict( + query=query, + key=key, + value=value, + kv_num_blocks=self.input_nodes[3], + kv_indices=self.input_nodes[4], + full_kv_num_blocks=( + self.input_nodes[5] if not self.no_full_kv_block else None + ), + full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None, + score_mod_other_buffers=self.score_mod_other_buffers, + mask_mod_other_buffers=self.mask_mod_other_buffers, + scale=self.scale, + has_full_kv_block=not self.no_full_kv_block, + accumulate_dtype=self.accumulate_dtype, + query_dtype=self.input_dtype, + kvBlockSize=self.kv_block_size, + qBlockSize=self.q_block_size, + template=self, + output=buf_out, + kernel=kernel, + num_thread=num_threads, + score_mod=self.score_mod, + mask_mod=self.mask_mod, + score_buf_name=self.score_buf_name, + mask_buf_name=self.mask_buf_name, + score_buf_idx=self.score_buf_idx, + mask_buf_idx=self.mask_buf_idx, + ) + with contextlib.ExitStack() as stack: + for buf in self.fake_buffers: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options) + + def codegen_softmax_fusion(self, kernel_name: str): + # TODO: use inductor IR to rewrite those fusions + return self._template_from_string(SOFTMAX_FUSIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_brgemm_pack_function(self, kernel_name: str): + # TODO: make them general for common bmm templates + return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size): + return self._template_from_string(ALLOCATE_BUFFER).render( + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + ) + ) + + def micro_gemm_define(self, kernel_name: str): + from torch._inductor.codegen.cpp_gemm_template import ( + CppTemplateKernel, + parallel_num_threads, + ) + from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec + from torch._inductor.virtualized import V + + micro_gemm_trans = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm_transpose_b", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + True, + ) + + micro_gemm = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + False, + ) + + with V.set_graph_handler(V.graph): + kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads()) + code_trans = micro_gemm_trans.codegen_define(kernel) + code = micro_gemm.codegen_define(kernel) + return code + code_trans + + def codegen_micro_gemm(self, kernel_name: str): + micro_gemm = self.micro_gemm_define(kernel_name) + GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm) + return self._template_from_string(GEMM_SOURCE_CODE).render() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..8b15ef253a4d0bf61e7449fd77d15f7107997019 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,1819 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +from collections.abc import Callable +from functools import lru_cache +from typing import Any, cast, Optional, TypeVar, Union +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir, lowering as L +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import ( + has_free_symbols, + is_same_mkldnn_tensor, + is_same_tensor, + parallel_num_threads, +) +from ..virtualized import ops, V +from .cpp import get_export_declaration +from .cpp_micro_gemm import ( + CppMicroBrgemm, + CppMicroGemm, + CppMicroGemmAMX, + CppMicroGemmFP32Vec, + create_micro_gemm, + is_int8_woq_gemm_small_m_dim_corner_case, + LayoutType, +) +from .cpp_template import CppTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r""" + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{N}}; + constexpr int64_t K = {{K}}; + constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; + constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; +{%- if is_dynamic_M %} + const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- else %} + constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- endif %} +""" + +GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r""" +{%- if is_dynamic_M %} + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = Mr_blocks; + const auto Nt_blocks = Nr_blocks; + const auto Kt_blocks = Kr_blocks; + {%- endif %} + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); + const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- else %} + constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}}; + constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}}; + constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- endif %} +{%- if is_woq_int4 %} + int64_t group_size = *q_group_size; +{%- endif %} + + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, + "Not all partitions are assigned." + ); +""" + +GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r""" +const int tid = omp_get_thread_num(); +const int64_t k_group_id = tid / num_Kt_blocks; +const int64_t k_slice_id = tid % num_Kt_blocks; +const int64_t n_group_id = k_group_id / num_Nt_blocks; +const int64_t n_slice_id = k_group_id % num_Nt_blocks; +const int64_t k_block_start = k_slice_id * Kt_blocks; +const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); +const int64_t n_block_start = n_slice_id * Nt_blocks; +const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); +const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); +const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); +const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; +""" + +GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r""" +constexpr int tid = 0; +constexpr int64_t k_group_id = 0; +constexpr int64_t k_slice_id = 0; +constexpr int64_t n_group_id = 0; +constexpr int64_t n_slice_id = 0; +constexpr int64_t m_block_start = 0; +constexpr int64_t n_block_start = 0; +constexpr int64_t n_block_end = Nr_blocks; +constexpr int64_t k_block_start = 0; +constexpr int64_t k_block_end = Kr_blocks; +{%- if is_dynamic_M %} +const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +const int64_t m_block_end = Mr_blocks; +{%- else %} +constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +constexpr int64_t m_block_end = Mr_blocks; +{%- endif %} +""" + +GEMM_TEMPLATE_M_LOOP_PARAMS = r""" +const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; +const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; +const int64_t m_start = mc * Mr; +const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); +const int64_t m_size = m_end - m_start; +""" + +GEMM_TEMPLATE_N_LOOP_PARAMS = r""" +const int64_t n_start = nc * Nr; +const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); +const int64_t n_size = n_end - n_start; +// NB: assume we pad N, nc_block_end won't exceed padded N here. +const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); +""" + +GEMM_TEMPLATE_MICROKERNEL_DEF = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} +""" + +GEMM_TEMPLATE_STUB_DEF = r""" +{%- if x_scale is not none %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} +{%- elif is_woq_int4 %} + {%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %} +{%- else %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} +{%- endif %} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} +""" + +GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + +{%- if maybe_k_slicing %} + std::unique_ptr[]> local_buf_ptrs; + if (num_Kt_blocks > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); + } +{%- endif %} + +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- if use_local_acc %} + {%- set acc = kernel.local_buffers[acc_buf_name] %} + {{ kernel.reinit_buffer_if_null(acc_buf_name) }} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- if template.should_block_weights and not is_woq_int4 %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} +{%- else %} + {%- if is_woq_int4 %} + {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %} + {%- set tile_qparam = kernel.slice_nd( + qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %} + {%- else %} + {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %} + {%- set tile_qparam = None %} + {%- endif %} +{%- endif %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=False, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } else { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=True, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( + {{ kernel.release_buffer(acc_buf_name) }}); + } else +{%- endif %} + { +{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + #pragma omp barrier + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + // We slice M-dim and each thread in the k-slicing group works on a slice + const int64_t m_start_unsliced = mc * Mr; + const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; + const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); + const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); + const int64_t m_size = m_end - m_start; + const int64_t m_offset = m_start - m_start_unsliced; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); + for (int64_t m = m_offset; m < m_offset + m_size; m++) { + #pragma omp simd + for (int64_t n = 0; n < n_size; n++) { + {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; + } + } + } + {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} + {{ kernel.store_output( + tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- endif %} + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + +SMALL_M_GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + # pragma omp parallel + { + #pragma omp for nowait + for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) { + // Handle one output M * Nr block in each thread + int64_t n_start = nr_block_id * Nr; + int64_t n_end = (nr_block_id + 1) * Nr; +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }} + {%- set acc = kernel.local_buffers[acc_buf_name] %} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) { + // this loop is not parallelized + int64_t k_start = kr_block_id * Kr; + int64_t k_end = std::min((kr_block_id + 1) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + if C10_UNLIKELY(kr_block_id == 0) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }} + } else if C10_UNLIKELY(k_end == K) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }} + } + } +{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers + )|indent(20, false) }} + } + } +} +""" + + +def _is_int8_gemm(inputs): + return ( + isinstance(inputs[0], ir.IRNode) + and inputs[0].get_dtype() in [torch.uint8, torch.int8] + ) or ( + isinstance(inputs[0], torch.Tensor) + and inputs[0].dtype in [torch.uint8, torch.int8] + ) + + +def get_padded_n(n, block_n): + return (n + block_n - 1) // block_n * block_n + + +_T = TypeVar("_T", ir.IRNode, torch.Tensor) + + +def transpose_w(W: _T, trans_w: bool) -> _T: + """ + Transpose W based on the trans_w flag. + """ + if isinstance(W, ir.IRNode): + if trans_w: + if not isinstance(W, ir.TensorBox): + # pyrefly: ignore [bad-assignment] + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) + else: + if trans_w: + assert isinstance(W, torch.Tensor) + # pyrefly: ignore [bad-assignment] + W = W.transpose(0, 1) + # pyrefly: ignore [bad-return] + return W + + +def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]: + """ + Expand Bias to the same size of X. + """ + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + # pyrefly: ignore [bad-assignment] + B = ir.TensorBox(B) + assert hasattr(X, "get_size") + # pyrefly: ignore [missing-attribute] + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + assert isinstance(X, torch.Tensor) + # pyrefly: ignore [bad-assignment] + B = B.expand(X.shape[0], B.shape[-1]) + return B + + +def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]): + """ + Prune unused tensors from `V.graph` since the GEMM Template use new packed weight. + """ + + def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor): + return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and ( + is_same_tensor(base_tensor, comp_tensor) + or is_same_mkldnn_tensor(base_tensor, comp_tensor) + ) + + def get_candidates(input_nodes, new_input_nodes): + # Only Constant Buffer like weight and bias might be changed in GEMM Template. + # The Inductor IR Node may changed, but still share the storage. For example: + # bias in bfloat16 case which only do the expand + return [ + node + for node in input_nodes + if ( + node not in new_input_nodes + and isinstance(node, (ir.TensorBox, ir.StorageBox)) + and node.get_name() in V.graph.constants + and not any( + ( + isinstance(new_node, (ir.TensorBox, ir.StorageBox)) + and new_node.get_name() in V.graph.constants + and share_storage( + V.graph.constants[node.get_name()], + V.graph.constants[new_node.get_name()], + ) + ) + for new_node in new_input_nodes + ) + ) + ] + + for candidate_node in get_candidates(input_nodes, new_input_nodes): + # By using the new packed weight for the GEMM template, we can prune the + # old weight if it has no other users. This saves memory but makes the FX graph + # non-retraceable. To support retracing, we can add a repack node to the + # FX graph. For example: + # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + candidate_tensor_users = 0 + candidate_tensor = V.graph.constants[candidate_node.get_name()] + for node in reversed(V.graph.graph.nodes): + # Case may happen when the candidate tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.target + ): # candidate tensor might already be deleted + comp_tensor = getattr(V.graph.module, node.target) + if isinstance(comp_tensor, torch.Tensor) and share_storage( + candidate_tensor, comp_tensor + ): + candidate_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The get_attr node has only 1 user fx node + # The candidate tensor has been used by only 1 get_attr node + if ( + node.op == "get_attr" + and node.target == candidate_node.get_name() + and len(node.users) == 1 + and candidate_tensor_users == 1 + ): + del V.graph.constants[node.target] + delattr(V.graph.module, node.target) + delattr(V.graph.graph.owning_module, node.target) + counters["inductor"]["select_algorithm_weight_prune"] += 1 + + +def gen_2d_view_of_epilogue_buf( + Y: ir.Buffer, + template_buffer: ir.Buffer, + epilogue_nodes: list[ir.IRNode], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], +) -> tuple[ + Union[ir.Buffer, ir.ReinterpretView], + list[Optional[Callable[[list[Any]], list[Any]]]], +]: + """ + The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is + 2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code, + we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to + the indexing of the GEMM output. + In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and + build a reindexer that converts the indexing of `Y` into `Y_2d`. + """ + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + if ( + Y.get_size() == template_buffer.get_size() + and Y.get_stride() == template_buffer.get_stride() + ): + reindexers.extend(default_reindexers) + Y_2d = Y + else: + + def get_reindexer(epilogue_node, default_reindexer=None): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + if default_reindexer: + reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) # type: ignore[var-annotated] + return reindexer + + if default_reindexers is None: + default_reindexers = [None] * len(epilogue_nodes) + new_reindexers = [ + get_reindexer(epilogue_node, default_reindexer) + for epilogue_node, default_reindexer in zip( + epilogue_nodes, default_reindexers + ) + ] + reindexers.extend(new_reindexers) + if isinstance(Y, ir.BaseView): + storage = ir.StorageBox(Y.unwrap_view()) + else: + assert isinstance(Y, ir.Buffer) + storage = ir.StorageBox(Y) + Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout()) + return Y_2d, reindexers + + +class CppGemmTemplate(CppTemplate): + """ + GEMM Template for Inductor CPP Backend. + """ + + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = True, + name="packed_gemm", + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] + super().__init__( + name, + input_nodes, + layout, + num_threads, + epilogue_creator=epilogue_creator, + ) + self.beta = beta + self.alpha = alpha + self.has_bias = has_bias + self.register_blocking = register_blocking + m, n = layout.size[-2:] + k = input_nodes[0].get_size()[-1] + self.m, self.n, self.k = m, n, k + self.padded_n = get_padded_n(n, self.register_blocking.block_n) + self.is_dynamic_M = has_free_symbols((m,)) + self.should_block_weights = should_block_weights + self.thread_blocking = self.make_thread_blocking_cache() + self.cache_blocking = self.make_cache_blocking_cache() + + def make_thread_blocking_cache(self): + cache = lru_cache()(self._thread_blocking) + + def thread_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return thread_blocking + + def _thread_blocking(self, num_threads: int) -> GemmBlocking: + """ + NOTE [Thread blocking in Cpp GEMM] + We use simple heuristics to decide the thread blocking: + 1. Make sure all threads are occupied as much as possible. + 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. + 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. + TODO(jgong5): allow tuning various blocking options + """ + + def get_factors(number): + factors = [] + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): + thread_block_k = math.ceil(k_blocks / k_factor) + thread_block_n = math.ceil(n_blocks / n_factor) + thread_block_m = math.ceil(m_blocks / m_factor) + return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) + + assert not self.is_dynamic_M, ( + "Unable to determine thread blocking for dynamic M." + ) + register_blocking = self.register_blocking + m_blocks = math.ceil(self.m / register_blocking.block_m) + n_blocks = math.ceil(self.n / register_blocking.block_n) + k_blocks = math.ceil(self.k / register_blocking.block_k) + factors = get_factors(num_threads) + assert len(factors) > 0 + + if config.cpp.gemm_thread_factors is not None: + factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] + assert len(factors) == 3 + assert math.prod(factors) == self.num_threads + return get_blocking( + factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks + ) + + # we favor square-sized thread blocks for good data reuse + def get_better_blocking(blocking, best_blocking): + if best_blocking is None: + best_blocking = blocking + else: + block_m_size = blocking.block_m * register_blocking.block_m + block_n_size = blocking.block_n * register_blocking.block_n + best_block_m_size = best_blocking.block_m * register_blocking.block_m + best_block_n_size = best_blocking.block_n * register_blocking.block_n + if blocking.block_k > best_blocking.block_k: + best_blocking = blocking + elif ( + blocking.block_k == best_blocking.block_k + and block_m_size + block_n_size + < best_block_m_size + best_block_n_size + ): + best_blocking = blocking + return best_blocking + + best_blocking = None + # check if we can have a thread-blocking to occupy all threads without k-slicing + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for k_factor in factors: + if k_blocks >= k_factor and ( + config.cpp.gemm_max_k_slices == 0 + or k_factor <= config.cpp.gemm_max_k_slices + ): + n_factors = get_factors(num_threads // k_factor) + for n_factor in n_factors: + m_factor = (num_threads // k_factor) // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, + n_factor, + k_factor, + m_blocks, + n_blocks, + k_blocks, + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor or m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + assert best_blocking is not None + return best_blocking + + def make_cache_blocking_cache(self): + cache = lru_cache()(self._cache_blocking) + + def cache_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return cache_blocking + + def _cache_blocking(self, num_threads: int) -> GemmBlocking: + def get_cache_blocking(register_blocking, thread_blocking): + Mr = register_blocking.block_m + Nr = register_blocking.block_n + Kr = register_blocking.block_k + + Mt_blocks = thread_blocking.block_m + Nt_blocks = thread_blocking.block_n + Kt_blocks = thread_blocking.block_k + + if config.cpp.gemm_cache_blocking is not None: + blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] + assert len(blockings) == 3 + Mc_blocks, Nc_blocks, Kc_blocks = blockings + return ( + min(Mc_blocks, Mt_blocks), + min(Nc_blocks, Nt_blocks), + min(Kc_blocks, Kt_blocks), + ) + + # The ratios below are empirically determined to decide + # the effective sizes of L1 and L2. + # TODO: tune the factor here + L1_limit_factor = 0.8 + L2_limit_factor = 0.5 + + L1_cache_size = ( + torch._C._cpu._L1d_cache_size() + ) # per core cache size in Bytes + assert L1_cache_size > 0, ( + f"Expect L1_cache_size > 0 but got {L1_cache_size}" + ) + L1 = L1_cache_size * L1_limit_factor + + L2_cache_size = ( + torch._C._cpu._L2_cache_size() + ) # per core cache size in Bytes + assert L2_cache_size > 0, ( + f"Expect L2_cache_size > 0 but got {L2_cache_size}" + ) + L2 = L2_cache_size * L2_limit_factor + + def get_num_byte(dtype): + return torch.tensor([], dtype=dtype).element_size() + + dtype_A = self.input_nodes[0].get_dtype() + dtype_B = self.input_nodes[1].get_dtype() + num_byte_A = get_num_byte(dtype_A) + num_byte_B = get_num_byte(dtype_B) + if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1: + # We will cache dequantized weights (BF16) in L1D for AMX micro-kernel. + # In this case, the choice of the micro-kernel being used can't be decoupled from + # the cache blocking. + # TODO: Decouple the choice of micro-kernel from cache blocking + num_byte_B *= num_byte_A + + # NOTE [CPP GEMM Cache Blocking Algorithm] + # Our overall strategy is to + # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. + # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to + # decide Kc. We want to make Mc large enough to better reuse B. + # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A + # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. + + # Step 1: Decide Kc assuming B block is L1-reside. + size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + + Kc_blocks = Kt_blocks + if size_cache_B > L1: + Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + + if ( + config.cpp.use_small_dequant_buffer + and dtype_A is torch.bfloat16 + and Mt_blocks == 1 + ): + if dtype_B is torch.uint8: + # A16W4 + # Make a small dequant_B buffer for woq int4 [q_group_size, Nr] + # Since when Mt_blocks == 1, L1-reside B block can't be reused by A. + if Kc_blocks * Kr >= self.q_group_size(): + Kc_blocks = self.q_group_size() // Kr + + elif dtype_B is torch.int8: + # A16W8 + # Make A, B, C buffer in L1 + A_buf_size_div_K = self.m * num_byte_A + B_buf_size_div_K = Nr * num_byte_B + # assume acc in float32/int32 and Mc_blocks = Nc_blocks = 1 + C_buf_size = Mr * Nr * 4 + K_block_size = (L1 - C_buf_size) // ( + A_buf_size_div_K + B_buf_size_div_K + ) + if Kc_blocks * Kr >= K_block_size: + Kc_blocks = (K_block_size + Kr - 1) // Kr + + # Step 2: Decide Mc assuming A block is L2-reside. + min_Mc_ratio = 2 # TODO(jgong5): something to tune? + min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) + assert min_Mc_blocks >= 1 + Kt_bytes = Kt_blocks * Kr * num_byte_A + if min_Mc_blocks * Mr * Kt_bytes < L2: + # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt + # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) + # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside + # in L1. + Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) + Nc_blocks = 1 + else: + # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse + # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. + Mc_blocks = Mt_blocks + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 + Kc_bytes = Kc_blocks * Kr * num_byte_A + if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: + # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, + # assuming Mc == Nc for good data reuse. + M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 + if M_max < Mc_blocks * Mr: + Mc_blocks = math.floor(M_max / Mr) + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + + return Mc_blocks, Nc_blocks, Kc_blocks + + assert not self.is_dynamic_M, ( + "Unable to determine cache blocking for dynamic M." + ) + register_blocking = self.register_blocking + thread_blocking = self.thread_blocking(num_threads) + + return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) + + def log_blockings(self): + log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 + if self.is_dynamic_M: + # thread and cache blockings are determined at runtime for dynamic shapes + return + log.debug( + f"Cache blocking: {self.cache_blocking(self.num_threads)}" # noqa: G004 + ) + thread_blocking = self.thread_blocking(self.num_threads) + log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 + + def get_occupancy(): + m_blocks = math.ceil(self.m / self.register_blocking.block_m) + n_blocks = math.ceil(self.n / self.register_blocking.block_n) + k_blocks = math.ceil(self.k / self.register_blocking.block_k) + m = math.ceil(m_blocks / thread_blocking.block_m) + n = math.ceil(n_blocks / thread_blocking.block_n) + k = math.ceil(k_blocks / thread_blocking.block_k) + return (m, n, k) + + log.debug( + f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 + ) + + def maybe_k_slicing(self): + if self.num_threads == 1: + return False + if self.is_dynamic_M: + # TODO(jgong5): perhaps use size hint to decide? + return True + register_blocking = self.register_blocking + k_blocks = math.ceil(self.k / register_blocking.block_k) + thread_blocking = self.thread_blocking(self.num_threads) + return k_blocks > thread_blocking.block_k + + @classmethod + def add_choices( + cls, + choices, + layout, + input_nodes, + beta=1, + alpha=1, + has_bias=False, + trans_w=False, + input_indices=None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + ): + """ + Add choices for the GEMM template. + """ + # Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant + use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all( + ( + isinstance(input_nodes[idx], ir.TensorBox) + and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer) + ) + for idx in [1, 2, 4] + ) + + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if has_bias: + assert len(input_indices) >= 3 + # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [ + inputs[x_idx], + inputs[w_idx], + inputs[inp_idx], + *[inputs[idx] for idx in input_indices[3:]], + ], layout_or_out + elif len(inputs) >= len(input_indices): + assert len(input_indices) >= 2 + return [inputs[idx] for idx in input_indices], layout_or_out + else: + # For when input is used for x and w, i.e. X@X.T or similar + # Assumes the first input is the only input + assert len(inputs) == 1 + return [inputs[0]] * len(input_indices), layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + is_mkldnn_wgt = ( + new_inputs[1].get_name() in V.graph.constants + and V.graph.constants[new_inputs[1].get_name()].is_mkldnn + ) + if is_mkldnn_wgt: + # It shouldn't happen as viewing an mkldnn tensor, we can extend the + # implementation if it does. + assert not isinstance(new_inputs[1], ir.BaseView) + # Note that the layout of MKLDNN Tensor is with the wrong stride + view_size = new_inputs[1].layout.size + view_stride = new_inputs[1].layout.stride + view_offset = new_inputs[1].layout.offset + + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): + new_inputs = list(inputs) + if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor): + if has_free_symbols(view_size): + # If batch size B is dynamic, we need to set the batch size and possibly stride + assert not has_free_symbols(view_size[1:]) + view_size[:] = V.graph.sizevars.size_hints(view_size) + view_stride[:] = V.graph.sizevars.size_hints(view_stride) + # With the assumptation that W is the storage of unwrap view + # thus view it back here + new_inputs[1] = new_inputs[1].as_strided( + view_size, view_stride, view_offset + ) + + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + W = new_inputs[1] + B = new_inputs[2] if has_bias else None + W = transpose_w(W, trans_w) + B = expand_bias(B, X) # type:ignore[arg-type] + new_inputs[1] = W + if B is not None: + new_inputs[2] = B + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args( + new_inputs[0], + new_inputs[1], + mat2_transposed=cls.is_woq_int4(), + use_4x2_dim=cls.is_woq_int4(), + ) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[1].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + use_ref=not cls.is_woq_int4(), + q_group_size=cls.q_group_size(), + ) + assert micro_gemm is not None + pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm) + micro_gemm.use_local_vnni_blocking(not pre_block_weights) + only_one_input = ( + input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False + ) and not pre_block_weights # If weights are blocked, use the second input + + def preprocessor(inputs, layout): + new_inputs, new_layout = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(inputs, layout)) + ) + if only_one_input and isinstance(new_inputs[0], torch.Tensor): + return new_inputs[1:], new_layout + return cls.prep_weight( + new_inputs, + new_layout, + # pyrefly: ignore [bad-argument-type] + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + ) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + if W_node.get_name() not in V.graph.constants: + return output + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, new_layout = normalize_shapes( + *maybe_to_dense(new_input_nodes, layout) + ) + new_input_nodes, _ = cls.prep_weight( + new_input_nodes, + new_layout, + # pyrefly: ignore [bad-argument-type] + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + skip_int8_compensation=True, + ) + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + new_input_nodes[1] = W_packed_constant + + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + cls, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=pre_block_weights, + name=micro_gemm.__class__.__name__, + ) + template.maybe_append_choice(choices) + return template + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + padded_n = get_padded_n(n, block_n) + # We assume that all GEMM weight tensors should be blocked and padded + new_size = [padded_n // block_n, k, block_n] + return new_size, padded_n + + @staticmethod + def _maybe_remove_storage_offset(node: ir.IRNode): + if node.get_layout().offset == 0: + return node + # node may be contiguous but still have a non-zero storage offset. + # GEMM_TEMPLATE emits code like: + # W.data_ptr[node.offset + ...] + # but runtime W.data_ptr (after normalize_shapes()) already includes this offset. + # To avoid double-offsetting, we remove the offset in the node also in the generated code. + # W.data_ptr[...] + return ir.ExternKernel.copy_input(node) + + @classmethod + def prep_weight( + cls, + inputs, + layout: ir.Layout, + micro_gemm: CppMicroGemm, + should_block_weight: bool, + use_int8_fast_compensation_path: bool = False, + skip_int8_compensation: bool = False, + ): + """ + NOTE Weight prep consists of 2 separate steps: + 1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n] + This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM. + For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway. + This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel, + and is worth the overhead of reshape and blocking. + + This blocking includes additional padding, when n is not a multiple of block_n. + This padding allows a more efficient microkernel implementation. For BMM, this is only done + if reshape would happen anyway, i.e. if the weight tensor is constant, is not contiguous, + or is using AMX VNNI layout. + 2. Packing the weight tensor into a VNNI-friendly shape. For constant input, + this is done at the same time as the weight blocking. + + At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM) + which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime. + + CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate + an additional dimension for the batch size and to determine if the weight tensor should be blocked. + """ + W = inputs[1] + new_inputs = list(inputs) + if cls.is_woq_int4(): + assert ( + len(W.get_size()) == 2 + if isinstance(W, ir.IRNode) + else len(W.shape) == 2 + ) + n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape + else: + k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:] + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight) + padding = padded_n - n + + if should_block_weight and not cls.is_woq_int4(): + blocked_w = cls.block_weight(W, new_size, padding) + new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size) + elif should_block_weight: + assert cls.is_woq_int4() + new_inputs[1] = cls.block_weight(W, new_size, padding) + elif isinstance(W, ir.IRNode): + # Require W layout to be fixed & contiguous, happens inplace. + ir.ExternKernel.require_contiguous(W) + new_inputs[1] = cls._maybe_remove_storage_offset(W) + + if not skip_int8_compensation and _is_int8_gemm(new_inputs): + BCompensate = None + x_w_scale = None + + def _get_compensation_node(W, use_int8_fast_compensation_path): + BCompensate = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_BMatrixCompens"], + W.get_name() + "_BMatrixCompens", + ) + x_w_scale = None + if use_int8_fast_compensation_path: + x_w_scale = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_x_w_compens"], + W.get_name() + "_x_w_compens", + ) + return BCompensate, x_w_scale + + if use_int8_fast_compensation_path: + # new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp] + x_scale = new_inputs[-4] + x_zp = new_inputs[-3] + w_scale = new_inputs[-2] + if isinstance(W, ir.IRNode): + BCompensate, x_w_scale = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + assert all( + isinstance(item, torch.Tensor) + for item in (x_scale, x_zp, w_scale) + ) + BCompensate = BCompensate * x_scale * w_scale * x_zp + x_w_scale = x_scale * w_scale + new_inputs.append(BCompensate) + new_inputs.append(x_w_scale) + else: + if isinstance(W, ir.IRNode): + BCompensate, _ = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + new_inputs.append(BCompensate) + return new_inputs, layout + + @staticmethod + def check_if_block_weight(W, micro_gemm): + return True + + @classmethod + def block_weight(cls, W, new_size, padding): + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if W.get_name() in V.graph.constants: + # Create a new buffer, representing the constant blocked tensor + blocked_w = ir.Buffer( + name=W.get_name(), # Borrow the registered buffer name + layout=ir.FixedLayout( + W.get_device_or_error(), + W.get_dtype(), + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + 0, + ), + ) + else: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + permute_dims = list(range(len(new_size))) + permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2] + permute_size = list(new_size) + permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2] + blocked_w = L.constant_pad_nd(W, (0, padding)) + blocked_w = L.permute( + L.view(blocked_w, permute_size), # type: ignore[arg-type] + permute_dims, + ) + else: + assert isinstance(W, torch.Tensor) + # Pad the weight tensor and reshape it into a 3D blocked shape + blocked_size = list(new_size) + blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2] + blocked_w = ( + torch.nn.functional.pad(W, (0, padding)) # type: ignore[assignment] + .reshape(*blocked_size) + .transpose(-3, -2) + .contiguous() + ) + return blocked_w + + @classmethod + def pack_vnni_weight(cls, W, micro_gemm, new_size): + # WOQ INT4 weights are reordered in microkernel so do not pack them here + should_pack = ( + micro_gemm.get_b_layout() != LayoutType.NORMAL + and not micro_gemm.is_woq_int4() + ) + + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants: + return W + k = new_size[-2] + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + if should_pack: + permute_dims = list(range(len(new_size) + 1)) + permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1] + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = L.view( + L.permute(L.view(W, vnni_view_size), permute_dims), + new_size, + ) + W = ir.ExternKernel.realize_input(W) + W = ir.ExternKernel.require_contiguous(W) + return W + else: + k = new_size[-2] + # Apply VNNI packing to the weight tensor + if should_pack: + # TODO: Move VNNI weight packing for non-constant tensors into the template, + # to improve cache locality and avoid full-tensor copy. + layout_str = ( + "VNNI4" + if micro_gemm.get_b_layout() == LayoutType.VNNI4 + else "VNNI2" + ) + assert micro_gemm.get_b_layout() in [ + LayoutType.VNNI2, + LayoutType.VNNI4, + ], f"We only support {layout_str} for now" + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + assert k % vnni_size == 0, ( + f"k should be divisible by vnni_size for {layout_str} layout" + ) + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(W.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + W = W.as_strided(W.shape, new_stride) + return W + + def get_default_reindexers(self, epilogue_nodes): + return [None] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ) -> dict[str, Any]: + assert len(self.input_nodes) >= 2 + + int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8] + x_scale = None + x_zp = None + w_scale = None + w_zp = None + inp = None + q_group_size_node = None + qscale_and_zeros = None + if int8_gemm: + X, W = self.input_nodes[0], self.input_nodes[1] + bias_idx = 2 if self.has_bias else 1 + inp = self.input_nodes[bias_idx] if self.has_bias else None + x_scale = self.input_nodes[bias_idx + 1] + x_zp = self.input_nodes[bias_idx + 2] + w_scale = self.input_nodes[bias_idx + 3] + w_zp = self.input_nodes[bias_idx + 4] + Y = self.output_node + elif self.is_woq_int4(): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + q_group_size_node = self.input_nodes[2] + qscale_and_zeros = self.input_nodes[3] + else: + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + inp = self.input_nodes[2] if self.has_bias else None + + template_buffer_has_other_users = None + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + + assert flag_template_buffer_has_other_users is not None + template_buffer_has_other_users = flag_template_buffer_has_other_users + + template_buffer = Y + gemm_output_buffer = template_buffer + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = [] + fake_buffers: list[ir.Buffer] = [] + Y_aliases: OrderedSet[str] = OrderedSet() + + use_local_acc = ( + self.layout.dtype != torch.float + or template_buffer_has_other_users + or int8_gemm + or self.padded_n != self.n + or self.maybe_k_slicing() + or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype) + ) + + # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, + # but we'd better move it here to align with fp. + if inp is not None and self.beta != 0 and not int8_gemm: + # add an epilogue for bias add + def _bias_add_epilogue(buf): + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + epilogue_creators.append(_bias_add_epilogue) + + if self.epilogue_creator is not None: + epilogue_creators.append(self.epilogue_creator) + + # When the GEMM output buffer is localized but it has users other than the epilogue nodes, + # we need to copy the value in the GEMM output local buffer to a global buffer. + def need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + # The GEMM output buffer is a global buffer, thus copy is not needed. + if not use_local_acc: + return False + + # The possible value of template_buffer_has_other_users is (None, False, True) + # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. + # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: + # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) + # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the + # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). + if not template_buffer_has_other_users: + return False + + # When bias is not None or self.epilogue_creator is not None, + # there will be epilogue_creators after the GEMM. + # The GEMM output buffer is localized while + # the output buffer of the epilogue_creators is a global buffer. + if epilogue_creators: + return False + + return True + + if need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + + def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): + dtype = self.layout.dtype + input_loader = input_buffer.make_loader() + + def copy_inner(index): + input = input_loader(index) + result = ops.to_dtype(input, dtype) + return result + + return ir.Pointwise( + device=input_buffer.get_device_or_error(), + dtype=self.layout.dtype, + inner_fn=copy_inner, + ranges=input_buffer.get_size(), + ) + + epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) + + # NOTE [How CPP GEMM template epilogues are organized] + # gemm_output_buffer + # --> zero or more in-template epilogues (created by `epilogue_creators`) --> + # template_buffer + # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> + # Y + if epilogue_creators: + assert isinstance(template_buffer, ir.IRNode) + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + gemm_output_buffer = ir.Buffer( + name=gemm_output_name, + # pyrefly: ignore [missing-attribute] + layout=template_buffer.layout, + ) + current_input_buffer = gemm_output_buffer + for i, creator in enumerate(epilogue_creators): + if i == len(epilogue_creators) - 1: + buffer_name = template_buffer.get_name() + else: + buffer_name = f"{gemm_output_name}_epilogue_{i}" + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + # pyrefly: ignore [missing-attribute] + layout=template_buffer.layout, + data=creator(current_input_buffer), + ) + ) + fake_buffers.append(current_input_buffer) + Y_aliases.add(current_input_buffer.get_name()) + reindexers.append(None) + if i < len(epilogue_creators) - 1: + current_input_buffer = ir.Buffer( + name=buffer_name, + # pyrefly: ignore [missing-attribute] + layout=template_buffer.layout, + ) + + assert isinstance(Y, (ir.Buffer, ir.ReinterpretView)) + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + + if epilogue_nodes: + if not template_buffer_has_other_users: + assert isinstance(template_buffer, ir.IRNode) + Y_aliases.add(template_buffer.get_name()) + epilogues.extend(epilogue_nodes) + assert Y.get_numel() == epilogues[-1].get_numel() + Y = cast(ir.Buffer, epilogues[-1]) + assert isinstance(template_buffer, ir.Buffer) + Y_2d, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + epilogue_nodes, + reindexers, + default_reindexers=self.get_default_reindexers(epilogue_nodes), + ) + + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X.get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X.get_dtype(), + # pyrefly: ignore [missing-attribute] + input2_dtype=W.get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + use_ref=not self.is_woq_int4(), + q_group_size=self.q_group_size(), + ) + assert micro_gemm is not None + micro_gemm.use_local_vnni_blocking(not self.should_block_weights) + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + if isinstance(micro_gemm, CppMicroBrgemm): + counters["inductor"]["cpp_micro_brgemm_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + GemmOut=gemm_output_buffer, + aliases={alias: Y.get_name() for alias in Y_aliases}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + epilogue_nodes=epilogues, + reindexers=reindexers, + Y_2d=Y_2d, + use_local_acc=use_local_acc, + maybe_k_slicing=self.maybe_k_slicing(), + x_scale=x_scale, + x_zp=x_zp, + w_scale=w_scale, + w_zp=w_zp, + acc_buf_dtype=torch.int32 if int8_gemm else torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + fake_buffers=fake_buffers, + is_woq_int4=self.is_woq_int4(), + q_group_size=q_group_size_node, + qscale_and_zeros=qscale_and_zeros, + ) + return options + + def is_int8_woq_gemm_small_m_dim( + self, + X: ir.ReinterpretView, + W: ir.ReinterpretView, + N, + K, + micro_gemm, + ): + """Use SMALL_M_GEMM_TEMPLATE""" + return ( + isinstance(micro_gemm, CppMicroGemmFP32Vec) + and is_int8_woq_gemm_small_m_dim_corner_case( + micro_gemm, X.get_size()[0], N, K + ) + and X.get_dtype() is torch.bfloat16 + and W.get_dtype() is torch.int8 + ) + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim( + options["X"], + options["W"], + options["N"], + options["K"], + options["micro_gemm"], + ): + template_str = SMALL_M_GEMM_TEMPLATE + else: + template_str = GEMM_TEMPLATE + return self._template_from_string(template_str).render(**options) + + def codegen_blocks( + self, + num_threads, + N, + K, + micro_gemm, + is_dynamic_M, + kernel, + GemmOut, + config, + L1_cache_size, + L2_cache_size, + X, + W, + ): + options = dict( + num_threads=num_threads, + N=N, + K=K, + micro_gemm=micro_gemm, + is_dynamic_M=is_dynamic_M, + kernel=kernel, + GemmOut=GemmOut, + config=config, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + template=self, + X=X, + W=W, + is_woq_int4=self.is_woq_int4(), + ) + template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK + if not ( + not is_dynamic_M + and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm) + ): + template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED + return self._template_from_string(template_str).render(options) + + def codegen_microkernel_def(self): + return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + microkernel = self.codegen_microkernel_def() + return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render( + self.render_options + ) + + def codegen_multi_threads_params(self): + return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render() + + def codegen_single_thread_params(self, is_dynamic_M): + options = dict( + is_dynamic_M=is_dynamic_M, + ) + return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render( + options + ) + + def codegen_m_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render() + + def codegen_n_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render() + + @classmethod + def is_woq_int4(cls): + return False + + @classmethod + def q_group_size(cls): + return None + + +class CppWoqInt4GemmTemplateMeta(type): + def __getitem__(cls, q_group_size): + class CppWoqInt4GemmTemplateInstance(CppGemmTemplate): + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__( + *args, + **kwargs, + ) + + @classmethod + def is_woq_int4(cls): + return True + + @classmethod + def q_group_size(cls): + return q_group_size + + @staticmethod + def check_if_block_weight(W, micro_gemm): + # For WOQ INT4, weight is already packed + # However, for AMX microkernel, we want to change the blocking of weight + from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx + + return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx) + + @classmethod + def block_weight(cls, W, new_size, padding): + # This method is called only if AMX microkernels are used. + # In this case, we unpack and repack weight so that block_n=32 + # the format of packed weight is described here: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 + if isinstance(W, ir.IRNode): + # in this case, we do nothing + ir.ExternKernel.require_contiguous(W) + blocked_w = W + else: + # in this case, we unpack and repack weight + assert isinstance(W, torch.Tensor) + assert W.dim() == 2 + N = W.size(0) + K = W.size(-1) * 2 + G = cls.q_group_size() + # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel + # so that the unpacking process is faster + x = torch.eye(K).bfloat16() + # Here we use scale=1 and qzero=8 because we want to unpack weight + # without dequantizing it. The qzero here is 8 instead of 0 because + # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95 + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .bfloat16() + .expand(K // G, N, 2) + .contiguous() + ) + # shape: [K, N] + unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + W, + G, + qscales_and_zeros, + ).to(torch.uint8) + block_n = 32 + # shape: [N // block_n, K, block_n] + w_blocked = ( + unpacked_w.view(K, N // block_n, block_n) + .permute(1, 0, 2) + .contiguous() + ) + # pack 2 int4 -> 1 int8 + # block_n: [a0, a1, ..., a15, b0, b1, ..., b15] + # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...] + # shape: [N // block_n, K, 2, block_n // 2] + w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2) + # shape: [N // block_n, K, block_n // 2] + w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | ( + w_blocked[:, :, 1, :] << 4 + ) + # shape: [N, K // 2] + blocked_w = w_blocked_packed.view(N, K // 2) + + return blocked_w + + return CppWoqInt4GemmTemplateInstance + + +class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..abea505b2d069a26c2d1ed181e217a88fb61d0d4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py @@ -0,0 +1,511 @@ +import contextlib +import logging +from collections.abc import Callable +from typing import Any, cast, Optional, TypeVar +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir +from ..kernel.mm_common import mm_args +from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp import get_export_declaration +from .cpp_gemm_template import ( + CppGemmTemplate, + expand_bias, + gen_2d_view_of_epilogue_buf, + prune_tensors, + transpose_w, +) +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +{{micro_gemm.codegen_define(kernel)}} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}} +{ + {{kernel.maybe_codegen_profile()}} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0] + ) }} +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- set acc_buf_name_list=[] %} +{%- set acc_buf_name_prefix = "local_acc_buf_" %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} + {%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %} +{%- endfor %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- set acc_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %} + {{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }} +{%- endfor %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %} +{%- endfor %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set tile_W_3d_list=[] %} +{%- set tile_W_list=[] %} +{%- set acc_slice_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_slice_list = acc_slice_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) + ) %} + {%- set tile_W_3d_list = tile_W_3d_list.append( + kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()]) + ) %} +{%- endfor %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_W_list = tile_W_list.append( + kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n]) + ) %} +{%- endfor %} + if (kc == k_block_start) { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False + )|indent(28, false) }} + {%- endfor %} + } else { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True + )|indent(28, false) }} + {%- endfor %} + } + } + } + { +{%- set tile_acc_list = [] %} +{%- set tile_Y_list = [] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_acc_list = tile_acc_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")]) + ) %} + {%- set tile_Y_list = tile_Y_list.append( + kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")]) + ) %} +{%- endfor %} + {{ kernel.store_outputs( + tile_Y_list, + tile_acc_list, + GemmOuts, + epilogue_nodes, + offsets=("m_start", "n_start"), + reindexers=reindexers, + multi_output_buffers=multi_output_buffers + )|indent(20, false) + }} + } + } + } + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + + +def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]: + act_deduplicated = [] + act_deduplicated_name: OrderedSet[str] = OrderedSet() + for act_idx in range(len(act_mapping.values())): + act = act_mapping[act_idx] + if act.get_name() not in act_deduplicated_name: + act_deduplicated.append(act) + act_deduplicated_name.add(act.get_name()) + return act_deduplicated + + +class CppGroupedGemmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes: list[ir.IRNode], + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta: int = 1, + alpha: int = 1, + has_bias: bool = False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + gemm_grouped_num: int = 1, + ) -> None: + """ + Template for Group of GEMMs: + * Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc) + for their A, B, and C matrices. + * Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues. + * In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues. + This behavior can be extended in the future if needed. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta, + alpha, + has_bias, + epilogue_creator, + ) + self.act_mapping = act_mapping + self.gemm_grouped_num = gemm_grouped_num + # pyrefly: ignore [bad-override] + self.output_node: list[ir.Buffer] = [ + ir.Buffer(name="buf_out" + str(idx), layout=layout) + for idx in range(gemm_grouped_num) + ] + + @classmethod + # pyrefly: ignore [bad-override] + def add_choices( + cls, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[ir.IRNode], + beta: int = 1, + alpha: int = 1, + has_bias: tuple[bool, ...] = (False, False), + trans_w: bool = False, + input_indices: Optional[list[int]] = None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf + ) -> DataProcessorTemplateWrapper: + # Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ... + gemm_grouped_num = len(has_bias) + assert act_mapping + act_deduplicated = get_deduplicated_act(act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + gemm_grouped_num + input_indices = list(range(len(input_nodes))) + + _T = TypeVar("_T", ir.IRNode, torch.Tensor) + _U = TypeVar("_U", ir.Layout, torch.Tensor) + + def reorder_and_filter( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + assert input_indices is not None, "input_indices must be set" + return [inputs[idx] for idx in input_indices], layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + + def maybe_to_dense( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs = list(inputs) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + if isinstance(inputs[idx], torch.Tensor): + W = inputs[idx] + assert isinstance(W, torch.Tensor), "W must be a torch.Tensor" + # pyrefly: ignore [unsupported-operation] + new_inputs[idx] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs: list[_T] = list(inputs) + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + new_input = new_inputs[wgt_idx] + new_inputs[wgt_idx] = transpose_w(new_input, trans_w) + for bias_idx in range(bias_start_idx, len(new_inputs)): + # pyrefly: ignore [bad-argument-type] + new_bias = expand_bias(new_inputs[bias_idx], X) + assert new_bias is not None + # pyrefly: ignore [unsupported-operation] + new_inputs[bias_idx] = new_bias + return new_inputs, layout_or_out + + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx]) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[wgt_start_idx].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size( + n, block_n, k, should_block_weight=True + ) + padding = padded_n - n + + def pack_weight( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_W_list = [] + new_inputs = list(inputs) + W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] + for W in W_list: + blocked_w = cls.block_weight(W, new_size, padding) + new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)) + new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list + return new_inputs, layout_or_out + + def preprocessor( + inputs: list[_T], + layout: _U, + ) -> tuple[list[_T], _U]: + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) + + def postprocessor(output: _T) -> _T: + if isinstance(output, ir.TensorBox): + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + W_nodes = new_input_nodes[ + wgt_start_idx : wgt_start_idx + gemm_grouped_num + ] + W_tensor = [] + for W_node in W_nodes: + assert W_node.get_name() in V.graph.constants + # pyrefly: ignore [bad-argument-type] + W_tensor.append(V.graph.constants[W_node.get_name()]) + new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( + W_tensor # type: ignore[assignment] + ) + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + W_packed = new_input_nodes[idx] + assert isinstance(W_packed, torch.Tensor) + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[idx] = ( + ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) + ) + # pyrefly: ignore [bad-return] + return output + + template = DataProcessorTemplateWrapper( + CppGroupedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + act_mapping=act_mapping, + gemm_grouped_num=gemm_grouped_num, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override,return,no-untyped-def] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert self.act_mapping + act_deduplicated = get_deduplicated_act(self.act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + self.gemm_grouped_num + X_list = list(self.act_mapping.values()) + W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num] + inp_list = [] + cur_idx = bias_start_idx + for inp_idx in range(self.gemm_grouped_num): + inp = None + # pyrefly: ignore [index-error] + if self.has_bias[inp_idx]: + inp = self.input_nodes[cur_idx] + cur_idx += 1 + inp_list.append(inp) + + Y_list = self.output_node + multi_output_buffers = None + if template_buffer_node is not None: + W_list = template_buffer_node.inputs[ + wgt_start_idx : wgt_start_idx + self.gemm_grouped_num + ] + assert isinstance(template_buffer_node.outputs, list) + Y_list = template_buffer_node.outputs + counters["inductor"]["cpp_grouped_gemm_template"] += 1 + multi_output_buffers = template_buffer_node.outputs + + template_buffer = Y_list[0] + fake_buffers: list[ir.Buffer] = [] + Y_2d_list = Y_list + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X_list[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X_list[0].get_dtype(), + # pyrefly: ignore [missing-attribute] + input2_dtype=W_list[0].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + gemm_output_buffers: list[ir.Buffer] = [] + for out_buf_idx in range(self.gemm_grouped_num): + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str( + out_buf_idx + ) + gemm_output_buffers.append( + ir.Buffer(name=gemm_output_name, layout=template_buffer.layout) + ) + + assert not self.epilogue_creator, ( + "epilogue_creator is not supported yet in Grouped GEMM Template" + ) + + kernel_args: dict[str, Optional[ir.IRNode]] = {} + for x_idx in range(wgt_start_idx): + kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx] + for w_idx in range(self.gemm_grouped_num): + # pyrefly: ignore [unsupported-operation] + kernel_args["W" + str(w_idx)] = W_list[w_idx] + for inp_idx in range(self.gemm_grouped_num): + kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx] + + def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise: + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + for gemm_idx, inp in enumerate(inp_list): + if inp: + buffer_name = Y_list[gemm_idx].get_name() + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp), + ) + ) + reindexers.append(None) + + if epilogue_nodes: + epilogues.extend(epilogue_nodes) + for epilogue_node in epilogue_nodes: + Y = cast(ir.Buffer, epilogue_node) + _, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + [ + epilogue_node, + ], + reindexers, + default_reindexers=[ + None, + ], + ) + + options = dict( + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + aliases={}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + acc_buf_dtype=torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + epilogue_nodes=epilogues, + GemmOuts=gemm_output_buffers, + reindexers=reindexers, + kernel_args=kernel_args, + X_list=X_list, + W_list=W_list, + gemm_grouped_num=self.gemm_grouped_num, + Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)}, + Y_2d_list=Y_2d_list, + multi_output_buffers=multi_output_buffers, + ) + with contextlib.ExitStack() as stack: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers)) + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..39c026949fb13d541191b7462ad8f5666f09c098 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,2232 @@ +# mypy: allow-untyped-defs +import dataclasses +import operator +import sys +from collections.abc import Callable +from enum import Enum +from typing import Optional + +import torch + +from .. import cpp_builder, ir +from ..cpu_vec_isa import ( + pick_vec_isa, + VecAMX, + VecAVX2, + VecAVX512, + VecAVX512VNNI, + VecISA, + VecNEON, + VecSVE256, +) +from ..utils import IndentedBuffer, parallel_num_threads +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class LayoutType(Enum): + NORMAL = 0 + VNNI2 = 1 + VNNI4 = 2 + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_restrict_keyword() -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170 + return "__restrict" + else: + return "__restrict__" + + +class CppMicroGemm: + """ + A class that codegens a kernel that computes small-sized matrix multiplication. + + A micro GEMM kernel is responsible for register blocking, instruction selection, + and other CPU architecture-specific optimizations. + + The subclasses need to override `codegen_define` to define the kernel function + that is called by the code generated by `codegen_call`. + """ + + # TODO(jgong5): support constant shapes and lds as template args. + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( +{%- if kernel_extra_args_declare %} + {{kernel_extra_args_declare}} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ) -> None: + self.name = name + self.input_dtype = input_dtype + assert input2_dtype is not None + self.input2_dtype = input2_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + self.pack_vnni_B_locally = False + + def get_common_options(self): + if self.input_dtype in [torch.uint8, torch.int8]: + assert self.compute_dtype == torch.int32 + assert self.output_dtype == torch.int32 + assert self.input2_dtype == torch.int8 + return { + "torch": torch, + "kernel_name": self.name, + "input_dtype": self.input_dtype, + "input2_dtype": self.input2_dtype, + "output_dtype": self.output_dtype, + "compute_dtype": self.compute_dtype, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "input2_t": DTYPE_TO_CPP[self.input2_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), + "int8_gemm": self.input_dtype in [torch.uint8, torch.int8], + "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2, + "restrict_keyword": get_restrict_keyword(), + "pack_vnni_B_locally": self.pack_vnni_B_locally, + "template": self, + "is_woq_int4": self.is_woq_int4(), + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def get_kernel_extra_args_declare(self) -> str: + return "" + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return [] + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + prefetch: bool = False, + **kwargs_for_extra_args, + ) -> str: + """ + Generate the code for calling the templated kernel that computes + `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. + """ + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + res = IndentedBuffer() + res.writeline( + f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>(" + ) + with res.indent(): + kwargs_for_extra_args.update({"kernel": kernel}) + extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args) + for arg in extra_args: + res.writeline(arg) + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() + + def use_local_vnni_blocking(self, should_block_weight: bool): + self.pack_vnni_B_locally = should_block_weight + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def get_b_layout(self) -> LayoutType: + return LayoutType.NORMAL + + ALLOCATE_WEIGHT_BUFFER = r""" + {%- if is_msvc_compiler %} + // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here. + auto heap_deq_b_buf_ptr = std::make_unique<{{buffer_dtype}}[]>({{buffer_size}}); + {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get(); + {%- else %} + // It's safe to use a stack-allocated array since the blocking strategy would + // require us to allocate an array that's smaller than the size of L1D cache, + // and the default per thread max stack size on Linux is quite higher, + // so we need not worry about stack overflow. + alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}]; + {%- endif %} +""" + + def codegen_allocate_weight_buffer( + self, buffer_name: str, buffer_dtype: str, *size_args + ) -> str: + buffer_size = " * ".join(map(str, size_args)) + return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render( + { + "buffer_name": buffer_name, + "buffer_dtype": buffer_dtype, + "buffer_size": buffer_size, + "is_msvc_compiler": cpp_builder.is_msvc_cl(), + } + ) + + def is_woq_int4(self): + return False + + +@dataclasses.dataclass +class CppMicroGemmConfig: + input_dtype: torch.dtype + input2_dtype: torch.dtype + output_dtype: torch.dtype + compute_dtype: torch.dtype + vec_isa_cls: type[VecISA] + register_blocking: GemmBlocking + extra_check: Optional[Callable[..., bool]] = None + + +micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert cls not in micro_gemm_configs, ( + f"Duplicate micro_gemm registration for {cls}" + ) + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + +def generate_gemm_config( + vec_isa_cls, + register_blockings, + input_dtype=torch.float, + input2_dtype=None, + output_dtype=None, + compute_dtype=None, + extra_check=None, +): + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if input2_dtype is None: + input2_dtype = input_dtype + return [ + CppMicroGemmConfig( + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + vec_isa_cls, + GemmBlocking(*blocking), + extra_check, + ) + for blocking in register_blockings + ] + + +class CppMicroGemmRef(CppMicroGemm): + """ + A reference implementation of the CppMicroGemm class with naive C++ code. + It is used for correctness debugging. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; + for (int64_t k = 0; k < K; ++k) { + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + C[m * ldc + n] = result; + } + } +} +""" + + def __init__( + self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + GemmBlocking(1, 1, 1), + alpha, + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k): + return ( + k % config.register_blocking.block_k == 0 + and n % config.register_blocking.block_n == 0 + and m < 16 + ) + + +# extra check for small M dimension for int8 WoQ case +def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs): + return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get( + "dynamic_M", False + ) + + +# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise +def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs): + return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs) + + +@register_micro_gemm( + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX512, + [ + (4, 32, 64), + (8, 32, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX2, + [ + (2, 16, 64), + (4, 16, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecNEON, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), + *generate_gemm_config( + VecSVE256, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), +) +class CppMicroGemmFP32Vec(CppMicroGemm): + """ + This class generates the code for micro gemm using fp32 vec instructions for compute. + It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template, + if the desired output is BF16/FP16. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; + constexpr auto VLEN = Vectorized::size(); + {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + int64_t block_n = std::min(N - n, {{block_n}}); + if (block_m == {{block_m}} && block_n == {{block_n}}) { +{%- if not trans_b %} + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- else %} + {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- endif %} + A + m * lda, +{%- if not trans_b %} + B + n, +{%- else %} + B + n * ldb, +{%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); +{%- if tail_n %} + } else if (block_n == {{block_n}}){ +{%- else %} + } else { +{%- endif %} + switch (block_m) { +{%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + break; +{%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + +{%- if tail_n %} + } else { + switch (block_m) { + {%- for b in range(block_m, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + block_n, + K, + lda, + ldb, + ldc + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + } +{%- else %} + } +{%- endif %} + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +{%- if not trans_b %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_kernel( + {%- else %} +inline void {{kernel_name}}_kernel( + {%- endif %} +{%- else %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_transpose_b_kernel( + {%- else %} +inline void {{kernel_name}}_transpose_b_kernel( + {%- endif %} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, +{%- if tail_n %} + int64_t N, +{%- endif %} + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; +{%- if input2_dtype in [torch.bfloat16, torch.float16] %} + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; +{%- endif %} + +{%- if not trans_b %} + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, COLS> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + + {%- if tail_n %} + int64_t rCOLS = (N + VLEN - 1) / VLEN; + int ntail = N % VLEN; + {%- endif %} + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size); + } + {%- else %} + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + {%- endif %} + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + {%- endif %} + if constexpr (col == 0) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); + {%- endif %} + } + + if constexpr (row == 0) { + {%- if tail_n %} + if (col < rCOLS) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size); + {%- endif %} + } else { + vb[col] = Vectorized(0.0f); + } + + {%- else %} + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN); + if constexpr (prefetch) { + _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0); + } + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + {%- endif %} + {%- endif %} + + } + + constexpr int idx = row * COLS + col; + {%- if tail_n %} + if (col < rCOLS) { + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + } + {%- else %} + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + {%- endif %} + }; + + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i].store(C + row * ldc + col * VLEN, store_size); + } + {%- else %} + vc[i].store(C + row * ldc + col * VLEN); + {%- endif %} + }; + c10::ForcedUnroll{}(storec); + +{%- else %} + // Use 2 implementations for the transposed B: + // First implementation: + // Transpose first and then perform outer product calculation in sub-blocks, + // which introduces an additional transpose overhead of [K, N] compared to the non-transpose version. + // Second implementation: + // Directly perform inner product calculation in sub-blocks, + // which introduces an additional vector reduction of [M, N] compared to the non-tranpose version. + // Therefore, when M * N / (K * N) is large, the first implementation has better performance. + {%- if tail_n %} + if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- else %} + if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- endif %} + // First implementation: + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + int _K = K / VLEN; + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, VLEN> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN); + {%- endif %} + }; + auto compute_trans = [&, COLS](auto i, int k) { + constexpr int row = i % ROWS; + constexpr int col = i / ROWS; + constexpr int e_col = col * VLEN; + int idk = k * VLEN; + if constexpr (row == 0) { + c10::ForcedUnroll{}(unroll_loadB, B + e_col * ldb + idk); + at::vec::transpose_block(vb); + } + constexpr int idx = row * COLS + col; + {{kernel.unroll_pragma(16)}} + for (int j = 0; j < VLEN; j++) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j])); + {%- endif %} + vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]); + } + }; + for (int k = 0; k < _K; ++k) { + c10::ForcedUnroll{}(compute_trans, k); + } + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); + } else { + // Second implementation + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + constexpr auto VLEN = VectorizedIn::size(); + {%- else %} + constexpr auto VLEN = Vectorized::size(); + {%- endif %} + int _K = (K + VLEN - 1) / VLEN; + // sub-block size of BLOCK_N and BLOCK_M + constexpr int sM = {{sub_block_m}}; + constexpr int sN = {{sub_block_n}}; + {%- if tail_n %} + int bN = (N + sN - 1) / sN; + {%- else %} + constexpr int bN = (BLOCK_N + sN - 1) / sN; + {%- endif %} + constexpr int bM = (BLOCK_M + sM - 1) / sM; + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + at::vec::VectorizedN<{{compute_t}}, 2> va; + at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb; + {%- else %} + at::vec::Vectorized<{{compute_t}}> va; + at::vec::VectorizedN<{{compute_t}}, sN> vb; + {%- endif %} + at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid; + + {%- if tail_n %} + int ntail = N % sN; + {%- else %} + constexpr int ntail = BLOCK_N % sN; + {%- endif %} + constexpr int mtail = BLOCK_M % sM; + int ktail = K % VLEN; + + auto compute_trans = [&](int m, int n, int k) { + {%- if tail_n %} + int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN; + {%- else %} + int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN; + {%- endif %} + int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM; + int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN; + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(B + (sN * n + i) * ldb + k * VLEN, e_k); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + {%- endif %} + } + + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a); + {%- elif input2_dtype == torch.int8 %} + auto a32 = at::vec::convert_to_int32(A + (sM * m + s) * lda + k * VLEN, e_k); + va = at::vec::convert(a32); + {%- else %} + va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + {%- endif %} + + {%- if alpha != 1 %} + va = va * Vectorized({{alpha}}); + {%- endif %} + if (k == 0) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f)); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f)); + {%- endif %} + } + } else { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]); + {%- endif %} + } + } + } + + // store to C + if (k == _K - 1) { + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]); + if constexpr (accum) { + auto c = *(C + (sM * m + s) * ldc + sN * n + i); + *(C + (sM * m + s) * ldc + sN * n + i) = c + v; + } else { + *(C + (sM * m + s) * ldc + sN * n + i) = v; + } + } + } + } + }; + + for (int n = 0; n < bN; ++n) { + for (int m = 0; m < bM; ++m) { + for (int k = 0; k < _K; ++k) { + compute_trans(m, n, k); + } + } + } + } +{%- endif %} +} +""" + + # set trans_b to generate gemm that supports transposed B matrix + # set tail_n to support the tail of N + # TODO add trans_b support for other micro gemms + # and move setting of trans_b to the init of CppMicroGemm + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + tail_n=False, + trans_b=False, + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha, + ) + self.tail_n = tail_n + # trans_b is only supported on platforms that + # support avx512 or avx2 since transpose_block is + # only implemented on these platforms + if trans_b: + vec_isa = pick_vec_isa() + assert issubclass(vec_isa.__class__, VecAVX512) or issubclass( + vec_isa.__class__, VecAVX2 + ) + self.trans_b = trans_b + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "trans_b": False, + "tail_n": False, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + if self.trans_b: + # TODO supports tuning of sub_block_m/sub_block_n + # to get better performance for specific shapes + sub_block_m = min(1, self.register_blocking.block_m) + sub_block_n = min(4, self.register_blocking.block_n) + # update options to generate kernel with trans_b and sub-block size + options.update( + { + "trans_b": self.trans_b, + "sub_block_m": sub_block_m, + "sub_block_n": sub_block_n, + } + ) + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + # update options to generate the kernel for the tail of N + if self.tail_n: + options.update( + { + "tail_n": self.tail_n, + } + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +def check_vnni_extra(config, m, n, k, alpha, num_threads, **kwargs): + assert config.input_dtype == torch.uint8 and config.input2_dtype == torch.int8 + vnni_size = 4 + return k % vnni_size == 0 + + +@register_micro_gemm( + *generate_gemm_config( + VecAVX512VNNI, + # (block_m, block_n, block_k) + [(6, 64, 4)], + input_dtype=torch.uint8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_vnni_extra, + ), +) +class CppMicroGemmAVX512VNNI(CppMicroGemm): + """ + This class generates the code for micro gemm using AVX512 VNNI instructions for compute. + It supports u8s8s32 GEMM only. + AVX512_VNNI ISA has been available since the 3rd gen of Intel Xeon. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{vnni_size}} == 0, "K dimension must be multiple of {{vnni_size}}"); + constexpr int64_t M_BLOCK = {{block_m}}; + const int64_t M_TAIL = M % M_BLOCK; + const int64_t M_MAIN = M - M_TAIL; + for (int64_t m = 0; m < M_MAIN; m += M_BLOCK) { + for (int64_t n = 0; n < N; n += {{block_n}}) { + {{kernel_name}}_kernel( + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + } + } + if (M_TAIL > 0) { + switch (M_TAIL) { +{%- for m_tail in range(block_m - 1, 0, -1) %} + case ({{m_tail}}): + for (int64_t n = 0; n < N; n += {{block_n}}) { + {{kernel_name}}_kernel<{{m_tail}}, {{block_n}}, accum>( + A + M_MAIN * lda, + B + n, + C + M_MAIN * ldc + n, + K, + lda, + ldb, + ldc + ); + } + break; +{%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported M_TAIL: {}", M_TAIL); + } // switch M_TAIL + } // if M_TAIL +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_kernel( + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + constexpr const int COLS = N / {{vec_len}}; + __m512i va; + __m512i vb[COLS]; + __m512i vc[M * COLS]; + + c10::ForcedUnroll{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); }); + + auto compute = [&](auto i, int k) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k)); + } + + if constexpr (row == 0) { + // B block in VNNI layout: [K / {{vnni_size}}, N, {{vnni_size}}] + int64_t offset = k * ldb + col * {{vec_len}} * {{vnni_size}}; + vb[col] = _mm512_loadu_si512((__m512i const*)(B + offset)); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + + // Accumulate along k + constexpr const int k_unroll = 2; + int k = 0; + int k_limit = K / {{vnni_size}} / k_unroll; + for (; k < k_limit; k++) { + c10::ForcedUnroll{}( + [&](auto i) { + c10::ForcedUnroll{}(compute, {{vnni_size}} * (k * k_unroll + i)); + } + ); + } + k *= {{vnni_size}} * k_unroll; + for (; k < K; k += {{vnni_size}}) { + c10::ForcedUnroll{}(compute, k); + } + + // Store to C + auto store_c = [&](auto i) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + if constexpr (accum) { + __m512i vc_old = _mm512_loadu_si512((__m512i const*)(C + row * ldc + col * {{vec_len}})); + vc[i] = _mm512_add_epi32(vc[i], vc_old); + } + _mm512_storeu_si512((__m512i*)(C + row * ldc + col * {{vec_len}}), vc[i]); + }; + c10::ForcedUnroll{}(store_c); +} +""" + + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha, + ) + assert input_dtype == torch.uint8 and input2_dtype == torch.int8, ( + f"Only u8s8s32 GEMM is supported by AVX512VNNI microkernel, got A:{input_dtype}, B:{input2_dtype}, C:{output_dtype}." + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "restrict_keyword": get_restrict_keyword(), + "vec_len": 16, # = 512 / 32 for C + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + def get_b_layout(self): + return LayoutType.VNNI4 + + +# extra check for CppMicroGemmAMX +def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): + vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2 + return k % vnni_size == 0 and alpha == 1 + + +def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): + # We need avx512_bf16 to dequant int8 to bf16 + vec_isa = kwargs.get("vec_isa") + assert vec_isa is not None + return vec_isa.is_avx512_bf16_supported() and check_amx_extra( + config, m, n, k, alpha, num_threads, **kwargs + ) + + +# amx_fp16 need to be checked separately since it is not always supported when amx is supported +def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs): + assert config.input_dtype == torch.float16 and config.output_dtype == torch.float + vec_isa = kwargs.get("vec_isa") + assert vec_isa is not None + vnni_size = 2 + return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.int8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_bf16_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 16, 32), (32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.float16, + output_dtype=torch.float, + extra_check=check_amx_fp16_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.uint8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmAMX(CppMicroGemm): + """ + This class generates the code for micro gemm using Advanced Matrix extension (AMX) + instructions available in 4th generation Intel Xeon for compute. + It supports input types of torch.bfloat16 with fp32 output. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}} +{%- endif %} +{%- if use_cached_dequantized_B %} + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + const auto buf_size = K * {{block_n}}; + auto load_dequantized_B = [&](int base_idx) { + // Load a tile of B & cache it in L1D. + {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx; + for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) { + {%- for vec_idx in range(0, block_n, 32) %} + _mm_prefetch(base_addr + idx_q + 64 * ldb, _MM_HINT_T0); + {%- if (block_n - vec_idx) >= 32 %} + // 1) Load 32 x int8 + __m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}})); + // 2) Extract two halves + __m128i v8_lo = _mm256_extracti128_si256(v8, 0); + __m128i v8_hi = _mm256_extracti128_si256(v8, 1); + // 3) Widen each half to i32 + __m512i v32_lo = _mm512_cvtepi8_epi32(v8_lo); + __m512i v32_hi = _mm512_cvtepi8_epi32(v8_hi); + // 4) Convert to f32 + __m512 f_lo = _mm512_cvtepi32_ps(v32_lo); + __m512 f_hi = _mm512_cvtepi32_ps(v32_hi); + // 5) f32 -> bf16 (round-to-nearest-even) and pack 32 lanes to 512b + // Packs the second operand (f_lo) into the lower 16 bf16 lanes and the first (f_hi) into the upper 16. + __m512i bf = (__m512i)_mm512_cvtne2ps_pbh(f_hi, f_lo); + // 6) Store 32 x bf16 (512 bits) + _mm512_storeu_si512((__m512i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf); + {%- elif (block_n - vec_idx) >= 16 %} + // 1) Load 16 x int8 (128 bits) + __m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}})); + // 2) Widen: 16 x i8 -> 16 x i32 + __m512i v32 = _mm512_cvtepi8_epi32(v8); + // 3) Convert to f32 + __m512 f32 = _mm512_cvtepi32_ps(v32); + // 4) Convert f32 -> bf16 (round-to-nearest-even) + __m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32); + // 5) Store 16 x bf16 (256 bits) + _mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16); + {%- else %} + auto b_int8_tail = at::vec::Vectorized::loadu( + base_addr + idx_q + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail); + b_bf16_tail.store( + dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + {%- endif %} + {%- endfor %} + } + }; +{%- endif %} +// The ldb would not be block_n if N != block_n +{%- if use_cached_dequantized_B or pack_vnni_B_locally %} + const int64_t updated_ldb = {{block_n}}; +{%- else %} + const int64_t updated_ldb = ldb; +{%- endif %} + // TODO(jgong5): loop unroll for M and N + for (int64_t n = 0; n < N; n += {{block_n}}) { +{%- if pack_vnni_B_locally %} + // Pack non-constant weights into VNNI interleaved format in packed_B_buf + at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}}); +{%- elif use_cached_dequantized_B %} + // Dequantize K * block_n int8 B elements into BF16 + load_dequantized_B(n); +{%- endif %} + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; +{%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m * ldc + n, + K, + lda, + updated_ldb, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } +{%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m_tail * ldc + n, + K, + lda, + updated_ldb, + ldc, + block_m + ); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + AMXState& amx_state, + const {{input_t}}* {{restrict_keyword}} A, +{%- if use_cached_dequantized_B %} + const {{input_t}}* {{restrict_keyword}} B, +{%- else %} + const {{input2_t}}* {{restrict_keyword}} B, +{%- endif %} + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows +) { + // TODO(jgong5): add prefetch hint for A, B, C + auto loadconfig = [](const amx_tilecfg& cfg) { + _tile_loadconfig(&cfg); + }; + const auto last_k_offset = K / {{block_k}} * {{block_k}}; + const auto tail_k_size = K - last_k_offset; + if C10_LIKELY (last_k_offset > 0) { + amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); + } else { + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + } + auto load_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + auto zero_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_zero({{tile_idx}}); + {%- endfor %} +{%- endfor %} + }; + + if constexpr (accum) { + load_c(); + } else { + zero_c(); + } + + auto compute = [&](int k) { +{%- set tile_offset_a = num_rows // 16 * num_columns %} +{%- set tile_offset_b = tile_offset_a + num_rows // 16 %} +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx_a = tile_offset_a + tile_row %} + {%- set tile_idx_b = tile_offset_b + tile_col %} + {%- set tile_idx_c = tile_row * num_columns + tile_col %} + {%- if tile_col == 0 %} + _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); + {%- endif %} + {%- if tile_row == 0 %} + _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); + {%- endif %} + {%- if int8_gemm %} + {%- if input_dtype == torch.int8 %} + _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- else %} + _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- else %} + {%- if input_dtype == torch.float16 %} + _tile_dpfp16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- else %} + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- endif %} + {%- endfor %} +{%- endfor %} + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < last_k_offset; k += {{block_k}}) { + compute(k); + } + + auto store_c = [&]() { + // store to C +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + + // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead + if C10_UNLIKELY (tail_k_size > 0) { + if C10_LIKELY (last_k_offset > 0) { + store_c(); + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + load_c(); + } + compute(last_k_offset); + } + + store_c(); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + block_m, block_n, block_k = self.register_blocking + assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" + assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" + if self.input_dtype in [torch.uint8, torch.int8]: + assert block_k == 64, "Only support block_k = 64 for AMX INT8" + else: + assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" + num_columns = block_n // 16 + options = { + "declare_kernel": self.get_kernel_declaration(), + "use_cached_dequantized_B": ( + self.input_dtype == torch.bfloat16 + and self.input2_dtype in [torch.int8, torch.uint8] + ), + "kernel": kernel, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_columns": num_columns, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + for num_rows in range(block_m, 0, -16): + amx_kernel_options = {**options, "num_rows": num_rows} + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + amx_kernel_options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "AMXState amx_state;" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "amx_state.release([]() { _tile_release(); });" + + def get_kernel_extra_args_declare(self) -> str: + return "AMXState& amx_state," + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return ["amx_state,"] + + def get_b_layout(self): + if self.input_dtype in [torch.uint8, torch.int8]: + return LayoutType.VNNI4 + else: + return LayoutType.VNNI2 + + +# extra check for CppMicroBrgemm +def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs): + assert config.input_dtype == torch.half and config.output_dtype == torch.float + vnni_size = 2 + # use brgemm for Half when amx_fp16 is supported + return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.half, + output_dtype=torch.float, + extra_check=check_brgemm_extra, + ), +) +class CppMicroBrgemm(CppMicroGemm): + """ + This class generates the code for micro gemm using oneDNN brgemm. + It supports input types of torch.half. + """ + + TEMPLATE_ENTRY = r""" +#include +{{declare_kernel}} { +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}} + at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N); +{%- endif %} + at::native::cpublas::brgemm( + M, N, K, + {%- if pack_vnni_B_locally %} + lda, N, ldc, + {%- else %} + lda, ldb, ldc, + {%- endif %} + accum, + A, + {%- if pack_vnni_B_locally %} + packed_B_buf, + {%- else %} + B, + {%- endif %} + C); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "at::native::cpublas::brgemm_release();" + + def get_b_layout(self): + assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported() + return LayoutType.VNNI2 + + +def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): + if alpha != 1: + return False + q_group_size = kwargs.get("q_group_size") + assert q_group_size is not None + if ( + q_group_size not in [32, 64, 128] + or k % q_group_size != 0 + or config.register_blocking.block_k > q_group_size + ): + return False + return k % config.register_blocking.block_k == 0 and n % 64 == 0 + + +@register_micro_gemm( + # TODO: support float/half input + *generate_gemm_config( + VecAVX512, + [(4, 64, 32), (4, 64, 64), (4, 64, 128)], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_woq_int4_extra, + ), +) +class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): + """ + This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics. + It is based on the corresponding ATen kernel. + Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + auto group_size = q_group_size; + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + if (block_m == {{block_m}}) { + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + } else { + switch (block_m) { + {%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + } + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + return (k_start + index) % group_size == 0; +} + +inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) { + __m128i tmp = _mm_loadu_si64((const __m128i*)data); + __m128i bytes = _mm_cvtepu8_epi16(tmp); + const __m128i lowMask = _mm_set1_epi8(0xF); + __m128i high = _mm_andnot_si128(lowMask, bytes); + __m128i low = _mm_and_si128(lowMask, bytes); + high = _mm_slli_epi16(high, 4); + bytes = _mm_or_si128(low, high); + return bytes; +} + +template +inline void {{kernel_name}}_kernel( + const {{input_t}}* {{restrict_keyword}} A, + const uint8_t* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t q_group_size, + const at::BFloat16* {{restrict_keyword}} ScaleAndZeros, + int64_t lds, // leading dimension of ScaleAndZeros + int64_t k_start) { + constexpr int BLOCK_K = {{block_k}}; + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + + // number of blocks on K + const int KB = K / BLOCK_K; + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + float aa = static_cast(A[row * lda + k]); + _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0); + va = _mm512_set1_ps(aa); + } + + if constexpr (row == 0) { + if constexpr (COLS == 4) { + // when BLOCK_N = 64, handle each row at a time + // to reduce de-quantize overhead. + if constexpr (col == 0) { + __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb)); + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0); + + __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); + vb[0] = _mm512_permutexvar_ps(b32, lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]); + + b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1)); + vb[1] = _mm512_permutexvar_ps(b32, lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]); + } + } else { + __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8); + __m512i b32 = _mm512_cvtepu8_epi32(b8); + vb[col] = _mm512_permutexvar_ps(b32, lut); + vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]); + } + } + + constexpr int idx = row * COLS + col; + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0, kb = 0; k < K; ++k) { + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + c10::ForcedUnroll{}(compute, k); + } + + //store to C + auto storec = [&, COLS](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "const int64_t q_group_size,\n" + " const at::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [ # (block_m, block_n, block_k) + (16, 32, 32), + (32, 32, 32), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): + """ + This class generates the code for WoQ int4 micro gemm using AMX intrinsics, + which are available on 4th and newer generations of Intel Xeon. + Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + Reuse TEMPLATE_KERNEL of CppMicroGemmAMX. + """ + + TEMPLATE_ENTRY = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + // check if (k_start + index) % group_size == 0, assuming group_size = 32/64/128 + return ((k_start + index) & (group_size - 1)) == 0; +} + +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4"); + + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + + constexpr int BLOCK_K = {{block_k}}; + constexpr int64_t BLOCK_N = {{block_n}}; + constexpr int COLS = BLOCK_N / 16; + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + const int KB = K / BLOCK_K; + + __m512i b32[COLS * 2]; + __m512 vb[COLS * 2]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // Indices for VNNI layout conversion + __m512i idx_low = _mm512_set_epi32( + 0x17, + 0x07, + 0x16, + 0x06, + 0x15, + 0x05, + 0x14, + 0x04, + 0x13, + 0x03, + 0x12, + 0x02, + 0x11, + 0x01, + 0x10, + 0x00); + __m512i idx_high = _mm512_set_epi32( + 0x1f, + 0x0f, + 0x1e, + 0x0e, + 0x1d, + 0x0d, + 0x1c, + 0x0c, + 0x1b, + 0x0b, + 0x1a, + 0x0a, + 0x19, + 0x09, + 0x18, + 0x08); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + // Dequantize a B block of 2 * block_n into bf16 + // So, it handles k and k+1 at the same time + auto dequantize_B = [&](int n) { + constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16 + for (int k = 0, kb = 0; k < K; k += 2) { + // Since block_k must be 32 for AMX microkernels, k_start may not be + // a multiple of q_group_size. In that case, we need to load scales + // and zero points immediately when k == 0 here + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); + + // load 256 bits = 64 elements in int4 + __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4)); + b32[0] = _mm512_cvtepu8_epi32(b4); + b32[1] = _mm512_srli_epi32(b32[0], 4); + vb[0] = _mm512_permutexvar_ps(b32[0] , lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[1] = _mm512_permutexvar_ps(b32[1], lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + + __m128i b4_2 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); + b32[0 + COLS] = _mm512_cvtepu8_epi32(b4_2); + b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4); + vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut); + vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]); + vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut); + vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]); + + for (int i = 0; i < COLS; i++) { + // convert to VNNI + auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]); + auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]); + // convert lower 16 float32 values to bfloat16 + auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low)); + // convert higher 16 float32 values to bfloat16 + auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high)); + // combine the lower 16 and higher 16 bfloat16 values + auto v = _mm512_castsi256_si512(v0_bf16); + v = _mm512_inserti64x4(v, v1_bf16, 1); + // store the VNNI format bfloat16 values + {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32; + _mm512_storeu_si512(addr, v); + } + } + }; + + for (int64_t n = 0; n < N; n += {{block_n}}) { + // Dequantize K * block_n int8 B elements into BF16 + dequantize_B(n); + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + {%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + dequantized_B_buf + n * K, + C + m * ldc + n, + K, + lda, + {{block_n}}, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } + {%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + dequantized_B_buf + n * K, + C + m_tail * ldc + n, + K, + lda, + {{block_n}}, + ldc, + block_m + ); + } + } // for m + } // for n +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "AMXState& amx_state,\n" + " const int64_t q_group_size,\n" + " const c10::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "amx_state,", + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + input2_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, + q_group_size=None, +) -> Optional[CppMicroGemm]: + """ + Based on the provided info, try to find the config of the micro-kernel that would + deliver the best performance in terms of lower latency for this case. + """ + + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( + name, + config.input_dtype, + config.input2_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + def skip_amx_kernel_for_woq(dynamic_M): + # For WoQ GEMM, AMX micro-kernel may not perform well if m is small. + # Exception: for dynamic shapes, we consider using the AMX micro-kernel. + if ( + dynamic_M + or input_dtype != torch.bfloat16 + or input2_dtype not in [torch.int8, torch.uint8] + ): + return False + m_threshold = 5 + return m < m_threshold + + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k + from ..utils import has_free_symbols + + dynamic_M = has_free_symbols((m,)) + m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m + assert isinstance(m, int) or m.is_number, m + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() + matched_configs = [] + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not issubclass(vec_isa.__class__, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.compute_dtype == compute_dtype + and config.input2_dtype == input2_dtype + and config.output_dtype == output_dtype + # The output_dtype here is the output dtype of the micro-kernel. + # In some cases, the actual output dtype of the op for which the micro-kernel + # is being created would be same as that of the activation, but the micro-kernels + # compute output in Float/int32, which is converted in the GEMM template. This is + # subject to change in the future. + ): + if config.extra_check is not None and not config.extra_check( + config, + m, + n, + k, + alpha, + num_threads, + dynamic_M=dynamic_M, + q_group_size=q_group_size, + vec_isa=vec_isa, + ): + continue + block_m, block_n, block_k = config.register_blocking + if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M): + continue + # Criteria on the ranking of configurations + # 1. ISA: AMX > VNNI > VEC + # 2. Dividable by block sizes (block_m, block_n, block_k) + # 3. Number of mxn blocks is large enough to occupy all the threads + # 4. Register blocks are larger + isa_score = 0 + if config.vec_isa_cls == VecAMX: + isa_score += 2 + elif config.vec_isa_cls == VecAVX512VNNI: + isa_score += 1 + dividable_score = 0 + if m % block_m == 0: + dividable_score += 1 + if n % block_n == 0: + dividable_score += 1 + if k % block_k == 0: + dividable_score += 1 + occupancy_score = 0 + n_blocks = (n + block_n - 1) // block_n + total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) + if n_blocks >= num_threads: + occupancy_score += 1 + if total_mxn_blocks >= num_threads: + occupancy_score += 1 + register_bytes = ( + block_m * block_n * config.compute_dtype.itemsize + + (block_m * block_k + block_k * block_n) + * config.input_dtype.itemsize + ) + size_score = register_bytes + # if number of mxn blocks can not occupy all the threads, + # we favor smaller register blocks. + if occupancy_score == 0: + size_score = 0 - register_bytes + matched_configs.append( + ( + (isa_score, dividable_score, occupancy_score, size_score), + cls, + config, + ) + ) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None + # TODO(jgong5): allow autotuning on choices of configs + return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000000000000000000000000000000..c01ca4363685deff18328b48026fa2d33f92e29f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import ctypes +import functools +import itertools +import logging +import sys +from collections.abc import Callable, Iterable +from typing import Optional, Union +from unittest.mock import patch + +import sympy + +from .. import config, ir +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: ir.Layout, + num_threads: int, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + super().__init__(name) + self.input_nodes = input_nodes + self.index = next(self.index_counter) + self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer( + name=f"buf_out{self.index}", layout=layout + ) + self.layout = layout + self.num_threads = num_threads + self.epilogue_creator = epilogue_creator + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + V.graph.set_current_device(self.layout.device), + CppTemplateKernel( + kernel_name=kernel_name, num_threads=self.num_threads + ) as kernel, + ): + code = kernel.render(self, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + if isinstance(self.output_node, Iterable): + expected_args.extend([node.get_name() for node in self.output_node]) + else: + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + # Cast the size hint from int to ctypes.c_ulonglong explicitly + # since in cpp kernel, we bind it to C long + extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args) + + kernel_hash_name = f"cpp_{self.name}_{self.index}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + # pyrefly: ignore [bad-argument-type] + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ir.CppTemplateBuffer, + flag_template_buffer_has_other_users: bool, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads + ) + render = functools.partial( + kernel.render, + self, + template_buffer_node=template_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + # pyrefly: ignore [index-error] + self.output_node[0].get_layout() + if isinstance(self.output_node, Iterable) + else self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.writeline("#include ") + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + res.splice("""#include """) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + res.writelines(["#include "]) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..1434398eac8a7e095a6ebb7d9c17e81cde8db11e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,621 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union +from unittest.mock import patch + +import sympy +from sympy.parsing.sympy_parser import parse_expr + +import torch +from torch._inductor.utils import do_bench_using_profiling +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import SymT + +from .. import config, cpp_builder, ir, lowering as L +from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody +from ..select_algorithm import PartialRender +from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix +from ..virtualized import V +from .common import REMOVED +from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext + + +def parse_expr_with_index_symbols(expr): + if isinstance(expr, sympy.Expr): + return expr + elif isinstance(expr, (list, tuple)): + return [parse_expr_with_index_symbols(e) for e in expr] + else: + expr = parse_expr(str(expr)) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) + + +class CppTemplateKernel(CppKernel): + def __init__(self, kernel_name, num_threads): + super().__init__(None, num_threads) + self.kernel_name = kernel_name + self.render_hooks = {} + self.local_buffers = {} + + def render(self, template, **kwargs): + return PartialRender( + template.render(kernel=self, **kwargs), self.render_hooks + ).finalize_all() + + def def_kernel( + self, + inputs: dict[str, ir.Buffer], + outputs: dict[str, ir.Buffer], + aliases: Optional[dict[str, str]] = None, + function_name: str = "", + extra_sizevars: Optional[list[sympy.Expr]] = None, + placeholder: str = "", + ) -> str: + if len(function_name) == 0: + function_name = str(self.kernel_name) + for name, inp in inputs.items(): + if inp is not None: + self.args.input_buffers[inp.get_name()] = name + for name, out in outputs.items(): + self.args.output_buffers[out.get_name()] = name + if aliases is not None: + for alias, orig in aliases.items(): + if orig in self.args.input_buffers: + self.args.input_buffers[alias] = self.args.input_buffers[orig] + if orig in self.args.output_buffers: + self.args.output_buffers[alias] = self.args.output_buffers[orig] + + unique_sizevars = OrderedSet( + s + for input in inputs.values() + if input is not None + for sym in itertools.chain(input.get_size(), input.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for sym in extra_sizevars or [] + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for output in outputs.values() + for sym in itertools.chain(output.get_size(), output.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + sizevars = sorted(unique_sizevars, key=str) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" + + def hook(): + # remove all aliases before generate function definition + if aliases is not None: + for alias in aliases: + if alias in self.args.input_buffers: + raise AssertionError( + f"input_buffers cannot be removed: {alias}" + ) + if alias in self.args.output_buffers: + self.args.output_buffers[alias] = REMOVED + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {function_name}({', '.join(cpp_argdefs)})" + + assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types) + + def dtype(self, node: ir.Buffer) -> str: + return DTYPE_TO_CPP[node.get_dtype()] + + def acc_dtype(self, node: ir.Buffer) -> str: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_size()[dim])) + + def stride(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: ir.Buffer, indices: list[Any]) -> str: + indexer = node.get_layout().as_fixed().make_indexer() + index = indexer(parse_expr_with_index_symbols(indices)) + index = self.rename_indexing(index) + outer_name = node.get_name() + inner_name = ( + outer_name + if outer_name in self.local_buffers + else self.args.input(node.get_name()) + ) + return f"{inner_name}[{cexpr_index(index)}]" + + def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView: + """ + Slice the given node with a list of ranges (start and end) corresponding to its dims. + The dim is not sliced if the corresponding range is empty. + """ + assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}" + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = parse_expr_with_index_symbols(_range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced, ir.TensorBox) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def select(self, node, dim: int, idx: int) -> ir.ReinterpretView: + # We avoid using L.select here because we need clamp=False so the dim after slicing + # is 1 instead of a sympy expression of symbol - dim_size. + node = wrap_with_tensorbox(node) + idx = ir.View.handle_negative_index(idx, node.get_size()[dim]) + sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def view(self, node, sizes: list[Any]) -> ir.IRNode: + node = wrap_with_tensorbox(node) + sizes = parse_expr_with_index_symbols(sizes) + return L.view(node, sizes).data # type: ignore[arg-type] + + def permute(self, node, dims): + node = wrap_with_tensorbox(node) + permuted = L.permute(node, dims).data + assert isinstance(permuted, ir.ReinterpretView) + return permuted + + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + handle_str = ( + "torch::aot_inductor::RAIIAtenRecordFunctionHandle " + f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);' + ) + return handle_str + else: + return "" + + def unroll_pragma(self, unroll): + if cpp_builder.is_gcc(): + return f"#pragma GCC unroll {unroll}" + else: + return f"#pragma unroll {unroll}" + + def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str: + """Define kernel local buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" + + def define_stack_allocated_buffer( + self, name, sizes: list[Any], dtype=torch.float + ) -> str: + """Define stack-allocated buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};" + + def reinit_buffer_if_null(self, name): + """Reinit the previously defined local buffer if it is null""" + assert name in self.local_buffers + buf = self.local_buffers[name] + ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}" + + def release_buffer(self, name): + """Codegen the code to release the ownership of a local buffer to others""" + assert name in self.local_buffers + return f"_{name}.release()" + + def store_pointwise_nodes( + self, + dst: ir.Buffer, + nodes: list[ir.IRNode], + offsets: Optional[list[sympy.Expr]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ) -> str: + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + if not offsets: + offsets = [sympy.S.Zero] * len(var_sizes[0]) + if not reindexers: + reindexers = [None] * len(nodes) + assert len(offsets) == len(var_sizes[0]) + output_index = dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_grouped_gemm_pointwise_nodes( + self, + dst: tuple[ir.Buffer], + nodes: list[ir.IRNode], + offsets: list[sympy.Expr], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + output_names: list[str], + ) -> str: + ref_dst = dst[0] + var_sizes = (tuple(ref_dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + assert offsets, "offsets should be set outside" + assert all(len(offset) == len(var_sizes[0]) for offset in offsets) + output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = output_names[i] + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_output( + self, + dst: ir.Buffer, + src: ir.Buffer, + orig_src: Optional[ir.Buffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ): + """ + Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. + If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues + before stored to `dst`. The `epilogues_nodes` are all pointwise. + + Notes: + 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute + and stores. In case `epilogue_nodes` are not provided, we do nothing. + 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since + they come form the original Inductor IR, they might need to be adjusted before working with + `src` and `dst` as outlined below: + a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. + In this case, the `offsets` could be provided to adjust the indices passed to + `epilogue_nodes` during codegen and the data ranges are also configured according to + the sizes of `src` and `dst`. + b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is + needed on the indices to `epilogue_nodes` to match the indexing of `dst`. + c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer + in `epilogue_nodes` with `src`. + """ + assert isinstance(dst, (ir.Buffer, ir.ReinterpretView)) + assert dst.get_size() == src.get_size(), f"{dst=}, {src=}" + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + if epilogue_nodes: + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + if orig_src.get_name() != src.get_name(): + scope.add_local_buffer( + src, + [ + orig_src, + ], + ) + epilogue_nodes = scope.localize_nodes(epilogue_nodes) + return self.store_pointwise_nodes( + # pyrefly: ignore [bad-argument-type] + dst, + epilogue_nodes, # type: ignore[arg-type] + offsets, + reindexers, + ) + else: + if dst.get_name() != src.get_name(): + # src is local + copy = L.copy(dst, src).data.data + with LocalBufferContext(self.args) as scope: + scope.add_local_buffer(src) + # pyrefly: ignore [bad-argument-type] + return self.store_pointwise_nodes(dst, [copy]) + else: + assert dst.layout == src.layout, f"{dst=}, {src=}" + return "" + + def store_outputs( + self, + dst: tuple[ir.Buffer], + src: tuple[ir.IRNode], + orig_src: Optional[tuple[ir.IRNode]] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + multi_output_buffers: Optional[tuple[ir.MultiOutput, ...]] = None, + ): + assert isinstance(dst, Iterable) + assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst)) + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + gemm_num = len(src) + final_offsets = [] + output_names = [] + if epilogue_nodes: + if not reindexers: + reindexers = [None] * len(epilogue_nodes) + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + localize_epilogue_nodes = [] + all_read_names = [] + for epilogue in epilogue_nodes: + all_read_names.extend(list(epilogue.get_read_names())) + localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes)) + final_offsets.extend([offsets] * len(localize_epilogue_nodes)) + output_names.extend( + [node.get_name() for node in localize_epilogue_nodes] + ) + for gemm_idx in range(gemm_num): + if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name(): + if orig_src[gemm_idx].get_name() in all_read_names or ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + ): + # If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output + global_buffers = [orig_src[gemm_idx]] + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + and orig_src[gemm_idx].get_name() not in all_read_names + ): + # Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer + # if this MultiOutput has not been stored by in-template epilogue + # otherwise, use the cse store cache if it will be stored before used + global_buffers.append(multi_output_buffers[gemm_idx]) + scope.add_local_buffer( + src[gemm_idx], + global_buffers, + ) + else: + scope.add_local_buffer(src[gemm_idx]) + localize_epilogue_nodes.extend( + [L.copy(dst[gemm_idx], src[gemm_idx]).data.data] + ) + reindexers.append(None) + output_names.append(dst[gemm_idx].get_name()) + final_offsets.append( + [sympy.S.Zero] * len(dst[gemm_idx].get_size()) + ) + res = self.store_grouped_gemm_pointwise_nodes( + dst, + localize_epilogue_nodes, + final_offsets, + reindexers, + output_names=output_names, + ) + for gemm_idx in range(gemm_num): + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() in all_read_names + ): + # If the MultiOutput is used in the Epilogue, let's remove it from args + multi_output_name = multi_output_buffers[gemm_idx].get_name() + if ( + multi_output_name in self.args.output_buffers + and self.args.output_buffers[multi_output_name] + is not REMOVED + ): + self.remove_buffer(multi_output_name) + return res + else: + if dst[0].get_name() != src[0].get_name(): + copy_list = [] + with LocalBufferContext(self.args) as scope: + for _src, _dst in zip(src, dst): + copy_list.extend([L.copy(_dst, _src).data.data]) + scope.add_local_buffer(_src) + output_names.append(_dst.get_name()) + final_offsets.append([sympy.S.Zero] * len(_dst.get_size())) + reindexers = [None] * len(copy_list) + return self.store_grouped_gemm_pointwise_nodes( + dst, + nodes=copy_list, + offsets=final_offsets, + reindexers=reindexers, + output_names=output_names, + ) + else: + assert all( + _src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst) + ) + assert all( + _src.get_layout() == _dst.get_layout() + for _src, _dst in zip(src, dst) + ) + return "" + + def check_bounds(self, expr, size, lower, upper): + # CppTemplateKernel does not need codegen related operations + return + + +class CppTemplateCaller(ir.ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ + ir.CppTemplateBuffer, + bool, + Optional[list[ir.IRNode]], + ], + str, + ], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict( + self, + ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: + return ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2bcede213b8f90e66ae18e40e9e18a5e24652e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py @@ -0,0 +1,787 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import math +import sys +from collections import namedtuple +from collections.abc import Callable, Sequence +from typing import Any, Optional +from unittest.mock import patch + +import sympy + +import torch +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter as _CppPrinter +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import ir +from ..dependencies import Dep +from ..loop_body import LoopBody +from ..scheduler import BaseSchedulerNode, SchedulerBuffer +from ..shape_propagation import BlockShapeType +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs +from ..virtualized import ops, OpsValue, V +from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext + + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "at::Half", + torch.int64: "int64_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint64: "uint64_t", + torch.uint32: "uint32_t", + torch.uint16: "uint16_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "at::BFloat16", + torch.complex32: "at::complex", + torch.complex64: "at::complex", + torch.complex128: "at::complex", + torch.float8_e4m3fn: "at::Float8_e4m3fn", + torch.float8_e5m2: "at::Float8_e5m2", + torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "meta": "at::kMeta", + "cpu": "at::kCPU", + "cuda": "at::kCUDA", + "xpu": "at::kXPU", + "mps": "at::kMPS", +} + +LAYOUT_TO_ATEN = { + torch.strided: "at::kStrided", + torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] +} + +# matches c10/core/DeviceType.h +DEVICE_TO_INT = {"cpu": 0, "cuda": 1} + +_IS_WINDOWS = sys.platform == "win32" + +INDEX_TYPE = "int64_t" + +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + + +def get_promote_dtype(args): + return ( + functools.reduce( + torch.promote_types, # type: ignore[arg-type] + [n.dtype for n in args if isinstance(n, CppCSEVariable)], + ) + if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) + else None # not enough info to calculate the promote dtype + ) + + +def promote_args(new_args): + def promote_arg(arg, promote_type): + if ( + isinstance(arg, CppCSEVariable) + and arg.dtype + and promote_type + and arg.dtype != promote_type + ): + arg = ops.to_dtype(arg, promote_type) + arg = arg.value if isinstance(arg, OpsValue) else arg + arg.dtype = promote_type + return arg + + promote_type = get_promote_dtype(new_args) + promote_fn = functools.partial( + promote_arg, + promote_type=promote_type, + ) + if ( + all( + new_arg.dtype is not None + for new_arg in new_args + if isinstance(new_arg, CppCSEVariable) + ) + and promote_type + ): + new_args = list(map(promote_fn, new_args)) + return new_args + + +class CppCSEVariable(CSEVariable): + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ) -> None: + super().__init__(name, bounds, dtype, shape=shape) + self.is_vec = False + self.dependent_itervars = OrderedSet[sympy.Symbol]() + + def __repr__(self) -> str: + return ( + f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, " + f"dependent_itervars: {self.dependent_itervars})" + ) + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[2] is index + self._set_dependent_itervars(args[2]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppPrinter(_CppPrinter): + def doprint(self, expr, *, simplify: bool = True, p=True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str: + if isinstance(item, sympy.Mod): + # use parenthesis to enforce precedence. + # in sympy 1.13.3, -2*Mod(x,y) becomes -2*x%y, which is wrong. + return f"({self._print(item)})" + else: + return super().parenthesize(item, level, strict) + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def rewrite_index_for_function( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + # Local buffer at the inner dimensions + snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op + assert snode is not None + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + scheduler_nodes = snode.get_nodes() + _, (group, reduction_group) = max( + scheduler_nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + indices_to_keep = [ + f"x{len(call_ranges) - (idx + 1)}" + for idx in range(len(local_buf.get_layout().size)) + ] + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] + replacements = {} + for x in sorted_symbols: + if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] + # Only keep index used by local buffer + replacements[x] = sympy.core.numbers.Zero() + index = sympy_subs(index, replacements) # type: ignore[arg-type] + return index + + +def rewrite_index_for_nodes( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + used_vars = OrderedSet( + s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX) + ) + index_vars = [] + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + for i in range(len(local_buf.get_size())): + var = sympy_index_symbol_with_prefix(SymT.INDEX, i) + index_vars.append(var if var in used_vars else 0) + index = local_buf.get_layout().make_indexer()(index_vars) + return index + + +class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] + def __init__( + self, + inner, + global_to_local: dict[str, ir.Buffer], + rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr], + ) -> None: + super().__init__(inner) + self.global_to_local = global_to_local + self.rewrite_index = rewrite_index + + def localize(self, name: str, index: sympy.Expr): + if self.global_to_local and name in self.global_to_local: + assert self.rewrite_index is not None + index = self.rewrite_index(self, index, name) + name = self.global_to_local[name].get_name() + return name, index + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(*self.localize(name, index)) + + def store(self, name, index, value, mode=None): + local_buffer_name, local_buffer_index = self.localize(name, index) + res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) + if ( + self.global_to_local + and name in self.global_to_local + and isinstance(V.kernel, Kernel) + ): + # Remove name of local buffer from Kernel.store_buffer_names + # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. + V.kernel.store_buffer_names.discard(local_buffer_name) + return res + + def store_reduction(self, name, index, value): + # pyrefly: ignore [bad-argument-count] + return self._inner.store_reduction(*self.localize(name, index), value) + + +class LocalBufferContext: + """ + This class creates a context that helps to generate code involving Inductor IR with + function local buffers. These buffers are constructed during the codegen process and + are used to store intermediate results such as local accumulators. We do not want to + add them to `V.graph` since they are not global and we do not want to add them as + function arguments either. So we patch the codegen processes under this scope to support + these buffers without exposure to the outside world. + """ + + def __init__(self, kernel_args: KernelArgs) -> None: + self.kernel_args = kernel_args + self.exit_stack = contextlib.ExitStack() + # map local buffer name to local buffer + self.local_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to global buffer + self.global_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to local buffer + self.global_to_local: dict[str, ir.Buffer] = {} + # record the global buffers that are removed by this LocalBufferContext + self.removed_buffers: OrderedSet[str] = OrderedSet() + + def __enter__(self): + self.exit_stack.__enter__() + original_get_dtype = V.graph.get_dtype + + def get_dtype(name): + if name in self.local_buffers: + return self.local_buffers[name].get_dtype() + return original_get_dtype(name) + + self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) + + original_input = self.kernel_args.input + + def input(name): + if name in self.local_buffers: + return name + return original_input(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) + + original_output = self.kernel_args.output + + def output(name): + if name in self.local_buffers: + return name + return original_output(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) + + # Set current LocalBufferContext into V + self.exit_stack.enter_context(V.set_local_buffer_context(self)) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.local_buffers.clear() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def add_local_buffer( + self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None + ): + assert local_buffer.get_name() not in self.local_buffers + self.local_buffers[local_buffer.get_name()] = local_buffer + if global_buffers: + for global_buffer in global_buffers: + global_buffer_name = global_buffer.get_name() + assert ( + global_buffer_name not in self.global_buffers + and global_buffer_name not in self.global_to_local + ) + self.global_buffers[global_buffer_name] = global_buffer + self.global_to_local[global_buffer_name] = local_buffer + if global_buffer_name not in V.graph.removed_buffers: + # Record the global buffers that are removed by this LocalBufferContext + # since which may need to restore. Refer to issue: + # https://github.com/pytorch/pytorch/issues/144186 + self.removed_buffers.add(global_buffer_name) + V.graph.removed_buffers.add(global_buffer_name) + + def localize_function( + self, + fn: Callable[..., Any], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_function, + ): + def inner(*args, **kwargs): + with V.set_ops_handler( + LocalizeBufferHandler( + V.get_ops_handler(), + global_to_local=self.global_to_local, + rewrite_index=rewrite_index, + ) + ): + return fn(*args, **kwargs) + + return inner + + def localize_nodes( + self, + nodes: list[ir.IRNode], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_nodes, + ) -> list[ir.IRNode]: + """ + Given `local_buf` and `global_buf` registered in current `LocalBufferContext` + though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` + for the given `nodes` and returns a new list of IR nodes that work on `local_buf` + instead of `global_buf`, i.e., all the loads and stores are redirected to + `local_buf`. This helps the fused loops to work on smaller-sized local buffers + for better data locality. + + The data access of `local_buf` is assumed to be contiguous with the + same order as the `global_buf`. + """ + assert len(nodes) > 0 + + def wrap_inner_fn_for_node(node: ir.IRNode): + loops = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(loops, ir.Loops) + new_inner_fn = self.localize_function( + loops.inner_fn, + rewrite_index, + ) + + new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn) + if isinstance(node, ir.ComputedBuffer): + new_node = ir.ComputedBuffer( + name=node.get_name(), layout=node.get_layout(), data=new_loops + ) + else: + new_node = new_loops # type: ignore[assignment] + + return new_node + + return [wrap_inner_fn_for_node(node) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars + + +def may_unify_binary_op_mask_type(a, b): + """ + Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable. + """ + if a.dtype == torch.bool: + assert b.dtype == torch.bool + mask_dtype = torch.int32 + return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype) + return a, b + + +def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32): + assert is_integer_dtype(offset.dtype) + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];" + ) + code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];") + code.writeline(f"{offset}.store(offset);") + code.writeline( + f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )" + ) + with code.indent(): + code.writeline(rand_function) + num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype) + if num_vectors == 1: + code.writeline( + f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);" + ) + else: + code.writeline( + f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);" + ) + code.writeline("()") + return code + + +def get_gemm_template_output_and_compute_dtype(input_dtype): + if input_dtype in [torch.uint8, torch.int8]: + return (torch.int32, torch.int32) + else: + return (torch.float32, torch.float32) + + +def create_epilogue_with_attr(input_buffer, attr, **kwargs): + input_loader = input_buffer.make_loader() + dtype = input_buffer.get_dtype() + if attr == "relu": + + def inner_fn(index): + input = input_loader(index) + zero = ops.constant(0, dtype) + return ops.maximum(input, zero) + + elif attr == "gelu": + assert "algorithm" in kwargs + if kwargs["algorithm"] == "none": + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const = ops.constant(0.7071067811865476, torch.float) + result = input * half * (ops.erf(input * const) + one) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + else: + assert kwargs["algorithm"] == "tanh" + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const1 = ops.constant(0.7978845608028654, torch.float) + const2 = ops.constant(0.044715, torch.float) + result = ( + half + * input + * ( + one + + ops.tanh(const1 * (input + const2 * input * input * input)) + ) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "swish": + + def inner_fn(index): + input = input_loader(index) + result = input * ops.sigmoid(input) + return result + + elif attr == "sigmoid": + + def inner_fn(index): + return ops.sigmoid(input_loader(index)) + + elif attr == "tanh": + + def inner_fn(index): + return ops.tanh(input_loader(index)) + + elif attr == "hardswish" or attr == "hardsigmoid": + + def hardsigmoid_float(input): + zero = ops.constant(0, torch.float) + six = ops.constant(6, torch.float) + three = ops.constant(3, torch.float) + one_over_six = ops.constant(0.16666666666666666, torch.float) + max = ops.maximum(input + three, zero) + min = ops.minimum(max, six) + return min * one_over_six + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = hardsigmoid_float(input) + if attr == "hardswish": + result = input * result + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "leaky_relu": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 1 + negative_slope = kwargs["scalars"][0] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + zero = ops.constant(0, torch.float) + result = ops.where( + input > zero, input, input * ops.constant(negative_slope, torch.float) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "hardtanh": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 2 + min_value = kwargs["scalars"][0] + max_value = kwargs["scalars"][1] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = ops.minimum( + ops.maximum(input, ops.constant(min_value, torch.float)), + ops.constant(max_value, torch.float), + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr in ["add", "sub", "mul"]: + assert "other" in kwargs + other = kwargs["other"] + num_input_dims = len(input_buffer.get_size()) + num_other_dims = len(other.get_size()) + dims_diff = num_input_dims - num_other_dims + other_loader = other.make_loader() + + def inner_fn(index): + op = getattr(ops, attr) + if dims_diff != 0: + return op(input_loader(index), other_loader(index[dims_diff:])) + else: + return op(input_loader(index), other_loader(index)) + + elif attr == "bias_add": + assert "other" in kwargs + assert "beta" in kwargs + assert "dtype" in kwargs + beta = kwargs["beta"] + other = kwargs["other"] + dtype = kwargs["dtype"] + bias_loader = other.make_loader() + + def inner_fn(index): + bias = bias_loader(index) + input = input_loader(index) + if beta != 1: + result = ops.constant(beta, torch.float) * bias + input + else: + result = bias + input + return result + + else: + raise ValueError(f"Unsupported epilogue attribute: {attr}") + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + +def _get_loop_body(fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = OrderedSet[torch.dtype]() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes + + +def template_fusion_with_epilogues_supported( + template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode] +) -> tuple[bool, bool]: + def _get_indexes_of_template_buf_read( + epilogue_node: ir.Operation, template_buf_names: list[str] + ) -> list[sympy.Expr]: + return [ + read.index + for read in epilogue_node.get_reads() + if read.name in template_buf_names + ] + + def _check_supported_and_same_indexes( + index_of_template_buf_read: Sequence[sympy.Expr], + epilogue_writes: OrderedSet[Dep], + ) -> tuple[bool, bool]: + num_indexes = len(OrderedSet(index_of_template_buf_read)) + + if num_indexes > 1: + same_index = False + supported = False # Different read indexes not supported + elif num_indexes == 0: + same_index = True + supported = True # No reads, automatically supported + elif num_indexes == 1: + iotbr = index_of_template_buf_read[0] + same_index = all(write.index == iotbr for write in epilogue_writes) + # TODO: Add support of fusion when the read of template buffer and the write of epilogue output + # in the epilogue node don't have the same index and change supported to True + supported = same_index + else: + raise AssertionError("Should not reach here") + + return supported, same_index + + def _template_fusion_supported( + template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation] + ) -> tuple[bool, bool]: + template_buf_names = [x.get_name() for x in template_outputs] + indexes_of_template_buf_reads = [ + _get_indexes_of_template_buf_read(epilogue_node, template_buf_names) + for epilogue_node in epilogue_nodes + ] + epilogue_nodes_writes = [ + epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes + ] + + results = [ + _check_supported_and_same_indexes(reads, writes) + for reads, writes in zip( + indexes_of_template_buf_reads, epilogue_nodes_writes + ) + ] + supported, same_indexes = zip(*results) + return all(supported), all(same_indexes) + + assert template.is_template() + template_outputs = template.get_outputs() + + epilogue_nodes = [ + n.node + for epilogue in epilogues + for n in epilogue.get_nodes() + if n.node is not None + ] + return _template_fusion_supported(template_outputs, epilogue_nodes) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..16522d9832ec0e9e8ce7686fe5537e3c4a647410 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,3010 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import ctypes +import functools +import math +import os +import sys +import textwrap +from itertools import chain, count +from typing import Any, Optional, Protocol, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._higher_order_ops.torchbind +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import config, cpp_builder, ir +from ..ir import ExternKernel +from ..utils import _align, DeferredLineBase, LineContext, normalize_name +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, IndentedBuffer, Kernel +from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP +from .wrapper import ( + codegen_reinterpret_view_helper, + EnterSubgraphLine, + ExitSubgraphLine, + PythonWrapperCodegen, + SymbolicCallArg, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ..graph import GraphLowering + + # At most, the list nesting can go one layer deep. + _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]] + + from ..scheduler import BaseSchedulerNode + + +class HasWriteLine(Protocol): + def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ... + + +class CppWrapperCpu(PythonWrapperCodegen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + # must be initialized prior to calling super().__init__() + self.included_devices: OrderedSet[str] = OrderedSet() + self.model_class_name_suffix = ( + "" + if config.aot_inductor.dynamic_linkage + else config.aot_inductor.model_name_for_generated_files + ) + self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}" + + super().__init__() + + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.comment = "//" + self.none_str = "nullptr" + self.supports_intermediate_hooks = False + self.kernel_callsite_id = count() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars: OrderedSet[str] = OrderedSet() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices: OrderedSet[str] = OrderedSet() + self.used_cached_dtypes: OrderedSet[str] = OrderedSet() + self.used_cached_layouts: OrderedSet[str] = OrderedSet() + self.used_cached_memory_formats: OrderedSet[str] = OrderedSet() + self.used_cond_predicate: OrderedSet[str] = OrderedSet() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + self.custom_op_wrapper_loaded = False + # For GEMM kernels that must be initialized and are resolved at linking. + self.initialized_kernels: dict[str, Kernel] = {} + self.device_codegen = get_device_op_overrides(self.device) + # only need to include each header once + self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] + self._include_extra_header + ) + self.codegen_int_array_var_cache = {} + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpu() + + @staticmethod + def _generate_temporary_array_pointer( + c_type: str, elements: Sequence[str], *, force_mutable: bool = False + ) -> str: + """Get a pointer to an array that only exists for the duration of the C++ + statement it's used in.""" + # If the c_type is already a pointer, return a mutable pointer to the array. + # Otherwise, return a const pointer. In the C-shim API, pointer types are only + # const-qualified with respect to the underlying value, not any nested pointers. + # e.g. const double** is possible, but not const double* const*. This means + # that an array containing pointers must _already_ be properly const-qualified + # by the c_type, and not add additional const-ness. + # MSVC does not support implicitly converting a const iterator to a const pointer. + ptr_call = ( + "data()" + if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl() + else "cbegin()" + ) + return ( + f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call:\n" + f"call_args: {call_args}\n" + f"arg_types: {arg_types}" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if isinstance(arg_types[idx], str) and "*" in arg_types[idx]: + new_args.append(f"({arg_types[idx]})({arg}.data_ptr())") + else: + # arg is a scalar - ensure it's a string for C++ codegen + # With Triton support, arg might be a SymPy expression or other type + new_args.append(str(arg) if not isinstance(arg, str) else arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + @staticmethod + def get_device_include_path(device: str) -> str: + if V.graph.aot_mode: + return f"#include " + return f"#include " + + def add_device_include(self, device: str) -> None: + if device in self.included_devices: + return + + self.included_devices.add(device) + + # Add the default header for this device, plus any C-shim extensions that are + # present. + self.header.splice(self.get_device_include_path(device)) + extend_aoti_c_shim_include = ( + f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h" + ) + extend_aoti_c_shim_path = os.path.join( + os.path.dirname(torch.__file__), + "include", + extend_aoti_c_shim_include, + ) + if os.path.exists(extend_aoti_c_shim_path): + self.header.splice(f"#include <{extend_aoti_c_shim_include}>") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if not V.graph.aot_mode: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + r''' + """ + ) + + for device in V.graph.device_types: + if device != "meta": + self.add_device_include(device) + + if V.graph.aot_mode: + if config.aot_inductor.dynamic_linkage: + with open( + os.path.join( + os.path.dirname(__file__), "aoti_runtime", "interface.cpp" + ) + ) as f: + self.header.splice(f.read()) + else: + # we produce a separate model header for each model in static linkage + self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""") + self.header.splice("\n") + + if config.cpp.enable_kernel_profile: + self.header.splice( + "#include " + ) + self.header.splice( + """ + namespace torch::aot_inductor { + thread_local KernelContext* tls_kernel_context = nullptr; + } + """ + ) + + def _include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = {} + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + if config.aot_inductor.custom_ops_to_c_shims: + # custom_ops_to_c_shims contains declaration of custom ops with C shim. + # TODO: this could be auto-generated from a passed-in custom op schema + custom_c_shims = list( + chain(*config.aot_inductor.custom_ops_to_c_shims.values()) + ) + declarations = "\n".join( + [f"extern {textwrap.dedent(shim)};" for shim in custom_c_shims] + ) + self.prefix.splice( + f""" + extern "C" {{ + {declarations} + }} + """ + ) + if V.graph.aot_mode: + self.prefix.writeline("namespace torch::aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.cache + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + def codegen_symbol( + sym_or_exp: Union[sympy.Symbol, sympy.Expr], + base_name: str, + name_fn: Callable[[str], str], + dim: int, + ): + if isinstance(sym_or_exp, sympy.Symbol): + if sym_or_exp in bound_vars: + return + code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];") + bound_vars.add(sym_or_exp) + elif isinstance(sym_or_exp, sympy.Expr): + undefined_symbols = [ + sym for sym in sym_or_exp.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) != 1: + # Skip if expression contains no symbols or if multiple + # symbols exists since we assume each base symbol is defined + # by other codegen_symbol calls. + return + + from torch.utils._sympy.solve import try_solve + + free_symbol = undefined_symbols.pop() + base_name = name_fn(base_name) + # Use a size symbol to solve the free symbol + size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True) + code.writeline(f"int64_t {size_symbol} = {base_name}[{dim}];") + solution = try_solve(sympy.Eq(sym_or_exp, size_symbol), free_symbol) + if solution is not None: + code.writeline(f"int64_t {free_symbol} = {cexpr(solution[1])};") + bound_vars.add(free_symbol) + else: + raise AssertionError( + str(sympy.Eq(sym_or_exp, size_symbol)) + " is not solvable" + ) + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + if value.is_integer: + decl = "int64_t" + elif value.is_float: + decl = "double" + else: + raise AssertionError("Unexpected symbol type") + code.writeline(f"{decl} {value} = {name};") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + codegen_symbol(size, name, sizeof, dim) + for dim, stride in enumerate(value.get_stride()): + codegen_symbol(stride, name, strideof, dim) + elif isinstance(value, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def generate_input_output_runtime_checks(self): + """ + In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each + real input/output tensor match ones provided at compile time via sample + input/output. + """ + + def gen_check(handle_kind, idx, name, tensor): + # Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access + self.prefix.writeline( + f"ConstantHandle {name} = ConstantHandle({handle_kind}[{idx}]);" + ) + self.codegen_tensor_dtype_var_decl(self.prefix, name) + expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype] + dtype_str = str(tensor.dtype).split(".")[-1] + self.prefix.splice( + f""" + int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}(); + if ({name}_expected_dtype != {name}_dtype) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dtype, " + << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), " + << "but got: " << {name}_dtype << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + self.codegen_input_size_var_decl(self.prefix, name) + for dim_idx, d in enumerate(tensor.get_size()): + if isinstance(d, (int, sympy.Integer)): + self.prefix.splice( + f""" + if ({d} != {name}_size[{dim_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, " + << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + else: + from torch.utils._sympy.value_ranges import bound_sympy + + sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) + if config.aot_inductor.check_lowerbound and not math.isinf( + sym_range.lower + ): + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, " + << "expected it to be >= {sym_range.lower}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + if not math.isinf(sym_range.upper): + # Limit upper bound to max C long long value (2^63 - 1) + max_long_long = ctypes.c_longlong(2**63 - 1).value + upper_bound = min(sym_range.upper, max_long_long) + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] > {upper_bound}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " + << "expected to be <= {upper_bound}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + self.codegen_input_stride_var_decl(self.prefix, name) + for stride_idx, s in enumerate(tensor.get_stride()): + if not isinstance(s, (int, sympy.Integer)): + continue + self.prefix.splice( + f""" + if ({s} != {name}_stride[{stride_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, " + << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # check input device type + if isinstance(tensor, ir.TensorBox): + tensor_device = tensor.get_device() + if tensor_device is not None: + expected_device_type = DEVICE_TO_INT.get(tensor_device.type) + if expected_device_type is not None: + self.codegen_input_device_type_var_decl(self.prefix, name) + device_type_str = str(tensor_device.type) + self.prefix.splice( + f""" + int32_t {name}_expected_device_type = {expected_device_type}; + if ({name}_expected_device_type != {name}_device_type) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched device type, " + << "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), " + << "but got: " << {name}_device_type << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # Create a separate function for each input check to avoid "too big to optimize" error + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + self.prefix.splice( + f""" + AOTI_NOINLINE static void check_input_{idx}( + AtenTensorHandle* input_handles + ) {{ + """ + ) + with self.prefix.indent(): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + + # force noinline to avoid any potential compilation slowdown due to aggressive + # inline done by the host compiler + self.prefix.splice( + """ + static bool _check_aoti_runtime_check_inputs_env() { + const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS"); + const static bool result = env_var_value != nullptr && env_var_value[0] != '0'; + return result; + } + + AOTI_NOINLINE static void __check_inputs_outputs( + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + if (!_check_aoti_runtime_check_inputs_env()){ + return; + } + """ + ) + with self.prefix.indent(): + for idx in range(len(V.graph.graph_inputs)): + self.prefix.writeline(f"check_input_{idx}(input_handles);") + self.prefix.writeline("}") + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + self.codegen_additional_funcs() + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + f""" + void {self.aoti_model_class_name}::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {{ + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + f""" + void {self.aoti_model_class_name}::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {{}} + + """ + ) + + run_impl_proto = f""" + void {self.aoti_model_class_name}::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {{ + __check_inputs_outputs(input_handles, output_handles); + """ + + self.generate_input_output_runtime_checks() + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release_simple release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + # debug printing for all input args to AOTI model + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.codegen_model_inputs_value_print( + input_args_to_print=[ + input_key + for input_key in V.graph.graph_inputs + if input_key.startswith("arg") + ] + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by ConstantHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"[[maybe_unused]] auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_dtype;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));" + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_size = {name}.sizes();") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_stride = {name}.strides();") + + def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_device_type;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" + ) + + def codegen_additional_funcs(self): + pass + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + + # Tell compiler we need to link with the non-mangled symbols + for kernel in self.initialized_kernels.values(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + signature = kernel.get_signature() + self.prefix.writeline(f'extern "C" {signature};') + + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = OrderedSet(self.src_to_kernel.values()) - OrderedSet( + self.initialized_kernels.keys() + ) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in sorted(declare_kernel): + self.prefix.writeline( + maybe_hipify_code_wrapper( + f" {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};" + ) + ) + for name, kernel in self.initialized_kernels.items(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + kernel_ptr = f"(*{name})" + signature = kernel.get_signature().replace(name, kernel_ptr) + self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};") + self.prefix.writeline("};") + self.prefix.writeline("} // namespace\n\n") + + if config.aot_inductor.embed_kernel_binary: + self.prefix.writeline('extern "C" {') + for name in sorted(declare_kernel): + self.prefix.writeline( + f" extern const unsigned char __{name}_start[];" + ) + if torch.xpu.is_available(): + self.prefix.writeline( + f" extern const unsigned char __{name}_end[];" + ) + self.prefix.writeline("}") + + # MSVC string was longer than the limit of 16380 single-byte characters. + # https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026 + MSVC_C2026_MAX_STRING_LENGTH = 16000 + + def codegen_write_arg_with_large_length_string( + self, + arg_name: str, + arg_str_val: str, + max_truncate_length: int = MSVC_C2026_MAX_STRING_LENGTH, + ): + def truncate_string(s: str, length: int) -> list[str]: + return [s[i : i + length] for i in range(0, len(s), length)] + + if len(arg_str_val) > max_truncate_length: + truncated_strs = truncate_string(arg_str_val, max_truncate_length) + self.prefix.writeline(f"{arg_name} =") + for truncate_str in truncated_strs: + self.prefix.writeline(f'R"({truncate_str})"') + self.prefix.writeline(";") + else: + self.prefix.writeline(f'{arg_name} = R"({arg_str_val})";') + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + include_weights = ( + "true" + if config.aot_inductor.package_constants_in_so + and config.aot_inductor.package_constants_on_disk_format != "binary_blob" + else "false" + ) + self.prefix.splice( + f""" + {self.aoti_model_class_name}::{self.aoti_model_class_name}(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, + {num_outputs}, + {num_constants}, + device_str, + std::move(cubin_dir), + {include_weights}) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance(inp, sympy.Expr), ( + f"input {name=} cannot be symbolic" + ) + self.write_input_output_info("inputs_info_", idx, name) + + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants + if name not in V.graph.folded_constants + ) + for idx, name in enumerate(V.graph.constants.keys()): + tensor = V.graph.get_original_value_of_constant(name) + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + if name in V.graph.folded_constants: + constant_type_str = "FoldedConstant" + elif name.startswith("_tensor_constant"): + constant_type_str = "TensorConstant" + elif any( + name == normalize_name(parameter_name) + for parameter_name in V.graph.named_parameters + ): + constant_type_str = "Parameter" + elif any( + name == normalize_name(buffer_name) + for buffer_name in V.graph.named_buffers + ): + constant_type_str = "Buffer" + else: + constant_type_str = "Unknown" + self.prefix.writeline( + f"constants_info_[{idx}].type = static_cast(torch::aot_inductor::ConstantType::{constant_type_str});" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});" + ) + + if tensor.is_mkldnn: + opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( + tensor + ) + assert opaque_metadata_tensor.dim() == 1, ( + "Expect opaque_metadata_tensor to be 1-D" + ) + + opaque_metadata_list = opaque_metadata_tensor.tolist() + opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) + self.prefix.writeline( + f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + # Origin code: self.prefix.writeline(f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";') + # Fix msvc C2026 error via codegen_write_arg_with_large_length_string + self.codegen_write_arg_with_large_length_string( + arg_name="in_spec_", arg_str_val=config.aot_inductor.serialized_in_spec + ) + # Origin code: self.prefix.writeline(f'out_spec_ = R"({config.aot_inductor.serialized_out_spec})";') + # Fix msvc C2026 error via codegen_write_arg_with_large_length_string + self.codegen_write_arg_with_large_length_string( + arg_name="out_spec_", + arg_str_val=config.aot_inductor.serialized_out_spec, + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance(output, sympy.Expr), ( + f"output {name=} cannot be symbolic" + ) + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + f""" + std::unordered_map {self.aoti_model_class_name}::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) {{ + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: list[Optional[tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert None not in const_index_mapping, ( + "Not all constant gets mapped for constant folding graph." + ) + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True): + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + prior = self.prefix + self.prefix = aot_mode_decls = IndentedBuffer() + if V.graph.aot_mode and not V.graph.is_const_graph: + aot_mode_decls.writeline("namespace torch::aot_inductor {") + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + aot_mode_decls.writeline("} // namespace torch::aot_inductor") + aot_mode_decls.writeline("using namespace torch::aot_inductor;") + + self.prefix = cache_decls = IndentedBuffer() + for dtype in self.used_cached_dtypes: + cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});") + for memory_format in self.used_cached_memory_formats: + cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});") + + self.prefix.splice(aot_mode_decls) + self.prefix.splice(prior) + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = False, + cpp_definition: Optional[str] = None, + ): + if cpp_definition is not None: + self.header.splice(cpp_definition) + self.kernel_declarations.splice(f"\n{kernel_body}\n") + else: + self.header.splice(f"\n{kernel_body}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + output2idx: dict[str, int] = {} + + # If any output ref represents an rvalue tensor, materialize it to an lvalue + # RAIIAtenTensorHandle first. This prevents situations where the code for the + # rvalue tensor references tensor handles whose contents are modified below. + output_refs = [ + self.create_tmp_raii_handle_var_if_needed(o, self.wrapper_call) + for o in output_refs + ] + + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + + if output not in output2idx: + output2idx[output] = idx + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline(f"}} // {self.aoti_model_class_name}::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + """Generates the end of the code block, and any code needed to call it.""" + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline(f"}} // {self.aoti_model_class_name}::_const_run_impl") + else: + result.writeline("} // namespace torch::aot_inductor\n\n\n") + return + + if config.cpp_wrapper_build_separate: + # Close the wrapper code block, then write any kernel definitions. + result.splice("'''\n)") + if self.kernel_declarations: + result.splice("\nkernel_src = (\nr'''") + result.splice(self.kernel_declarations.getvalue()) + result.splice("'''\n)") + else: + result.splice( + """ + kernel_src = '' + """ + ) + else: + # Merge main code and kernel code + result.splice(self.kernel_declarations.getvalue()) + self.kernel_declarations.clear() + # Close the wrapper code block + result.splice("'''\n)") + + kernel_code = "kernel_src" if config.cpp_wrapper_build_separate else "None" + # Cpp entry function for JIT with cpp wrapper + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + argtypes=["std::vector"], + main_code=cpp_wrapper_src, + device_type="{self.device}", + num_outputs={len(V.graph.graph_outputs)}, + kernel_code={kernel_code}, + ) + """ + ) + + wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + wrapper_body += f""" + constants_tensor = {constants_str} + input_tensors.extend(constants_tensor) + """ + # Convert vector of at::Tensor to vector of AtenTensorHandle. + # If we pass at::Tensor, the compilation will be too slow. + wrapper_body += """ + input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) + """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + del input_tensors + """ + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + outputs_str = "output_tensors" + else: + outputs = [ + ( + f"output_tensors[{i}]" + if self.output_is_tensor[i] + else f"output_tensors[{i}].item()" + ) + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + wrapper_body += f""" + output_handles = f(input_handles) + output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) + return {outputs_str} + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {wrapper_body} + return g + + call = _wrap_func(inductor_entry) + """ + ) + + @staticmethod + def get_c_shim_func_name(kernel: str, device: str) -> str: + if kernel.startswith("aoti_torch_"): + return kernel + + assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + + shim_fn = f"aoti_torch_{device}_{kernel_suffix}" + return shim_fn + + def generate_c_shim_extern_kernel_call( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + stack_traces: Optional[OrderedSet[str]] = None, + ) -> None: + """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in + place of args while preserving debug printer output.""" + # We can do this unconditionally, since we cache this call. + self.add_device_include(device) + + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + debug_args if debug_args is not None else args, kernel, None, None, "extern" + ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + with debug_printer_manager: + shim_fn = self.get_c_shim_func_name(kernel, device) + shim_fn_codes = [ + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" + ] + if enable_kernel_profile: + stack_trace_str = 'R"(' + if stack_traces: + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + + shim_fn_codes = [ + "{", + f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""", + f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""", + shim_fn_codes[0], + "}", + ] + self.writelines(shim_fn_codes) + + def generate_c_shim_extern_kernel_alloc( + self, extern_kernel: ir.ExternKernelAlloc, args: list[str] + ) -> None: + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + is_inplace = ( + isinstance(extern_kernel.op_overload, torch._ops.OpOverload) + and torch.Tag.inplace_view in extern_kernel.op_overload.tags + ) + + if not is_inplace: + self.writeline(f"AtenTensorHandle {output_handle_name};") + args = [*args, f"&{output_handle_name}"] + + device = d.type if (d := extern_kernel.get_device()) else self.device + + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args, device + ) + + if extern_kernel.python_kernel_name in ( + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.wait_tensor.default", + ): + # all_reduce_ is an inplace op and its returned tensor is not used anywhere. + # wait_tensor returns its input without any modification and the returned tensor is not used anywhere. + # In both cases, we can immediately delete the returned AtenTensorHandle to reduce its lifetime. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object({output_handle_name}));" + ) + elif not is_inplace: + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + if getattr(extern_kernel, "outputs", None): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) + else: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel( + self, fallback_kernel: ir.FallbackKernel, args: list[str] + ) -> None: + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + # TODO: handle integer output (e.g., as in attention) + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert output.indices[0][1] == idx, ( + f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + ) + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif isinstance(output, sympy.Expr): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"auto {output_name} = {cexpr(output)};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError(f"unsupported type of {output=}") + args = args + output_args + device = d.type if (d := fallback_kernel.get_device()) else self.device + + self.generate_c_shim_extern_kernel_call( + fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] + args, + device, + ) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + stack_traces: Optional[OrderedSet[str]] = None, + ) -> None: + if out_view: + out_name = f"{out}_as_strided" + self.writeline(f"auto {out_name} = {out_view};") + args.insert(0, out_name) + else: + args.insert(0, out) + + self.generate_c_shim_extern_kernel_call( + kernel, args, device, stack_traces=stack_traces + ) + + def _get_scatter_reduce_enum(self, reduce): + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + return reduce + + def _generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + device, + ): + reduce = self._get_scatter_reduce_enum(reduce) + + # call the ABI shim function instead of the ATen one + self.add_device_include(device) + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [str(x) for x in inputs] + line = f"{cpp_kernel_name}({output}, {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", indices + ) + args = [ + x, + indices_str, + str(len(indices)), + values, + accumulate, + ] + args.insert(0, x) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_cpp_sizevar(self, x: sympy.Expr, *, simplify: bool = True) -> str: + return cexpr(V.graph.sizevars.simplify(x) if simplify else x) + + def codegen_sizevar(self, x: sympy.Expr) -> str: + return self.codegen_cpp_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + # in the abi_compatible mode, outputs are returned via arguments + return name + + def codegen_shape_tuple(self, shape: Sequence[sympy.Expr]) -> str: + parts = [*map(self.codegen_sizevar, shape)] + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"int64_t {sym} = {cexpr(expr)};") + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + if (arg.inner, graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((arg.inner, graph)) + self.writeline(f"int64_t {arg.inner} = {cexpr(arg.inner_expr)};") + else: + self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};") + + def _codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw") + + if len(node.keypath) == 0: + self.writeline(f"auto {node.sym} = {node.sym}_raw;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + # TODO: assert divisibility here + self.writeline( + f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + + def codegen_dynamic_select_index(self, node, clamp): + index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) + size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int) + + # codegen index + sym = node.unbacked_offset_symbol + index_str = ( + f"{index_cpp_str} < 0 ? {index_cpp_str} + " + f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" + ) + self.writeline(f"auto {sym}_index = {index_str};") + index_str_clamped = ( + f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)" + if clamp + else f"{sym}_index" + ) + self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};") + self.writeline( + f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " + f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(sym)) + + def codegen_dynamic_slice_size(self, node): + start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int) + end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int) + size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int) + step_cpp_str = self.val_to_arg_str_for_prim_type(node.step, int) + sym = node.unbacked_size_symbol + + def codegen_clamp(index_str, start=True): + suf = "st" if start else "en" + index_ = f"{sym}_{suf}_index" + self.writeline( + f"int64_t {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};" + ) + self.writeline( + f"int64_t {sym}_{suf}_cl = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});" + ) + + codegen_clamp(start_cpp_str, start=True) + codegen_clamp(end_cpp_str, start=False) + if node.step == 1: + step_str = f"{sym}_en_cl - {sym}_st_cl" + else: + step_str = ( + f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} - 1) / {step_cpp_str}" + ) + self.writeline(f"int64_t {sym}_with_step = {step_str};") + self.writeline(f"int64_t {sym} = {sym}_with_step < 0 ? 0 : {sym}_with_step;") + self.unbacked_symbol_decls.add(str(sym)) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or isinstance(buffer, ir.TMADescriptor) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: list[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"auto {new_name} = std::move({old_name}); // reuse" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);' + ) + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + assert device.type in DEVICE_TO_ATEN, ( + device.type + " not found in DEVICE_TO_ATEN" + ) + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + self.used_cached_devices.add(device_str) + return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" + + def codegen_dtype(self, dtype): + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + + def codegen_layout(self, layout): + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" + + def codegen_memory_format(self, memory_format): + memory_format_str = str(memory_format).split(".")[-1] + self.used_cached_memory_formats.add(memory_format_str) + return f"cached_torch_memory_format_{memory_format_str}" + + def codegen_int_array_var( + self, + int_array: str, + writeline: Callable[..., None], + known_statically=False, + graph=None, # for per-graph caching + ) -> str: + # Use id(graph) for caching to avoid circular references + cache_key = ( + int_array, + id(writeline), + known_statically, + id(graph) if graph else None, + ) + if cache_key not in self.codegen_int_array_var_cache: + self.codegen_int_array_var_cache[cache_key] = ( + self._codegen_int_array_var_impl(int_array, writeline, known_statically) + ) + + return self.codegen_int_array_var_cache[cache_key] + + def _codegen_int_array_var_impl( + self, + int_array: str, + writeline: Callable[..., None], + known_statically: bool, + ) -> str: + # Used for size/stride declaration + # + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass. + # This is why writeline needs to explicitly passed in as a parameter. + var = f"int_array_{next(self.int_array_id)}" + ctype = "int64_t" + if int_array == "{}": + # An array of unknown bound cannot be initialized with {}. + if known_statically: + if config.cpp.use_constexpr_for_int_array: + writeline(f"static constexpr {ctype} *{var}=nullptr;") + else: + writeline(f"static const {ctype} *{var}=nullptr;") + else: + writeline(f"const {ctype} *{var}=nullptr;") + else: + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + if config.cpp.use_constexpr_for_int_array: + writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writeline(f"static const {ctype} {var}[] = {int_array};") + else: + writeline(f"const {ctype} {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + V.graph.get_allocation_size(buffer), + buffer.get_is_pinned(), + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False + ): + if allocation_shape is None: + allocation_shape = shape + + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + allocation_size = self.codegen_shape_tuple(allocation_shape) + stride = self.codegen_shape_tuple(orig_stride) + + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + + if allocation_size != size: + allocation_size_array_var = self.codegen_int_array_var( + allocation_size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints( + allocation_shape + ), + graph=self.get_codegened_graph(), + ) + else: + allocation_size_array_var = size_array_var + + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + + handle_name = f"{name}_handle" + args = [ + str(len(shape)), + allocation_size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{handle_name}", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + pinned_str = "_pinned" if is_pinned else "" + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" + ) + + if allocation_size != size: + old_handle_name, handle_name = handle_name, f"{name}_handle_restrided" + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + args = [ + old_handle_name, + size_array_var, + stride_array_var, + f"&{handle_name}", + ] + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));" + ) + self.wrapper_call.writeline( + f"wrap_with_raii_handle_if_needed({old_handle_name});" + ) + + return f"RAIIAtenTensorHandle {name}({handle_name});" + + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + cexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + # We return the lines instead of writing here because writing here is bug prune. + # If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle + # as well, otherwise you get memory leaks + allocations_to_write = [ + f"AtenTensorHandle {tmp_name};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));", + ] + return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + + d_size, d_stride, d_offset, d_dtype, collapsible = ( + codegen_reinterpret_view_helper(data) + ) + + dim = str(len(size)) + original_offset = offset + offset = self.codegen_sizevar(offset) + call_strs = [] + final_tensor_str = None + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + device_name = data.layout.device.type + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + tmp_call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {self.codegen_dtype(dtype)}, &{tmp_AtenTensorHandle}));" + ) + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + def create_new_tensor_handle() -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + collapsed = collapsible and original_offset == d_offset + if collapsed: + same_layout = size == d_size and stride == d_stride + base_dtype = d_dtype + else: + same_layout = ( + size == data.layout.size + and stride == data.layout.stride + and original_offset == data.layout.offset + ) + base_dtype = data.dtype + + if same_layout: + # pure dtypeview + if dtype is not None and dtype != base_dtype: + final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name()) + else: + final_tensor_str, tmp_call_strs = create_new_tensor_handle() + call_strs.extend(tmp_call_strs) + else: + # firstly create reinterpretview + final_tensor_str = create_reinterpret_call() + if dtype is not None and dtype != base_dtype: + # wrap it with dtypeview + final_tensor_str, tmp_call_strs = create_dtypeview_call( + final_tensor_str + ) + call_strs.extend(tmp_call_strs) + + for line in call_strs: + writeline(line) + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::array{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.cbegin() + # ); + # ``` + return final_tensor_str + + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): + """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to + handle cases where dst is not an AtenTensorHandle.""" + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));" + ) + + def codegen_multi_output(self, node: ir.MultiOutput): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + pass + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + + for (inner_input, inner_input_val), outer_input in zip( + subgraph.graph.graph_inputs.items(), outer_inputs + ): + if not isinstance(inner_input_val, ir.TensorBox): + continue + + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if not isinstance(inner_output, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") + self.writeline(f"{outer_output} = {src};") + + def codegen_invoke_subgraph(self, invoke_subgraph): + raise NotImplementedError( + "codegen invoke_subgraph is not implemented for cpp wrapper" + ) + + def codegen_conditional(self, conditional): + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the underlying scalar bool Tensor + predicate = f"{conditional.predicate.get_name()}_scalar" + if predicate not in self.used_cond_predicate: + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, + ) + self.used_cond_predicate.add(predicate) + else: + # the predicate is not a Tensor: SymBool or Python bool + predicate = conditional.predicate.codegen_reference() + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # PythonWrapperCode `codegen_subgraph` function. We should perhaps + # support lifting of subgraphs as functions for cpp wrapper as well. + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"// subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_while_loop(self, while_loop, stack_output=False): + if stack_output: + raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output") + is_bool_pred = isinstance( + while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer + ) + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + cond_result_name = f"{name}_cond_result" + if is_bool_pred: + self.writeline(f"bool {cond_result_name};") + else: + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + # in ABI-compatible mode, the carried inputs are codegened + # as buffers outside the while loop and set to the initial + # values. at the end of each while_loop iteration, they + # will be assigned the carried values. + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) + + # additional inputs will be assigned within the while_loop + # iteration directly from the corresponding outer graph buffers + cond_outer_inputs.extend(outer_additional_inputs) + + cond_outer_outputs = [cond_result_name] + body_outer_inputs = list(cond_outer_inputs) + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while (1) {") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + + if is_bool_pred: + cond_result = f"{cond_result_name}" + else: + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + self.writeline(f"if (!{cond_result}) break;") + + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ): + """ + Generates declarations for external kernel arguments if needed, based on the provided + operator and its arguments. It processes both input and output arguments, categorizing + them into tensor and integer arguments for further code generation. + """ + schema = None + if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind): + obj = raw_args[0] + method = raw_args[1] + schema = op_overload.schema(obj, method) + else: + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + schema = op_overload._schema + assert schema is not None + arg_types = [x.real_type for x in schema.arguments] + return_types = [x.type for x in schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(cexpr(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend([cexpr(expr) for expr in expressions]) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all(is_int_type), ( + "AOTInductor only supports int scalars of the same type" + ) + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + else: + assert isinstance( + arg_type, + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg( + arg: str, return_type: torch.JitType, is_mutated_output: bool + ) -> None: + if isinstance(return_type, torch.TensorType): + if not is_mutated_output: + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance( + return_type, (torch.TensorType, torch.NoneType, torch.IntType) + ): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type] + # None output is supported, but Optional return types are not yet supported + if output_arg is None: + continue + elif isinstance(raw_output_arg, int): + new_int_args.append(str(raw_output_arg)) + elif isinstance(output_arg, list): + for out in output_arg: + assert out is not None, out + fill_output_arg( + out, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + else: + fill_output_arg( + output_arg, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + + return new_tensor_args, new_int_args + + @staticmethod + def _compatible_with_stableivalue(op: torch._ops.OpOverload) -> bool: + """Returns true if op_overload._schema only utilizes types supported by the AOT + C-shim *internal* function to_ivalue. to_ivalue is an implementation detail, so + these types are not guaranteed to be supported long-term. When generating code + for cpp_wrapper mode, we don't have to be forward-compatible, so changing this + function's implementation in future is fine.""" + supported_types = ( + torch.BoolType, + torch.DeviceObjType, + torch.FloatType, + # ScalarTypeType, LayoutType, and MemoryFormatType are seen as IntType + # when queried via torch.JitType.type. + torch.IntType, + torch.TensorType, + ) + + def type_supported(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return type_supported(t.getElementType()) + return isinstance(t, supported_types) + + return all( + type_supported(a.type) + for a in chain(op._schema.arguments, op._schema.returns) + ) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + """Generate a call to a kernel not contained in the C-shim. This results in + different code paths for AOT Inductor vs cpp_wrapper Inductor mode.""" + + def extract_output_name( + out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]], + ) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]: + if out is None: + return None + if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + return out.get_name() + if isinstance(out, ir.MutationOutput): + mutated_buf_names = out.get_mutation_names() + assert ( + isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 + ), "Expect only one mutated buffer in MutationOutput" + return mutated_buf_names[0] + if isinstance(out, (list, tuple)): + return [extract_output_name(o) for o in out] # type: ignore[misc] + if isinstance(out, int): + return str(out) + raise AssertionError(f"Unexpected output: {type(out)}") + + if isinstance(op_overload, torch._ops.HigherOrderOperator): + assert isinstance( + op_overload, torch._higher_order_ops.torchbind.CallTorchBind + ), type(op_overload) + assert len(raw_args) > 1 + obj = raw_args[0] + method = raw_args[1] + return_schema = op_overload.schema(obj, method).returns + else: + return_schema = op_overload._schema.returns + + # output_args has the same pytree structure as outputs + if not return_schema: + # kernel does not return a value + output_args: _OUTPUT_ARGS_TYPE = [] + elif isinstance(output_name := extract_output_name(outputs), str): + output_args = [output_name] + else: + # If the schema indicates a return value, we should have a non-None value by + # this point. + assert isinstance(output_name, list), type(output_name) + output_args = output_name + + # In AOT mode, we use a ProxyExecutor to run fallback kernels. + if V.graph.aot_mode: + self.generate_fallback_kernel_with_runtime_lookup_aot( + op_overload, + raw_args, + output_args, + outputs, + ) + return + + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + for output in output_args: + assert output is None or isinstance(output, str), ( + "fallback kernels with runtime lookup currently only support tensor " + "returns, not more complicated types (such as list-of-list-of-tensor)" + ) + + # In non-AOT mode, we use aoti_torch_call_dispatcher if all the inputs and + # outputs of the op can be represented with StableIValue. This avoids the + # overhead of calling back into Python, and covers most remaining fallback ops. + if self._compatible_with_stableivalue(op_overload): + self.generate_fallback_kernel_with_runtime_lookup_nopython( + get_args, + op_overload, + output_args, # type: ignore[arg-type] + outputs, + ) + return + + # Otherwise, we call back into Python, which has some extra runtime overhead, + # but handles situations like list[Tensor] (currently unrepresentable via + # StableIValue). + self.generate_fallback_kernel_with_runtime_lookup_python( + buf_name, + python_kernel_name, + op_overload, + raw_args, + output_args, # type: ignore[arg-type] + outputs, + ) + + def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): + scoped_lines = IndentedBuffer() + for declaration in declarations_before_scope: + scoped_lines.writeline(declaration) + + scoped_lines.writeline("{") + with scoped_lines.indent(): + scoped_lines.writeline("py::gil_scoped_acquire_simple acquire;") + scoped_lines.writelines(lines_in_scope.split("\n")) + scoped_lines.writelines("}") + return scoped_lines._lines + + def load_custom_op_wrapper(self): + # TODO: need to support control flow + if self.custom_op_wrapper_loaded: + return + + lines = """ +RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); +if (!codecache_module) { + throw std::runtime_error("Failed to load torch._inductor.codecache"); +} +custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); +if (!custom_op_wrapper) { + throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); +}""" + + declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + self.custom_op_wrapper_loaded = True + + def generate_float_value(self, val): + assert isinstance(val, float) + if val == float("inf"): + return "std::numeric_limits::infinity()" + elif val == float("-inf"): + return "-std::numeric_limits::infinity()" + elif math.isnan(val): + return "std::numeric_limits::quiet_NaN()" + else: + return f"{val}" + + def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): + def generate_py_arg_inner(lines, raw_arg, arg_type): + def handle_scalar(scalar): + if isinstance(scalar, int): + return f"PyLong_FromLongLong({scalar})" + if isinstance(scalar, float): + return f"PyFloat_FromDouble({self.generate_float_value(scalar)})" + if isinstance(scalar, bool): + return f"PyBool_FromLong({1 if scalar else 0})" + if isinstance(scalar, complex): + real = self.generate_float_value(scalar.real) + imag = self.generate_float_value(scalar.imag) + return f"PyComplex_FromDoubles({real}, {imag})" + if isinstance(scalar, SymTypes): + scalar_var = cexpr(scalar.node.expr) + if isinstance(scalar, torch.SymBool): + return f"PyBool_FromLong({scalar_var})" + if isinstance(scalar, torch.SymFloat): + return f"PyFloat_FromDouble({scalar_var})" + return f"PyLong_FromLongLong({scalar_var})" + raise NotImplementedError( + f"scalar {scalar}, {type(scalar)} cannot be handled by handle_scalar" + ) + + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): + # In some cases, scalar arguments may be passed in place of tensors. + if not hasattr(raw_arg, "codegen_reference"): + return handle_scalar(raw_arg) + + # Store AtenTensorHandle as void*. All Python args are constructed in a + # nested scope, so this handle will self-destruct after the function + # call. + base_handle = self.create_tmp_raii_handle_var_if_needed( + raw_arg.codegen_reference(), lines + ) + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) + elif isinstance(arg_type, torch.IntType): + # int + return f"PyLong_FromLongLong({raw_arg})" + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = ( + raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg + ) + return f"PyLong_FromLongLong({cexpr(expr)})" + elif isinstance(arg_type, torch.FloatType): + return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})" + elif isinstance(arg_type, torch.BoolType): + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' + elif isinstance(arg_type, torch.NumberType): + # Union[bool, int, float, complex] + # torch/_prims_common/__init__.py + return handle_scalar(raw_arg) + elif isinstance(raw_arg, torch.device): + device_str, device_index = self.codegen_device(raw_arg).split(", ") + return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))" + elif isinstance(raw_arg, torch.dtype): + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" + elif isinstance(raw_arg, torch.layout): + return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))" + elif isinstance(raw_arg, torch.memory_format): + return ( + "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast(" + f"{self.codegen_memory_format(raw_arg)})))" + ) + else: + raise NotImplementedError( + f"arg type {arg_type} is not yet supported by custom_op_wrapper" + ) + + lines = [] + if isinstance(arg_type, torch.ListType): + assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) + for i, elem in enumerate(raw_arg): + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) + else: + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) + + def generate_fallback_kernel_with_runtime_lookup_nopython( + self, + get_args: Callable[[], Sequence[str]], + op_overload: torch._ops.OpOverload, + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + In the future, we may switch over to directly calling c10::Dispatcher if we need + to support more datatypes.""" + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + + dispatch_lines = IndentedBuffer() + dispatch_lines.writelines(declarations_before_scope) + dispatch_lines.writeline("{") + + with dispatch_lines.indent(): + tmp_var_number = count() + + def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: + # Strip off any temporary references; we're in an indented context, so + # any saved-off variables will be auto-destroyed. + new_codegen_arg = codegen_arg.removeprefix("&temporary_reference(") + if new_codegen_arg != codegen_arg: + # If we removed temporary_reference, there's a good chance the + # variable ends with get() (which would retrieve an ATenTensorHandle + # from a temporary RAII handle). Strip that off too, since we're + # going to save this in a temporary RAII handle. + if codegen_arg.endswith(".get())"): + codegen_arg = new_codegen_arg.removesuffix(".get())") + else: + codegen_arg = new_codegen_arg.removesuffix(")") + + if isinstance(arg_type, torch.OptionalType): + # If we have a pointer to a variable, strip it off and let + # from handle any internal pointers. + codegen_arg = codegen_arg.removeprefix("&") + + if codegen_arg == "nullptr": + return "torch::stable::detail::from(std::nullopt)" + + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline( + f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" + ) + return f"torch::stable::detail::from({var_name})" + + raii_var = self.create_tmp_raii_handle_var_if_needed( + codegen_arg, dispatch_lines + ) + temp_handle = raii_var != codegen_arg + + if isinstance(arg_type, torch.TensorType): + if not temp_handle: + # If the RAII tensor being referenced _isn't_ a temporary, + # scoped to this fallback call, then create a new handle + # referencing it which from can steal. + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline(f"AtenTensorHandle {var_name};") + dispatch_lines.writeline( + f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" + ) + return f"torch::stable::detail::from({var_name})" + # If the RAII tensor _is_ a temporary scoped to this fallback call, + # simply release and steal the handle. + return f"torch::stable::detail::from({raii_var}.release())" + return f"torch::stable::detail::from({codegen_arg})" + + codegen_args = get_args() + ivalue_args = ( + parse_arg(a.type, c) + for a, c in zip(op_overload._schema.arguments, codegen_args) + ) + array_len = max(len(codegen_args), len(output_args)) + dispatch_lines.writeline( + f"std::array dispatch_vars{{{', '.join(ivalue_args)}}};" + ) + dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(") + with dispatch_lines.indent(): + dispatch_lines.writeline( + f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' # noqa: B950 + ) + dispatch_lines.writeline(");") + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + dispatch_lines.writeline( + f"{output} = torch::stable::detail::to(dispatch_vars[0]);" + ) + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + dispatch_lines.writeline( + f"{output_arg} = torch::stable::detail::to(dispatch_vars[{idx}]);" + ) + + dispatch_lines.writeline("}") + self.writelines(dispatch_lines.getvalue().splitlines()) + + def generate_fallback_kernel_with_runtime_lookup_python( + self, + buf_name: str, + python_kernel_name: str, + op_overload: torch._ops.OpOverload, + raw_args: Sequence[Any], + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + This function calls into Python to dispatch, which allows it to handle datatypes + that cannot be contained in StableIValue, at the cost of some performance.""" + self.load_custom_op_wrapper() + + num_args = len(raw_args) + py_args_var = f"py_args_{next(self.arg_var_id)}" + # First arg is always the python op name + lines = textwrap.dedent( + f""" + RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1})); + if (!{py_args_var}) {{ + throw std::runtime_error("PyTuple_New {py_args_var} failed"); + }} + PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); + """ + ) + + for idx, (raw_arg, schema_arg) in enumerate( + zip(raw_args, op_overload._schema.arguments) + ): + lines += self.generate_py_arg( + py_args_var, idx + 1, raw_arg, schema_arg.real_type + ) + + lines += textwrap.dedent( + f""" + // Call the custom op in Python + RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); + if (!py_{buf_name}) {{ + if (PyErr_Occurred()) {{ + return; + }} + throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); + }} + """ + ) + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + lines += f"{output} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));\n" + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" # noqa: B950 + + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + def generate_fallback_kernel_with_runtime_lookup_aot( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ) -> None: + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, + raw_args, + output_args, + raw_outputs, + ) + # force both temporary arrays to generate mutable data pointers, since the proxy + # executor signature requires that datatype + int_call_str = self._generate_temporary_array_pointer( + "int64_t", int_call_args, force_mutable=True + ) + tensor_call_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", tensor_call_args, force_mutable=True + ) + + extern_kernel_node_index = len(V.extern_kernel_nodes) - 1 + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"{int_call_str}, " + f"{len(tensor_call_args)}, " + f"{tensor_call_str});" + ) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def c_type_for_prim_type(self, val, type_) -> str: + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("Layout", "MemoryFormat", "ScalarType"): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, (int, float)): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") + + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): + return "1" if val else "0" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS and Windows + return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" + elif isinstance(val, complex): + return f"c10::complex{{ {self.generate_float_value(val.real)}, {self.generate_float_value(val.imag)} }}" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, torch.layout): + return self.codegen_layout(val) + elif isinstance(val, torch.memory_format): + return self.codegen_memory_format(val) + elif isinstance(val, float): + return self.generate_float_value(val) + elif isinstance(val, (list, tuple)): + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + elif isinstance(val, SymTypes): + return cexpr(val.node.expr) + elif isinstance(val, sympy.Expr): + return cexpr(val) + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.DeviceObjType, + torch.ListType, + torch.TupleType, + ), + ): + return "nullptr, 0" + return "nullptr" + + if isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);") + return var_name + + raise AssertionError("Can not map None to a known data type") + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + arg_str = self.val_to_arg_str(val, element_type) + # Handle optional iterables as a special case. Utilize the + # temporary_reference function to avoid saving them off and increasing + # memory usage. + if isinstance(element_type, (torch.ListType, torch.TupleType)): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + return f"&temporary_reference({main_value}), {aux}" + + # Handle optional tensors as a special case, as above. + if isinstance(element_type, torch.TensorType): + base_handle = self.val_to_arg_str(val, element_type) + return f"&temporary_reference({base_handle}.get())" + + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(element_type, torch.DeviceObjType): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + + self.writeline( + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {arg_str};" + ) + return f"&{var_name}" + + if isinstance(type_, (torch.ListType, torch.TupleType)): + assert isinstance(val, (list, tuple)), ( + f"{val} does not match with arg type {type_}" + ) + element_type = type_.getElementType() + + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so return a + # nullptr. + return "nullptr, 0" + + result = [self.val_to_arg_str(x, element_type) for x in val] + if isinstance(element_type, torch.TensorType): + result = [f"{t}.get()" for t in result] + + c_type = self.c_type_for_prim_type(val[0], element_type) + # see the comment in self._generate_temporary_array_pointer for an + # explanation of why this c_type gets modified + if isinstance(element_type, torch.OptionalType) and not c_type.startswith( + "const" + ): + c_type = f"const {c_type}" + + # need to pass the array length, because we can't use the std::array member + # function + return ( + f"{self._generate_temporary_array_pointer(c_type, result)}, {len(val)}" + ) + + val_is_scalar = isinstance(val, (bool, complex, float, int, *SymTypes)) + if isinstance(type_, torch.TensorType) and val_is_scalar: + val_str = self.val_to_arg_str_for_prim_type(val, None) + return self.codegen_scalar_to_tensor(val_str) + + return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var_if_needed( + self, handle: str, writer: Optional[Union[HasWriteLine, list[str]]] = None + ) -> str: + """If the input handle is an rvalue RAII tensor, creates an lvalue variable for + it in writer. Returns a variable name that can be used to access handle.""" + if not handle.startswith( + ( + "borrow_arrayref_tensor_as_tensor(", + "copy_arrayref_tensor_to_tensor(", + "wrap_with_raii_handle_if_needed(", + "RAIIAtenTensorHandle(", + ) + ): + return handle + + tmp_var_name = f"var_{next(self.arg_var_id)}" + call_str = f"auto {tmp_var_name} = {handle};" + + writer = writer if writer is not None else self + if isinstance(writer, list): + writer.append(call_str) + else: + writer.writeline(call_str) + + return tmp_var_name + + def write_kernel_context_guard_begin( + self, + ): + # Beginning of a kernel context guarded block. + # The block looks like this: + # { + # KernelContextGuard _ctx("{kernel_name}", {stack_trace_str}); + # ... operations... + # } + self.writeline("{") + + def write_kernel_context_guard_end( + self, + ): + # End of a kernel context guarded block. + self.writeline("}") + + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + def aggregate_stack_traces( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ) -> OrderedSet[str]: + if isinstance(node_schedule, list): + return functools.reduce( + lambda a, b: a | b, + [ + # pyrefly: ignore [missing-attribute] + node.node.get_stack_traces() + for node in node_schedule + if hasattr(node, "node") and node.node + ], + OrderedSet(), + ) + elif isinstance(node_schedule, ExternKernel): + return node_schedule.get_stack_traces() + else: + return OrderedSet() + + stack_trace_str = 'R"(' + stack_traces = aggregate_stack_traces(node_schedule) + + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});') diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c9aef609ba483ad9178f0653f52a20b1b2ea2f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -0,0 +1,897 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops + +from .. import config, ir +from ..utils import sympy_product +from ..virtualized import V +from .cpp_utils import DTYPE_TO_CPP +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import ( + BufferLike, + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + PythonWrapperCodegen, +) + + +BufferName = str + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class CppWrapperCpuArrayRef(CppWrapperCpu): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + + This class is forked from CppWrapperCpu, with a difference that tensors may be + represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h + """ + + def __init__(self): + super().__init__() + assert self.device == "cpu", "ArrayRefTensor only supported on CPU!" + self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation + self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpuArrayRef() + + @staticmethod + def get_input_cpp_type(input): + assert config.aot_inductor.use_minimal_arrayref_interface + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + @staticmethod + def get_device_include_path(device: str) -> str: + assert device == "cpu", "ArrayRef only supported on CPU!" + if V.graph.aot_mode: + return "#include " + return "#include " + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def generate_extern_kernel_alloc(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_alloc(*args, **kwargs) + + def generate_extern_kernel_out(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_out(*args, **kwargs) + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_fallback_kernel(node) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert not triton, ( + "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" + ) + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});") + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if ( + config.aot_inductor.use_minimal_arrayref_interface + and not V.graph.is_const_graph + ): + input_cpp_types = ", ".join( + f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + + self.generate_input_output_runtime_checks() + run_impl_proto += """ + __check_inputs_outputs(input_handles, output_handles); + """ + + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.aot_inductor.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release_simple release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.aot_inductor.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph + and config.aot_inductor.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + + output2idx: dict[str, int] = {} + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + if output not in output2idx: + output2idx[output] = idx + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.aot_inductor.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.aot_inductor.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + buffer.get_is_pinned(), + ) + + def make_allocation( + self, + name, + device, + dtype, + shape, + stride, + buffer_if_can_stack_allocate=None, + is_pinned=False, + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + pinned_str = "_pinned" if is_pinned else "" + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + # The only way to get into this case is via an exact buffer reuse, since all + # other options result in a new tensor handle. + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} // reuse" + + def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + # Borrowing arguments to shim functions is only safe because we know + # that the arguments can't be stack-allocated. Otherwise, to be sure + # we can't return a dangling pointer, we need to either 1) be + # certain that the shim function cannot return an alias of a + # borrowed argument, or 2) be certain that the returned Tensor from + # the shim function cannot escape. + assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), ( + "borrowing arguments to shim functions is unsafe with " + "stack allocation on! (see comment above this assertion)" + ) + + def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + return not self.allow_stack_allocation and not self.stack_allocated_buffers + + def generate_c_shim_extern_kernel_call( + self, kernel: str, args: list[str], device: str, **_ + ) -> None: + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + # Setting self.allow_stack_allocation to False because the exchange between + # ArrayRefTensor and at::Tensor is still fragile. + self.allow_stack_allocation = False + + wrapped_args = [] + for arg in args: + # We only really *need* borrow_arrayref_tensor_as_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, which + # borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just + # avoid wrapping integers. Name matching is to find tensor is hacky, but + # fixing all the ArrayRefTensor issues is not a priority for now. + if isinstance(arg, str) and arg.startswith( + ("buf", "arg", "wrap_with_raii_handle_if_needed") + ): + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + arg = f"borrow_arrayref_tensor_as_tensor({arg})" + wrapped_args.append(arg) + + super().generate_c_shim_extern_kernel_call( + kernel, wrapped_args, device, debug_args=args + ) + + def generate_scatter_fallback(self, node: ir.ScatterFallback): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + super().generate_scatter_fallback(node) + + def _generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + device, + ): + reduce = self._get_scatter_reduce_enum(reduce) + + # call the ABI shim function instead of the ATen one + self.add_device_include(device) + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device) + + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + inputs_wrapped = [ + (f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x)) + for x in inputs + ] + line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None: + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + super().generate_index_put_fallback(node) + + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", + [f"borrow_arrayref_tensor_as_tensor({i})" for i in indices], + ) + args = [ + f"borrow_arrayref_tensor_as_tensor({x})", + indices_str, + str(len(indices)), + f"borrow_arrayref_tensor_as_tensor({values})", + accumulate, + ] + args.insert( + 0, f"borrow_arrayref_tensor_as_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + super().generate_fallback_kernel_with_runtime_lookup( + buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs + ) + + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): + # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, + # while stack-allocation results in ArrayRefTensor + # so disable stack allocation here + self.allow_stack_allocation = False + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + ) + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + dim = str(len(size)) + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_new_tensor_handle() -> tuple[str, list[str]]: + # Calling reset() on ArrayRefTensor does nothing, since the array is + # const-allocated on the stack. Thus, it's safe to return a reference to + # the original array. + if (name := data.get_name()) in self.stack_allocated_buffers: + return name, [] + + tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + and (dtype is None or dtype == data.dtype) + ): + final_tensor_str, call_strs = create_new_tensor_handle() + for line in call_strs: + writeline(line) + return final_tensor_str + + return super().codegen_reinterpret_view( + data, size, stride, offset, writeline, dtype + ) + + def val_to_arg_str(self, val, type_=None) -> str: + if ( + val is not None + and isinstance(type_, torch.OptionalType) + and isinstance(type_.getElementType(), torch.TensorType) + ): + # Handle optional tensors as a special case, as in the parent class. + base_handle = self.val_to_arg_str(val, torch.TensorType) + if config.aot_inductor.use_minimal_arrayref_interface: + if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(): + base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})" + else: + base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})" + return f"&temporary_reference({base_handle}.get())" + + return super().val_to_arg_str(val, type_) + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..42c082d9d92af7585c1d56dd35a1b79ba55f9ede --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -0,0 +1,891 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import re +import sys +from itertools import count, zip_longest +from typing import Any, Optional, Union +from typing_extensions import Self + +import sympy + +import torch +from torch import dtype as torch_dtype +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.runtime.runtime_utils import dynamo_timed + +from .. import config +from ..codecache import CudaKernelParamCache +from ..ir import ( + GraphPartitionSignature, + TensorBox, + TMADescriptorExperimental, + TMADescriptorStable, +) +from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, TritonScratchWorkspace +from .cpp_utils import cexpr +from .cpp_wrapper_cpu import CppWrapperCpu +from .multi_kernel import MultiKernelCall +from .triton_utils import should_unwrap_unspec_arg +from .wrapper import PythonWrapperCodegen, SymbolicCallArg + + +_cpp_string_literal_escapes = { + "\\": "\\\\", + '"': '\\"', + "\n": "\\n", + "\t": "\\t", + "\r": "\\r", +} +_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]') + + +def cpp_string_literal(s: str) -> str: + escaped = _cpp_string_literal_pattern.sub( + lambda match: _cpp_string_literal_escapes[match.group(0)], s + ) + return f'"{escaped}"' + + +@dataclasses.dataclass +class DeferredTritonCallWrapper: + """ + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred generating the final wrapper around + the triton kernel until right before the prefix is written. + """ + + wrapper_name: str + kernel_name: str + kernel_name_to_body: dict[str, str] + arg_types: list[Any] + + def generate(self, wrapper: CppWrapperGpu): + """ + Generate the GPU kernel definition, as well as load and launch code. + """ + prefix = wrapper.prefix + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + params = CudaKernelParamCache.get(self.kernel_name) + assert params, f"CudaKernelParamCache not populated for {self.kernel_name}" + def_args = params["def_args"] + arg_types = self.arg_types + inductor_meta = params["inductor_meta"] + + if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types): + # extra_launcher_args should already be in def_args + assert len(def_args) == len(arg_types) - len( + inductor_meta["extra_launcher_args"] + ) + arg_types = arg_types + [SymbolicCallArg] * len( + inductor_meta["extra_launcher_args"] + ) + + if not V.graph.aot_mode: + prefix.writeline( + maybe_hipify_code_wrapper( + f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;" + ) + ) + kernel_var_name = self.kernel_name + else: + kernel_var_name = f"kernels_.{self.kernel_name}" + + # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types + template_types = [ + f"typename {name}_type_" + for name, arg_type in zip(def_args, arg_types) + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)) + ] + if V.graph.aot_mode: + template_types.append("typename kernels_type_") + if template_types: + prefix.writeline(f"template <{', '.join(template_types)}>") + prefix.writeline(f"static inline void {self.wrapper_name}(") + with prefix.indent(): + assert len(def_args) == len(arg_types), (def_args, arg_types) + for name, arg_type in zip(def_args, arg_types): + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)): + prefix.writeline(f"const {name}_type_& {name},") + elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)): + prefix.writeline(f"int64_t {name},") + elif arg_type is float: + prefix.writeline(f"float {name},") + elif arg_type is bool: + prefix.writeline(f"bool {name},") + else: + raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline("int32_t device_idx_,") + prefix.writeline( + maybe_hipify_code_wrapper( + f"{wrapper.device_codegen.cpp_stream_type()} stream_," + ) + ) + if V.graph.aot_mode: + prefix.writeline("kernels_type_& kernels_,") + prefix.writeline( + "const std::optional& cubin_dir_ = std::nullopt" + ) + prefix.writeline("){") + with prefix.indent(): + if V.graph.aot_mode: + # Emit the original Triton kernel for debugging purposes + prefix.writeline("/*") + prefix.splice(self.kernel_name_to_body[self.kernel_name]) + prefix.writeline("*/") + self.generate_grid(prefix, inductor_meta, params) + self.generate_load_kernel(prefix, kernel_var_name, params) + self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) + prefix.writeline("}") + + if not config.aot_inductor.embed_kernel_binary: + # Ensure the cubin file is included in the package + V.graph.wrapper_code.additional_files.append( + params[get_cpp_wrapper_cubin_path_name()] + ) + + def generate_grid( + self, + prefix: IndentedBuffer, + inductor_meta: dict[str, Any], + params: dict[str, Any], + ): + from ..runtime.triton_heuristics import GridExpr + + grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp") + for line in grid.prefix: + prefix.writeline(line) + prefix.splice( + f"""\ + uint32_t grid_0 = {grid.x_grid}; + uint32_t grid_1 = {grid.y_grid}; + uint32_t grid_2 = {grid.z_grid}; + """ + ) + prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;") + + def generate_load_kernel(self, prefix, kernel_var_name, params): + prefix.writeline(f"if ({kernel_var_name} == nullptr) {{") + with prefix.indent(): + embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"] + if torch.xpu.is_available(): + # XPU needs the end address of the kernel to calculate the size of the kernel binary. + embed_kernel_args.append( + f"__{params['inductor_meta']['kernel_name']}_end" + ) + + load_kernel_args = ( + [ + *embed_kernel_args, + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + ] + if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary + else [ + cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]), + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + "cubin_dir_", + ] + ) + prefix.writeline( + f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); " + ) + prefix.writeline("}") + + def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): + """ + Generate the GPU kernel launching code. + This is where all the call args being sorted out and generated. + If enable_kernel_profile is enabled, all args related information would be packed in this function. + """ + triton_meta = params["triton_meta"] + assert len(self.arg_types) == len(params["def_args"]), ( + self.arg_types, + params["def_args"], + ) + arg_type_loookup = dict(zip(params["def_args"], self.arg_types)) + # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants + call_args = [ + name for name in params["call_args"] if name not in triton_meta["constants"] + ] + arg_types = [arg_type_loookup[name] for name in call_args] + arg_signatures = [triton_meta["signature"][name] for name in call_args] + scratch_spaces = { + name: params[name] + for name in ["global_scratch", "profile_scratch"] + if params.get(name, None) is not None + } + call_args_str = wrapper.generate_args_decl( + prefix, + call_args, + arg_types, + arg_signatures, + scratch_spaces=scratch_spaces, + ) + prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") + launch_kernel_args = [ + kernel_var_name, + "grid_0", + "grid_1", + "grid_2", + str(params["num_warps"]), + str(params["shared_mem"]), + "kernel_args_", + "stream_", + ] + if wrapper.device == "xpu": + launch_kernel_args.append(str(params["threads_per_warp"])) + + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + normalized_kernel_name = re.sub(r"[^a-zA-Z0-9_]", "_", f"{kernel_var_name}") + prefix.writeline("{") + with prefix.indent(): + prefix.writelines( + [ + f"std::unordered_map kwargs_{normalized_kernel_name};", + "", + ] + ) + # Add launch args info + record_launch_kernel_args = [ + ("grid_0", "grid_0"), + ("grid_1", "grid_1"), + ("grid_2", "grid_2"), + ("num_warps", str(params["num_warps"])), + ("shared_mem", str(params["shared_mem"])), + ] + for k, v in record_launch_kernel_args: + arg_name = f"{normalized_kernel_name}_{k}" + prefix.writelines( + [ + f"// Create c10::IValue for {k}", + f"C10IValueHandle tmp_{arg_name};", + f"aoti_torch_int64_to_ivalue({v}, &tmp_{arg_name});", + f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", + f'kwargs_{normalized_kernel_name}.emplace("{k}", RAII_{arg_name});', + ] + ) + + # Add input info (This copies the logic from args_decl) + signature2dtype = { + "i32": "int32_t", + "i64": "int64_t", + "fp32": "float", + } + + def signature_is_tma_desc(sig): + if not sig: + return False + if sig == "nvTmaDesc": + return True + if sig.startswith("tensordesc<"): + return True + return False + + curr_arg_id = -1 + total_args = [] + ordered_argsname = [] + + def write_dummy_scalar_ivalue(arg_name): + # We only care about the shape, therefore we create a dummy scalar here. + prefix.writelines( + [ + f"// Create c10::IValue for arg_{curr_arg_id}", + f"C10IValueHandle tmp_{arg_name};", + f"aoti_torch_int64_to_ivalue(0, &tmp_{arg_name});", + f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", + ] + ) + # pyrefly: ignore [bad-argument-type] + total_args.append(f"tmp_{arg_name}") + + def process_args_for_input_shape(arg, arg_type, arg_signature=None): + nonlocal curr_arg_id + curr_arg_id += 1 + arg_name = f"{normalized_kernel_name}_arg_{curr_arg_id}" + # ignore tma descriptors, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance( + arg_type, UnwrapUnspecArg + ) and not signature_is_tma_desc(arg_signature): + write_dummy_scalar_ivalue(arg_name) + elif isinstance( + arg_type, torch_dtype + ) and not signature_is_tma_desc(arg_signature): + # This is an at::Tensor. + prefix.writelines( + [ + f"// Create c10::IValue for arg_{curr_arg_id}", + f"C10IValueHandle tmp_{arg_name};", + f"aoti_torch_tensor_to_ivalue({arg}, &tmp_{arg_name});", + f"RAIIC10IValueHandle RAII_{arg_name}(tmp_{arg_name});", + ] + ) + # pyrefly: ignore [bad-argument-type] + total_args.append(f"tmp_{arg_name}") + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype + ) or arg_type in (sympy.Integer, int, sympy.Float, float): + write_dummy_scalar_ivalue(arg_name) + elif arg_signature and arg_signature.startswith("tensordesc<"): + # Skip tma related args + pass + else: + write_dummy_scalar_ivalue(arg_name) + + # Add input name and shape information + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + # pyrefly: ignore [bad-argument-type] + ordered_argsname.append(f'"{arg}"') + process_args_for_input_shape(arg, arg_type, arg_signature) + + # Add input name into kwargs + name_var = f"{normalized_kernel_name}_input_names" + prefix.writelines( + [ + "// Create c10::IValue for input names", + f"C10IValueHandle tmp_{name_var};", + f"std::vector {name_var}({{{', '.join(ordered_argsname)}}});", + f"aoti_torch_strlist_to_ivalue({name_var}.data(), {len(ordered_argsname)}, &tmp_{name_var});", + f"RAIIC10IValueHandle RAII_{name_var}(tmp_{name_var});", + f'kwargs_{normalized_kernel_name}.emplace("Input Args", RAII_{name_var});', + ] + ) + + inputs_info_ = f"{normalized_kernel_name}_inputs_info_" + # We pass in the non-RAII handles, since C10 doesn't automatically free them. + # The RAII will make sure they get freed when they are out of scope. + tmp_args = ",".join(total_args) + prefix.writelines( + [ + "// Aggregate all c10::IValue for inputs", + f"std::vector {inputs_info_}({{{tmp_args}}});", + ] + ) + + # Start recording Function + prefix.writelines( + [ + "", + ( + "torch::aot_inductor::RAIIAtenRecordFunctionHandle " + f"record_{normalized_kernel_name}_" + f'("{kernel_var_name}", ' + f"reinterpret_cast(&kwargs_{normalized_kernel_name}), " + f"{inputs_info_});" + ), + "", + f"launchKernel({', '.join(launch_kernel_args)});", + ] + ) + prefix.writeline("}") + else: + prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});") + + +class CppWrapperGpu(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = get_gpu_type() + self.device_codegen = get_device_op_overrides(self.device) + super().__init__() + self.grid_id = count() + self._kernel_name_to_body: dict[str, str] = {} + self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} + self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperGpu() + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) + + @cache_on_self + def write_tma_descriptor_helpers_once(self): + self.header.splice(self.device_codegen.tma_descriptor_helpers()) + + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + name = f"stream{device_idx}" + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));" + ) + return name + + def get_autotuning_input_name(self, idx): + return f"{self.autotune_input_prefix}_{idx}" + + def codegen_inputs(self): + # See Note: [Input Alignment handling in Inductor] + # + # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to + # copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp. + + if config.is_fbcode(): + # TODO: This is added because FC. Remove this once the newly added shim symbols, + # e.g. aoti_torch_clone_preserve_strides, have landed + return super().codegen_inputs() + + if V.graph.aot_mode and V.graph.inputs_to_check: + for idx in V.graph.inputs_to_check: + input_name = V.graph.graph_input_names[idx] + assert input_name in V.graph.graph_inputs, ( + f"{input_name} not found in graph inputs" + ) + value = V.graph.graph_inputs[input_name] + assert isinstance(value, TensorBox), ( + f"{input_name} is expected to be tensor but found as {type(value)}" + ) + warn_msg = ( + f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, " + "but it is not aligned at run time. Copying to an aligned tensor " + "to guarantee correctness, but expect a performance hit." + ) + self.prefix.splice( + f""" + if ((reinterpret_cast({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{ + AOTI_TORCH_WARN("{warn_msg}"); + AtenTensorHandle {input_name}_aligned; + aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned); + {input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned)); + }} + """ + ) + + super().codegen_inputs() + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if gpu: + self._kernel_name_to_body[kernel_name] = kernel_body + if config.triton.autotune_at_compile_time: + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + else: + return CppWrapperCpu._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + + def generate(self, is_inference): + with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True): + return super().generate(is_inference) + + def finalize_prefix(self): + """Define the triton kernels now that autotuning is finished""" + old_prefix = self.prefix # new content should go at start of prefix + + # Generating triton kernel callers can modify the prefix (cached dtypes), + # so do this before running finalize_prefix(), but put the generated code + # after the finalize_prefix() code. + self.prefix = IndentedBuffer() + for kernel in self._triton_call_wrappers.values(): + self.prefix.writeline("\n") + kernel.generate(self) + triton_prefix = self.prefix + + self.prefix = IndentedBuffer() + super().finalize_prefix() + + self.prefix.splice(triton_prefix) + + self.prefix.writeline("\n") + self.prefix.splice(old_prefix) + + def generate_tma_descriptor(self, desc): + self.write_tma_descriptor_helpers_once() + + if isinstance(desc, TMADescriptorExperimental): + self._generate_experimental_tma_descriptor(desc) + else: + assert isinstance(desc, TMADescriptorStable) + self._generate_stable_tma_descriptor(desc) + + def _generate_experimental_tma_descriptor(self, desc): + # generate data pointer for the source tensor + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + self.writeline(f"alignas(64) CUtensorMap {desc_name};") + + # `source` is in the form of `&var_x`, where `var_x` is the data pointer + # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to + # the data pointer of the source tensor to the helper function + # `init{1,2}DTMADescriptor` + ptr = f"reinterpret_cast(*({source}))" + dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims) + block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims) + element_size = self.val_to_arg_str(desc.element_size) + fn = f"init{desc.rank}DTMADescriptor" + args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}" + self.writeline(f"{fn}({args});") + + def _generate_stable_tma_descriptor(self, desc): + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + # Pack the relevant information into a StableTMADescriptor struct. + # See [Note: AOTI TMA Stable handling] for more details. + self.writeline(f"alignas(64) StableTMADescriptor {desc_name};") + + def fill_array(name, values): + for i, val in enumerate(values): + self.writeline(f"{name}[{i}] = {val};") + + ptr = f"reinterpret_cast(*({source}))" + rank = len(desc.tensor.get_size()) + + fill_array(f"{desc_name}.block_shape", desc.block_shape) + fill_array(f"{desc_name}.global_shape", desc.tensor.get_size()) + fill_array(f"{desc_name}.strides", desc.tensor.get_stride()) + + element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize) + fn = "initTMADescriptor" + args = ", ".join( + str(x) + for x in [ + f"&{desc_name}.m", + ptr, + element_size, + rank, + f"{desc_name}.block_shape", + f"{desc_name}.global_shape", + f"{desc_name}.strides", + ] + ) + self.writeline(f"{fn}({args});") + + def generate_args_decl( + self, + code: Union[IndentedBuffer, Self], + call_args, + arg_types, + arg_signatures, + is_triton_kernel=True, + scratch_spaces: Optional[dict[str, int]] = None, + ): + """ + Generates any declarations of args to pass into a kernel call, and then returns the arg names. + + In more detail: + * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;` + * returns: a string with the list of args, e.g. "var_0, var_1" + + call_args: list of call arguments + arg_types: list of argument types + arg_signatures: list with signatures of all the args + is_triton_kernel: whether these are passed into a triton kernel or not. In particular, + calls to triton kernels will have an additional global scratch space + arg injected at the front of the arg list. + """ + new_args: list[str] = [] + + # Add more cases for other types as needed + signature2dtype = { + "i32": "int32_t", + "i64": "int64_t", + "fp32": "float", + } + + def signature_is_tma_desc(sig): + if not sig: + return False + if sig == "nvTmaDesc": + return True + if sig.startswith("tensordesc<"): + return True + return False + + def process_tma_stable_arg(arg, arg_type, arg_signature, var_name): + # [Note: AOTI TMA Stable handling] + # For most args, a single arg passed to the python triton interface + # maps to a single arg in the cubin interface. However, for host-side + # TMA descriptors, a single python arg turns into 1 + 2 * N args in the + # cubin interface (where N is the rank). + # + # To do this: at TMA codegen time (for aoti), we generate a struct + # (StableTMADescriptor) containing the necessary information; and then + # when we call the function (i.e. here), we unpack the struct members. + code.writeline(f"auto {var_name} = {cexpr(arg)};") + + result = [] + result.append(f"&{var_name}.m") + + # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950 + match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature) + assert match is not None + shape = match.group(2) + ndim = shape.count(",") + 1 + + for i in range(ndim): + result.append(f"&{var_name}.block_shape[{i}]") + + for i in range(ndim): + result.append(f"&{var_name}.strides[{i}]") + + return result + + def process_args(arg, arg_type, arg_signature=None): + var_name = f"var_{next(self.arg_var_id)}" + # ignore tma descriptors, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc( + arg_signature + ): + self.codegen_tensor_item( + arg_type.dtype, + arg, + var_name, + indented_buffer=code, + ) + new_args.append(f"&{var_name}") + elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc( + arg_signature + ): + device_ptr_type = self.device_codegen.cpp_device_ptr() + code.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" + ) + ) + new_args.append(f"&{var_name}") + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype + ): + code.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" + ) + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Integer, int): + code.writeline(f"int {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Float, float): + code.writeline(f"float {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + elif arg_signature and arg_signature.startswith("tensordesc<"): + new_args.extend( + process_tma_stable_arg(arg, arg_type, arg_signature, var_name) + ) + else: + code.writeline(f"auto {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + + for scratch_name, workspace_size in (scratch_spaces or {}).items(): + if ( + is_triton_kernel + and ( + scratch := self.device_codegen.cpp_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=( + lambda: self.codegen_dtype(torch.uint8) + ), + ), + prefix=scratch_name, + ) + ) + is not None + ): + scratch_def, scratch_var = scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def]) + new_args.append(f"&{scratch_var}") + + return ", ".join(new_args) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Override the default value of argument 'gpu' to True here. + generate_kernel_call can still be called with gpu=False because of + a mix of cpu kernels and gpu kernels. + """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + if ( + triton + and config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + original_fxnode_name=original_fxnode_name, + ) + + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device.index, graph_name) + ) + + if triton: + call_args, arg_types = self.prepare_triton_wrapper_args( + call_args, + # pyrefly: ignore [bad-argument-type] + arg_types, + ) + wrapper_name = f"call_{kernel_name}" + if wrapper_name not in self._triton_call_wrappers: + self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper( + wrapper_name, + kernel_name, + self._kernel_name_to_body, + arg_types, + ) + device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index) + call_args.append(device_idx) + call_args.append(stream) + if V.graph.aot_mode: + call_args.append("kernels") + call_args.append("this->cubin_dir_") + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[: len(arg_types)], kernel_name, arg_types, None + ) + with debug_printer_manager: + self.writeline(f"{wrapper_name}({', '.join(call_args)});") + else: + casted = [] + # pyrefly: ignore [no-matching-overload] + for arg_type, arg in zip(arg_types, call_args): + new_arg = arg + if arg_type.endswith("*") and arg != "nullptr": + new_arg = f"{arg}.data_ptr()" + # pyrefly: ignore [bad-argument-type] + casted.append(f"({arg_type}){cexpr(new_arg)}") + call_args_str = ", ".join(casted) + self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") + + @staticmethod + def prepare_triton_wrapper_args( + call_args: list[Any], arg_types: list[Any] + ) -> tuple[list[Any], list[Any]]: + assert len(call_args) == len(arg_types), (call_args, arg_types) + new_args = [] + new_args_types = [] + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg, str): + if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + arg_type = UnwrapUnspecArg(dtype=arg_type) + new_args.append(arg) + elif isinstance(arg, bool): + new_args.append(str(arg).lower()) + elif isinstance(arg, (int, float, SymbolicCallArg)): + new_args.append(str(arg)) + else: + new_args.append(cexpr(V.graph.sizevars.simplify(arg))) + new_args_types.append(arg_type) + return new_args, new_args_types + + def make_zero_buffer(self, name): + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" + + +@dataclasses.dataclass +class UnwrapUnspecArg: + """Marker that we need to call .item() on the tensor""" + + dtype: torch_dtype diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5638f37b7856927f612a37c867c46c3e76785e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -0,0 +1,301 @@ +from typing import Any, Optional + +import sympy + +import torch +from torch.utils._ordered_set import OrderedSet + +from ..ir import GraphPartitionSignature +from ..virtualized import V +from .cpp_wrapper_cpu import CppWrapperCpu +from .cpp_wrapper_gpu import CppWrapperGpu +from .wrapper import KernelCallLine, PythonWrapperCodegen + + +class CppWrapperMps(CppWrapperGpu): + """ + Generates cpp wrapper for running on MPS and calls metal kernels + """ + + def __init__(self) -> None: + super().__init__() + self._used_kernel_names: OrderedSet[str] = OrderedSet() + self._lambda_counter: int = 0 + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> "CppWrapperMps": + return CppWrapperMps() + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args: list[str], + *, + device: Optional[torch.device] = None, + triton: bool = True, + arg_types: Optional[tuple[Any, ...]] = None, + raw_keys: Optional[tuple[Any, ...]] = None, + raw_args: Optional[tuple[Any, ...]] = None, + triton_meta: Optional[dict[str, Any]] = None, + graph_name: str = "", + original_fxnode_name: Optional[str] = None, + ) -> None: + """ + Generates MPS kernel call code. It should look something like: + ``` + auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) { + aoti_torch_mps_start_encoding(handle); + aoti_torch_mps_set_arg_tensor(handle, 0, buf0); + aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1); + aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1); + aoti_torch_mps_dispatch_single(handle, static_cast(10LL)); + }; + + std::function mps_lib_0_func_wrapper = mps_lib_0_lambda; + aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper); + ``` + """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + assert device.type == "mps" + + assert arg_types is not None + + new_args = [] + for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): + if isinstance(arg_type, torch.dtype): + new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});") + elif arg_type in (int, sympy.core.symbol.Symbol): + new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});") + else: + raise NotImplementedError( + f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}" + ) + + threads, group_size = call_args[-2], call_args[-1] + if threads is None: + raise NotImplementedError("No threads or group_size provided") + + # Check if threads is a single value or an array-like structure + threads_str = str(threads) + is_single_value = ( + threads_str.startswith("{") + and threads_str.endswith("}") + and threads_str.count(",") == 0 + ) or not threads_str.startswith(("{", "[")) + + if is_single_value: + # Extract single value from braces if present + if threads_str.startswith("{") and threads_str.endswith("}"): + single_value = threads_str[1:-1].strip() # Remove braces + else: + single_value = threads_str + + if group_size is None: + new_args.append( + f"aoti_torch_mps_dispatch_single(handle, {single_value});" + ) + else: + # Extract group size value if it's also in braces + group_size_str = str(group_size) + if group_size_str.startswith("{") and group_size_str.endswith("}"): + group_size_value = group_size_str[1:-1].strip() + else: + group_size_value = group_size_str + new_args.append( + f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});" + ) + else: + # Handle array case - need to convert initializer list to array + # Use kernel name to make variable names unique + threads_var = f"{kernel_name}_threads_array" + group_size_var = f"{kernel_name}_group_size_array" + + # Extract array size from the initializer list string + def get_array_size(array_str: str) -> int: + # Remove braces and whitespace + content = array_str.strip() + if content.startswith("{") and content.endswith("}"): + content = content[1:-1].strip() + + if not content: # Empty array + return 0 + + # Count elements by counting commas, accounting for nested structures + depth = 0 + comma_count = 0 + for char in content: + if char in "({[<": + depth += 1 + elif char in ")}]>": + depth -= 1 + elif char == "," and depth == 0: + comma_count += 1 + + return comma_count + 1 # Number of elements = commas + 1 + + threads_size = get_array_size(threads_str) + + if group_size is None: + new_args.append("{") + new_args.append(f" uint64_t {threads_var}[] = {threads};") + new_args.append( + f" aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});" + ) + new_args.append("}") + else: + group_size_str = str(group_size) + group_size_size = get_array_size(group_size_str) + new_args.append("{") + new_args.append(f" uint64_t {threads_var}[] = {threads};") + new_args.append(f" uint64_t {group_size_var}[] = {group_size};") + dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}" + new_args.append( + f" aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});" + ) + new_args.append("}") + + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[:-2], + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.write_mps_kernel_call(kernel_name, new_args) + + def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: + # Generate unique variable names to avoid duplicate declarations + # when the same MPS lib is used multiple times + unique_suffix = self._lambda_counter + self._lambda_counter += 1 + + lambda_name = f"{name}_lambda_{unique_suffix}" + wrapper_name = f"{name}_func_wrapper_{unique_suffix}" + + # Generate the function call code (in current location) + # Create lambda that captures by reference and pass its pointer through void* + self.writeline( + f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{" + ) + self.writeline(" aoti_torch_mps_start_encoding(handle);") + + # Output call args directly since we're capturing by reference + for call_arg in call_args: + self.writeline(f" {call_arg}") + self.writeline("};") + self.writeline("") + + # Pass lambda pointer through void* + self.writeline( + f"std::function {wrapper_name} = {lambda_name};" + ) + self.writeline( + f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});" + ) + + @staticmethod + def get_device_include_path(device: str) -> str: + assert V.graph.aot_mode + return ( + "#include \n" + "#include " + ) + + def codegen_additional_funcs(self) -> None: + """ + Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup. + + The generated code will look like: + ``` + AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { + static auto kernel_handle = []() { + AOTIMetalShaderLibraryHandle lib_handle = nullptr; + AOTIMetalKernelFunctionHandle kern_handle = nullptr; + + aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle); + aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle); + + // RAII wrapper with custom deleter + auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) { + if (h) aoti_torch_mps_delete_shader_library(h); + }; + + using LibDeleter = decltype(lib_deleter); + using LibPtr = std::unique_ptr; + + // Return pair of kernel handle and library smart pointer for cleanup + return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter)); + }(); + return kernel_handle.first; + } + ``` + """ + + # Add shimified handles and functions + shader_libraries: OrderedSet[str] = OrderedSet() + for line in self.lines: + if not isinstance(line, KernelCallLine): + continue + if line.device.type != "mps": + continue + + # Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls) + if line.kernel_name not in self._used_kernel_names: + self._used_kernel_names.add(line.kernel_name) + shader_libraries.add(line.kernel_name) + + # NOTE: For shimified version, we expect the shader source constant to be generated + # by the existing MPS shader generation process, but instead of instantiating the + # DynamicMetalShaderLibrary directly, we'll use our shim functions. + # The existing codegen should produce something like: + # const char* mps_lib_0_source = R"MTL(...shader_source...)MTL"; + # instead of: + # at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL"); + + # Generate thread-safe lazy singleton with RAII for each library + for lib_name in shader_libraries: + self.prefix.splice(f""" +AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{ + static auto kernel_handle = []() {{ + AOTIMetalShaderLibraryHandle lib_handle = nullptr; + AOTIMetalKernelFunctionHandle kern_handle = nullptr; + + aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle); + aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle); + + // RAII wrapper with custom deleter + auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{ + if (h) aoti_torch_mps_delete_shader_library(h); + }}; + + using LibDeleter = decltype(lib_deleter); + using LibPtr = std::unique_ptr; + + // Return pair of kernel handle and library smart pointer for cleanup + return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter)); + }}(); + return kernel_handle.first; +}} +""") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..ccada837abbd4dbdaf16984c0a44ff7f90cedc04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from textwrap import dedent + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class CpuDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def cpp_kernel_type(self) -> str: + return "void*" + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + + +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67536bc859127ccadce9292bc883c0b1e18687ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6292e7972d1882b1175897fd556eacc08b327b8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d6f0fc6bab3964bff149773be0a377957d86698 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80cacc1c641fd3478ada04033dfc0180adc483ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af69372cd3543abc07fd74bbf736a08cd54cae4f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dea14a07132df67ad5a07edb60422ff9a6ee2ab8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb4ce6dcc1ac221f8654a8446084751e6ccce643 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..2496860ca1f7c72eadd86f908384e2f81983af4f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,296 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + MockCutlassHandler, +) +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, + WhyNoFuse, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class WhyNoFuseNames(WhyNoFuse): + def __init__(self, name1: str, name2: str) -> None: + self.name1 = name1 + self.name2 = name2 + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode): + assert node1.node, "node1.node should not be None" + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), + [], + node2, # type: ignore[arg-type] + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, BaseSchedulerNode + ): + assert node1.node, "node1.node should not be None" + assert node2.node, "node2.node should not be None" + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node(), # type: ignore[arg-type] + self._unwrap_epilogue_nodes(fnode1), + node2, # type: ignore[arg-type] + ) + + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + # use the original src_code as the key + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + # no EVT kernel, use the original kernel name + kernel_name = f"cutlass_{kernel_hash}" + else: + kernel_name = f"cutlass_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template(template_node), ( + "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc] + assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + kernel, render = ctb.make_kernel_render( # type: ignore[misc] + ctb, epilogue_nodes=epilogue_nodes + ) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + # typically there is a codegen pass which runs after mark_run + # for this kernel we've already generated the C++ code, but we still + # need to let the kernel know about loads/stores that occur in the fused + # kernel for memory planning to properly optimize allocations + ctb.emulate_store_fn() + for node in epilogue_ir_nodes: + with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())): + assert isinstance( + node, ComputedBuffer + ) # Not sure why we need to do this again + node.get_store_function()(CutlassEVTCodegen.get_index_vars(node)) + + with V.set_kernel_handler(kernel): + src_code = render() + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + @staticmethod + def _unwrap_epilogue_nodes( + fused_node: FusedSchedulerNode, + ) -> list[BaseSchedulerNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + assert all(n.node is not None for n in nodes), ( + "All epilogue nodes should have an IRNode" + ) + # pyrefly: ignore [redundant-cast] + return cast( + list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + existing_epilogue_nodes: list[BaseSchedulerNode], + node_to_fuse: BaseSchedulerNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes. + node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name()) + + scheduler_nodes_to_fuse = node_to_fuse.get_nodes() + + assert isinstance(cuda_template_buffer, CUDATemplateBuffer) + + # Checks on constituent nodes + for s_node in scheduler_nodes_to_fuse: + node = s_node.node + + if not isinstance(node, ComputedBuffer): + why(f"{node} is not a ComputedBuffer") + return False + elif not isinstance(node.data, Pointwise): + why(f"{node} is not a Pointwise op") + return False + elif not node.get_computed_buffer_name(): # type: ignore[attr-defined] + why(f"{node} does not have a computed buffer name") + return False + + name = node.get_computed_buffer_name() # type: ignore[attr-defined] + # dtype can differ, and strides can differ as long as they are broadcastable + if node.get_size() != cuda_template_buffer.get_size(): + why( + f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \ +size: {cuda_template_buffer.get_size()}" + ) + return False + + assert len( + existing_epilogue_nodes + ) or cuda_template_buffer.get_name() in OrderedSet( + [rd.name for rd in node_to_fuse.read_writes.reads] + ), "First epilogue node must read from cuda template buffer" + + if node_to_fuse.has_aliasing_or_mutation(): + why(f"{node_to_fuse.get_name()} has aliasing or mutation") + return False + elif node_to_fuse.is_reduction(): + why( + f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT" + ) + return False + elif ( + not config.cuda.cutlass_epilogue_fusion_enabled + or not config.epilogue_fusion + ): + why("cutlass epilogue fusion is not enabled") + return False + elif not cuda_template_buffer.supports_epilogue_fusion: + why("epilogue fusion is only supported for TMA-enabled gemm ops") + return False + + try: + from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + ) + + CutlassEVTCodegen.ir_to_evt_python_code( + cuda_template_buffer.get_name(), + existing_epilogue_nodes + list(node_to_fuse.get_nodes()), + OrderedSet(), + ) + + except NotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \ +likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: # Likely due to unsupported dtype. + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \ +Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + + return True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca3afbd9ca57a7aed17ffe69d074c667dd2c09f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,55 @@ +import functools +import logging +import shutil +from typing import Optional + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config + + +log = logging.getLogger(__name__) + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception: + log.exception("Error getting cuda arch") + return None + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def is_datacenter_blackwell_arch() -> bool: + arch = get_cuda_arch() + if arch is None: + return False + arch_number = int(arch) + return arch_number >= 100 and arch_number < 110 + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception: + log.exception("Error getting cuda version") + return None + + +@functools.cache +def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: + return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..97643ef00a7bd63aa6887c5ee6645f1c788e45fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,687 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal, Optional, TYPE_CHECKING, Union + +from sympy import Expr, symbols + +import torch._inductor.config as config +from torch import dtype as torch_dtype +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder +from torch.utils._sympy.value_ranges import ValueRanges + +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from .cuda_template import ArgInfo + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + ShapeAsConstantBuffer, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import ( + CSEVariable, + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] +ValidLayoutAttrs = Literal["size", "stride"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) + self.size_args: list[Union[Expr, int]] = [] + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [ + arg + for arg in itertools.chain.from_iterable(self.layout_args.values()) + if arg.matches(node, attr, dim) + ] + if len(matches) >= 1: + # Verify all matches have the same node, attribute, and dimension + # And if they come from the same node, whichever symbol we use is fine. + # if in runtime the logic changes, this would trigger guard + first_match = matches[0] + if not all( + match.node == first_match.node + and match.attr == first_match.attr + and match.dim == first_match.dim + for match in matches + ): + raise AssertionError("All matching layout args should be identical") + return first_match + return None + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args[symbol].append(arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + x_mdim = _normalize_idx(-2, len(X.get_size())) + x_kdim = _normalize_idx(-1, len(X.get_size())) + w_kdim = _normalize_idx(-2, len(W.get_size())) + w_ndim = _normalize_idx(-1, len(W.get_size())) + y_mdim = _normalize_idx(-2, len(Y.get_size())) + y_ndim = _normalize_idx(-1, len(Y.get_size())) + self.add_layout_arg("M", X, "size", x_mdim) + self.add_layout_arg("K", X, "size", x_kdim) + self.add_layout_arg("K", W, "size", w_kdim) + self.add_layout_arg("N", W, "size", w_ndim) + self.add_layout_arg("M", Y, "size", y_mdim) + self.add_layout_arg("N", Y, "size", y_ndim) + if len(X.get_size()) > 2: + self.add_layout_arg("B", X, "size", 0) + + lda_dim = self.find_ld_idx(X) + ldb_dim = self.find_ld_idx(W) + ldc_dim = self.find_ld_idx(Bias) if Bias else None + ldd_dim = self.find_ld_idx(Y) + self.add_layout_arg("lda", X, "stride", lda_dim) + self.add_layout_arg("ldb", W, "stride", ldb_dim) + if Bias is not None and ldc_dim is not None: + self.add_layout_arg("ldc", Bias, "stride", ldc_dim) + self.add_layout_arg("ldd", Y, "stride", ldd_dim) + + def get_layout_args(self) -> tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + def get_ld(node) -> Union[Expr, int]: + dim = self.find_ld_idx(node) + return node.get_stride()[dim] + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + B = X.get_size()[0] if len(X.get_size()) > 2 else 1 + LDA = get_ld(X) + LDB = get_ld(W) + LDC = get_ld(Bias) if Bias else 0 + LDD = get_ld(Y) + return (M, N, K, B, LDA, LDB, LDC, LDD) + + def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: + return [*self.get_layout_args(), *self.size_args] + + def get_offset_args(self) -> list[Expr]: + return [node.get_layout().offset for node in self.named_nodes.values()] + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + additional_size_args: Additional size arguments for epilogue inputs + """ + # NB: name order matters here, it's used to match up offsets + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + free_symbols: OrderedSet[Expr] = OrderedSet() + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + # NB: named nodes must be populated in the order of names + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + if name not in ( + "X", + "W", + "Bias", + "Y", + ): # we handle these symbolic shapes explicitly + for expr in itertools.chain(node.get_size(), node.get_stride()): + if isinstance(expr, Expr): + for s in expr.free_symbols: + free_symbols.add(s) # type: ignore[arg-type] + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + + self.init_layout_args() + size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] + size_vars.extend(str(s) for s in free_symbols) + self.size_args.extend(free_symbols) + size_args = [f"const int {s}" for s in size_vars] + offset_args = [f"const int {name}_offset" for name in self.named_nodes] + runtime_arg_decls = ",".join( + [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + ) + if runtime_arg_decls: + runtime_arg_decls += ", " + + signature = ( + f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\ + {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + ) + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "CUDATemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # We always originally initialize name with "KERNEL_NAME". So, we + # we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + dynamic_shape_args = self.get_dynamic_shape_args() + offset_args = self.get_offset_args() + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + call_args.extend(offset_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(str(arg)) + arg_types.extend("const int" for _ in dynamic_shape_args) + arg_types.extend("const int" for _ in offset_args) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = ( + call_args[i] + if V.graph.cpp_wrapper + else f"c_void_p({call_args[i]}.data_ptr())" + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append( + workspace + if V.graph.cpp_wrapper + else f"c_void_p({workspace}.data_ptr())" + ) + else: + ws = None + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + + wrapper.generate_kernel_call( + name, + call_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + return f"{arg_name} + {arg_name}_offset" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + def batch_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the batch stride of an arg. + Returns 0 if batch dim is not present. + + This method assumes that batch stride is the largest stride. + """ + + if node is None: + return str(default_value) + + if len(node.get_size()) < 3: + return str(default_value) + + batch_stride = node.get_stride()[0] + if V.graph.sizevars.statically_known_leq(batch_stride, 1): + return str(batch_stride) + + return "{}*{}".format( + self.find_symbol(node, "size", dim=1) or node.get_size()[1], + self.find_symbol(node, "size", dim=2) or node.get_size()[2], + ) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable: + """ + Mock load function for memory planning to optimize allocations properly. + """ + return self.create_cse_var(name, bounds=ValueRanges.unknown()) + + def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None: + """ + Mock store function for memory planning to optimize allocations properly. + """ + self.store_buffer_names.add(name) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]], + tuple[CUDATemplateKernel, functools.partial[str]], + ], + bmreq: CUDABenchmarkRequest, + supports_epilogue_fusion: bool, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.supports_epilogue_fusion = supports_epilogue_fusion + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def kernel_hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + swizzle_str: str = ( + str(self.info_kwargs.get("swizzle")) + if isinstance(self.info_kwargs, dict) + else "None" + ) + return "-".join( + [ + self.category, + self.bmreq.hash_key, + swizzle_str, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """ + Information returned here is logged to the autotune log file when that is enabled. + + In general, we should avoid calling this function as it is expensive to compute, + and can add up very fast. + """ + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + "swizzle": str(self.info_kwargs["swizzle"]), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + self.bmreq.update_workspace_size() + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + supports_epilogue_fusion=self.supports_epilogue_fusion, + template=self.template, + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..79dfa9c6c391fe10ce2c4a657aea83b1639f4f5d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,394 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import itertools +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import override +from unittest.mock import patch + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.utils import clear_on_fresh_cache, Placeholder +from torch._logging import getArtifactLogger + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from ...scheduler import BaseSchedulerNode # noqa: TC004 +else: + BaseSchedulerNode = Any + +GemmOperation = Any + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +@clear_on_fresh_cache +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + # dict of cache key to (code, size_args) + code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {} + cache_clear = staticmethod(code_cache.clear) + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + Baseclass for CUDA C++ Templates, derived from KernelTemplate. + Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies + the order of the input nodes. + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + @classmethod + @functools.lru_cache(None) + # pyrefly: ignore [bad-override] + def _template_from_string(cls, source: str) -> Any: + return KernelTemplate._template_from_string(source) + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return False + + def make_key(self, name: str, input_key: str, layout_repr: str) -> str: + """ + Make a key for the code cache. The idea of the method is to cache + everything that matters but doesn't include runtime param values, i.e., + self.get_runtime_arg_values(). + + Args: + kwargs: Additional keyword arguments. Including op (GemmOperation). + """ + return hashlib.sha256( + str( + ( + input_key, + self.input_reorder, + # output layout, same as self.output_node.get_layout() + layout_repr, + self.get_runtime_arg_info(), + name, + ) + ).encode("utf-8") + ).hexdigest() + + def generate_code_and_args( + self, name: str, input_key: str, layout_repr: str, **kwargs + ) -> tuple[str, tuple[int, ...]]: + """ + Generate code and args with caching. We cache the code even if runtime + args are different. + """ + key: Optional[str] = None + if config.cuda.enable_caching_codegen: + key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) + + if key is not None and key in self.code_cache: + code, size_args, offset_args = self.code_cache[key] + extra_args = tuple( + list(size_args) + + list(offset_args) + + list(self.get_runtime_arg_values(**kwargs)) + ) + return code, extra_args + + kernel_name = str(Placeholder.KERNEL_NAME) + kernel = CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + autotuning_log.debug("Generated Code:\n%s", code) + autotuning_log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) + size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args()) + + if key is not None: + self.code_cache[key] = code, size_args, offset_args + + # extra args has runtime params, which shouldn't be cached + extra_args = tuple( + list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs) + ) + + return code, extra_args + + def generate( # type: ignore[override] + self, + name: str, + description: str, + input_key: str, + layout_repr: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. + This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel + in a standalone manner to enable Autotuning. + + Args: + description: op name followed by swizzle. + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + code, extra_args = self.generate_code_and_args( + name=name, + input_key=input_key, + layout_repr=layout_repr, + **kwargs, + ) + + # not caching since kernel name is needed below + kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] + kernel_name = f"cutlass_{kernel_hash}" + code = code.replace(self.name, kernel_name) + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + extra_args=extra_args, + source_code=code, + ) + + # kwargs has "op" argument in case of CUTLASSGemmTemplate + op = kwargs["op"] + if not op: + supports_epilogue_fusion = False + else: + # epilogue fusion is only supported for TMA kernels + supports_epilogue_fusion = self.supports_epilogue_fusion(op) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + ) -> tuple[CUDATemplateKernel, functools.partial[str]]: + assert supports_epilogue_fusion or not epilogue_nodes, ( + "epilogue fusion is not supported for this kernel" + ) + kernel = CUDATemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_name, + "cutlass_gemm", + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + supports_epilogue_fusion, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in ("1", "1L"): + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + torch.float8_e4m3fn: "cutlass::float_e4m3_t", + } + + _DTYPE_TO_CUTLASS_SPARSE_META = { + torch.int32: "uint32_t", + torch.int16: "uint16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return ( + f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("swizzle", "const uint8_t")] + + @override + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..66db98867b4131631540d262b4e7eb4c932cc02a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import inspect +import json +import logging +import os +import time +from typing import Any, Optional + +import torch._inductor.config as config +from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda import cutlass_utils, serialization +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +CONFIG_PREFIX: str = "configs" + + +def get_config_request_key( + arch: str, + cuda_version: str, + instantiation_level: str, +) -> str: + """ + Return a key for the full ops, based on cutlass key, arch, cuda version, instantiation level, and serialization.py file hash. + """ + + # Get hash of serialization.py and cutlass_utils.py files using their module file paths + def get_file_hash(file_module): + file_path = inspect.getfile(file_module) + with open(file_path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + serialization_hash = get_file_hash(serialization) + cutlass_utils_hash = get_file_hash(cutlass_utils) + + hash_target = "-".join( + [ + cutlass_key().hex(), + arch, + cuda_version, + instantiation_level, + serialization_hash, + cutlass_utils_hash, + ] + ) + return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] + + +def _generate_config_filename(request_key: str) -> str: + """ + Generate a filename for the full ops. + """ + return f"{CONFIG_PREFIX}_{request_key}.json" + + +@clear_on_fresh_cache +@functools.cache +def maybe_fetch_ops() -> Optional[list[Any]]: + """ + Fetch ops from databases. + """ + if config.force_disable_caches: + return None + + # setup + arch: str = get_cuda_arch() + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version: str = ".".join(get_cuda_version().split(".")[:2]) + instantiation_level: str = config.cuda.cutlass_instantiation_level + + # filename and filepath + request_key: str = get_config_request_key(arch, version, instantiation_level) + filename: str = _generate_config_filename(request_key) + filepath: str = os.path.join(cache_dir(), filename) + + # try fetch + serialized_ops: Optional[list[str]] = None + start_time = time.time() + if os.path.isfile(filepath): + # locally + try: + with open(filepath) as f: + serialized_ops = json.load(f) + + assert isinstance(serialized_ops, list), ( + f"Expected serialized ops is a list, got {type(serialized_ops)}" + ) + except Exception: + log.warning( + "Failed to load CUTLASS config %s from local cache", + filename, + exc_info=True, + ) + serialized_ops = None + elif config.is_fbcode(): + from torch._inductor.fb.cutlass_remote_cache import ( + maybe_fetch_cutlass_configs_from_remote, + ) + + # from remote + serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath) + + if serialized_ops is None: + return None + + # deserialize + serializer = get_cutlass_operation_serializer() + full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr] + log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time) + return full_ops diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..472438fec90e302b362f315fe58bd0062d89d94d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -0,0 +1,276 @@ +from collections.abc import Callable +from typing import Any, Union + +from sympy import Expr + +from torch._inductor.ir import ( + ComputedBuffer, + InputBuffer, + is_contiguous_strides_for_shape, +) +from torch.utils._ordered_set import OrderedSet + +from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass + + +EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace +Buffer = Union[ComputedBuffer, InputBuffer] +CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_..TupleType +CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory..VisitorType +CutlassArgType = ( + Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p +) + + +if try_import_cutlass(): + import ast + import ctypes + import textwrap + from typing import Union + + from cutlass_cppgen.backend.c_types import ( # type: ignore[import-not-found] + EmptyByte, + ) + from cutlass_cppgen.backend.epilogue import ( # type: ignore[import-not-found] + dtype2ctype, + ) + from cutlass_cppgen.backend.evt import ( # type: ignore[import-not-found] + EpilogueFunctorVisitor, + ) + from cutlass_cppgen.backend.evt.backend.emitter_base import ( # type: ignore[import-not-found] + FusionCallbacks, + ) + from cutlass_cppgen.backend.evt.backend.sm90_emitter import ( # type: ignore[import-not-found] + CollectiveEpilogue, + ) + from cutlass_cppgen.backend.evt.frontend import ( # type: ignore[import-not-found] + PythonASTFrontend, + ) + from cutlass_cppgen.backend.evt.ir.tensor import ( # type: ignore[import-not-found] + Tensor as CutlassTensor, + ) + from cutlass_library import ( + DataType, + EpilogueScheduleType, + LayoutType, + TileDescription, + ) + + from torch._inductor.codegen.cuda import cuda_env + from torch._inductor.utils import IndentedBuffer + + _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] + + class EVTArgRenames: + """Handles mapping buffer names to variable names in the cpp kernel signature and body""" + + def __init__(self) -> None: + self.buf_renames: dict[str, str] = {} + + def new_name(self, name: str) -> str: + if name in self.buf_renames: + return self.buf_renames[name] + else: + new_name = f"ptr_{len(self.buf_renames)}" + self.buf_renames[name] = new_name + return new_name + + def get(self, name: str) -> str: + return self.buf_renames.get(name, name) + + def create_example_tensors( + var_name_to_buffer_name: dict[str, str], + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> dict[str, CutlassTensor]: + def cutlass_tensor_from_buffer( + buffer: Buffer, + ) -> CutlassTensor: + shape = buffer.get_layout().size + stride = buffer.get_layout().stride + shape = tuple(size_hint_fn(x) for x in shape) + stride = tuple(size_hint_fn(x) for x in stride) + + is_row_major = is_contiguous_strides_for_shape(stride, shape) + is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1]) + + if not is_row_major and not is_column_major: + raise RuntimeError( + f"Cannot create example tensor for {buffer.get_name()} with \ +non-contiguous layout, received stride: {stride} and shape: {shape}" + ) + + return CutlassTensor( + shape=shape, + layout_tag=( + LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor + ), + element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), + ) + + return { + key: cutlass_tensor_from_buffer(name_to_buffer[name]) + for key, name in var_name_to_buffer_name.items() + } + + def trace( + fn_src: str, + example_tensors: dict[str, CutlassTensor], + accum_type: DataType, + output_type: DataType, + tile_description: TileDescription, + epilogue_schedule: EpilogueScheduleType, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + **kwargs: dict[str, Any], + ) -> tuple[str, str, str, EVTArgRenames]: + cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] + assert cuda_arch >= 90, "Only SM90+ is supported for EVT" + epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) + visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) + fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) + collective_epilogue = CollectiveEpilogue( + tile_description, + epilogue_schedule, + accum_type, + output_type, + fusion_callbacks, + ) + evt_name, evt_code = collective_epilogue.emit() + evt_args, arg_renames = _render_argument_type( + epilogue_functor, name_to_buffer, size_hint_fn + ) + return evt_name, evt_args, evt_code, arg_renames + + # Based off of + # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117 + # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function + # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval + def _trace( + fn_src: str, + example_tensors: dict[str, CutlassTensor], + cc: int, + **kwargs: Any, + ) -> EpilogueFunctor: + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc: int, **kwargs: Any): + self.source = textwrap.dedent(fn_src) + super().__init__(cc, **kwargs) + + def parse( + self, + example_inputs: dict[str, CutlassTensor], + ) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + # pyrefly: ignore [missing-attribute] + self.visit(self.ast) + + cc = int(cuda_env.get_cuda_arch()) + epilogue_functor = EpilogueFunctor(cc=cc, **kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + + def _render_argument_type( + epilogue_functor: EpilogueFunctor, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> tuple[str, EVTArgRenames]: + epilogue_thread_type = epilogue_functor.epilogue_thread_type + arg_renames = EVTArgRenames() + + # Fragile, but this is the only way to guarantee t is expected type because t is a local class + def is_nested_visitor_type(t: type) -> bool: + return ( + ".".join([t.__module__, t.__qualname__]) + == "cutlass_cppgen.backend.c_types.visitor_factory..VisitorType" + ) + + buffer = IndentedBuffer() + with buffer.set_tabwidth(2): + + def render_argument_type(name: str, t: CutlassArgType) -> None: + if issubclass(t, ctypes.c_byte): + buffer.writeline(f"{{}}, /* {name} */") + else: + fields = [ + ( + fname, + _get_arg_from_node( + ty, name_to_buffer[name], size_hint_fn, arg_renames + ), + ) + for fname, ty in t._fields_ + ] + field_strs = [ + f"/* {fname} */ {str(field)}" for fname, field in fields + ] + buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */") + + def render_thread_type(name: str, t: CutlassArgType) -> None: + if is_nested_visitor_type(t): + buffer.writeline(f"{{ /* {name} */") + with buffer.indent(): + for name, inner_t in t._fields_: + render_thread_type(name, inner_t) + buffer.writeline("},") + else: + render_argument_type(name, t) + + # unroll the recursion once to address special case formatting + # namely, no ending comma and no indentation for the outermost thread type + buffer.writeline("{ /* thread */") + with buffer.indent(3): + if is_nested_visitor_type(epilogue_thread_type): + with buffer.indent(): + for name, inner_t in epilogue_thread_type._fields_: + render_thread_type(name, inner_t) + else: + render_argument_type("thread", epilogue_thread_type) + buffer.writeline("}") + + return buffer.getvalue(), arg_renames + + def _get_arg_from_node( + arg_ty: type, + node: Buffer, + size_hint_fn: Callable[[Union[Expr, int]], int], + arg_renames: EVTArgRenames, + ) -> str: + from ..cuda_template import CUTLASSTemplate + + # Today, arguments are either a pointer to the + # node's memory, a stride tuple, the datatype + # Once again, need to check for local class type for stride tuple + if ( + str(arg_ty) + == ".TupleType'>" + ): + DEFAULT_STRIDE_LEN = 3 + assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN + stride = [size_hint_fn(x) for x in node.get_layout().stride] + for _ in range(DEFAULT_STRIDE_LEN - len(stride)): + stride.append(0) + + def render_stride(x: int) -> str: + # Handle EBO for 0 and 1 + if x == 0: + return "_0{}" + elif x == 1: + return "_1{}" + else: + return str(x) + + return f"{{{', '.join([render_stride(x) for x in stride])}}}" + + elif issubclass(arg_ty, ctypes.c_void_p): + name = arg_renames.new_name(node.get_name()) + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)" + elif ( + arg_ty in _CUTLASS_C_DTYPES + ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently + return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)" + elif issubclass(arg_ty, EmptyByte): + return "{}" + + raise NotImplementedError(f"Unsupported arg type: {arg_ty}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..95af1a968a97ce4de5db33a2752056369ecff94c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,411 @@ +# mypy: ignore-errors +from ..cutlass_utils import try_import_cutlass + + +# copied / modified from original at +# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + +if try_import_cutlass(): + import enum + + from cutlass_library.gemm_operation import * # noqa: F401, F403 + from cutlass_library.library import * # noqa: F401, F403 + + _LOGGER = logging.getLogger(__name__) + + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix="", evt_name=None): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.evt_name = evt_name + self.gemm_template = """ +using ${operation_name}_epilogue = +typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} +>::CollectiveOp; + +${mixed_dtype_prepare_code} + +using ${operation_name}_mainloop = +typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} +>::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : +public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ +${compile_guard_start} +{ + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); +} +${compile_guard_end} + """ + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + "epi_vs": str(operation.ScaleFactorVectorSize), + "element_d": str(DataTypeTag[operation.D.element]), + "element_sfd": str(DataTypeTag[operation.ScaleFactorD.element]), + "layout_sfd": LayoutTag[operation.ScaleFactorD.layout], + "epilogue_functor": EpilogueFunctor3xTag[ + EpilogueFunctor3x.LinearCombinationBlockScaleFactor + ], + "element_accumulator": str(DataTypeTag[operation.accumulator_type()]), + "element_scalar": str(DataTypeTag[operation.accumulator_type()]), + "element_c": str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = ( + "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + ) + + return ( + gemm_shape_type + if not is_grouped(operation.gemm_kind) + else grouped_gemm_shape_type + ) + + def emit(self, operation): + """Given a gem operation, emits a template definition of the operation""" + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = ( + operation.tile_description.math_instruction.instruction_shape + ) + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + tile_shape_m, tile_shape_n, tile_shape_k = tile_shape + + # account for static/dynamic cluster shapes + cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + + # Shape passed to epilogue builder + is_sm100_kernel = operation.arch == 100 + if is_sm100_kernel: + cta_m_per_mma_instruction = ( + 2 if "2sm" in operation.procedural_name() else 1 + ) + if cluster_m <= 0: + cta_m = cta_m // cta_m_per_mma_instruction + + if opcode_class_main in [ + OpcodeClass.TensorOp, + OpcodeClass.BlockScaledTensorOp, + ]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<\ +{str(operation.tile_description.stages)}>" + else: + stage_count_string = ( + f"cutlass::gemm::collective::StageCountAutoCarveout(\ +sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + ) + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctor3xTag[ + operation.epilogue_functor + ], + } + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values + ) + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, + # e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = ( + DataTypeTag[operation.A.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.A.element])},\ +{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + ) + element_b = ( + DataTypeTag[operation.B.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.B.element])},\ +{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + ) + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + is_no_smem_epilogue = operation.epilogue_schedule in [ + EpilogueScheduleType.NoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm, + ] + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped + ) + ] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped + ) + ] + element_a = f"cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>" + element_b = f"cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>" + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode is not None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + ) + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + ) + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" + using {operation_name_str}_StrideNarrow = {stride_narrow_str}; + using {operation_name_str}_ValueShuffle = {value_shuffle_str}; + static constexpr int {operation_name_str}_NumShuffleAtoms = 1; + using {operation_name_str}_MmaAtomShape = \ +cute::Layout>>; + using {operation_name_str}_LayoutAtomQuant = \ +decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, \ +{operation_name_str}_ValueShuffle>()); + using {operation_name_str}_LayoutNarrowReordered = \ +decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, \ +cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>", + } + narrow_element = mixed_input_modes_to_element.get( + operation.mixed_input_mode, narrow_tag + ) + + if narrow_dtype == DataType.s4 and ( + wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2 + ): + narrow_element = ( + f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + ) + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + if self.evt_name: + epilogue_functor = self.evt_name + + values = { + "operation_name": operation_name_str, + "operation_suffix": self.operation_suffix, + "problem_shape": self.problem_shape(operation), + "element_a": element_a, + "layout_a": self.pointerize_if_grouped(operation, layout_a_str), + "element_b": element_b, + "layout_b": self.pointerize_if_grouped(operation, layout_b_str), + "element_c": DataTypeTag[operation.C.element], + "layout_c": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_C] + ), + "element_d": DataTypeTag[operation.D.element], + "layout_d": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_D] + ), + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class_main": OpcodeClassTag[opcode_class_main], + "opcode_class_epi": OpcodeClassTag[opcode_class_epi], + "arch": f"cutlass::arch::Sm{operation.arch}", + "tile_shape_m": str(tile_shape_m), + "tile_shape_n": str(tile_shape_n), + "tile_shape_k": str(tile_shape_k), + "cluster_shape_m": "cute::_" + + str(operation.tile_description.cluster_shape[0]) + if operation.tile_description.cluster_shape[0] > 0 + else "int", + "cluster_shape_n": "cute::_" + + str(operation.tile_description.cluster_shape[1]) + if operation.tile_description.cluster_shape[1] > 0 + else "int", + "cluster_shape_k": "cute::_" + + str(operation.tile_description.cluster_shape[2]) + if operation.tile_description.cluster_shape[2] > 0 + else "int", + "instruction_shape_m": str(instruction_shape[0]), + "instruction_shape_n": str(instruction_shape[1]), + "instruction_shape_k": str(instruction_shape[2]), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), + "epilogue_schedule": str(epilogue_schedule_type), + "epi_tile_mn": epi_tile_mn, + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.D.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), + "mixed_dtype_prepare_code": mixed_dtype_prepare_code, + } + + return SubstituteTemplate(self.gemm_template, values) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b7d2afe6c39e27e81c3b78d2c411f3cdf7193e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -0,0 +1,326 @@ +import itertools +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from os import linesep +from typing import Any, Optional + +import sympy + +import torch +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, Pointwise +from torch._inductor.ops_handler import DefaultHandler, WrapperHandler +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet +from torch._inductor.virtualized import OpsValue + +from ...virtualized import V + + +_ACCUMULATOR_ARG_NAME = "accum" + + +def scaled_mm_evt( + scale_A_name: str, scale_B_name: str, bias_name: Optional[str], output_name: str +) -> tuple[list[str], dict[str, Any], str]: + evt_read_names = [scale_A_name, scale_B_name] + var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]} + var_name_to_buffer_name["D"] = output_name + var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name + expr = f"accum * {scale_A_name} * {scale_B_name}{linesep}" + if bias_name: + expr = f"({expr}) + {bias_name}" + evt_read_names.append(bias_name) + var_name_to_buffer_name[bias_name] = bias_name + + evt_py_code = f"def fn(accum, {','.join(evt_read_names)}):{linesep}\ + D = {expr}{linesep}\ + return D{linesep}" + + return evt_read_names, var_name_to_buffer_name, evt_py_code + + +class CutlassEVTOpsMixIn: + @staticmethod + def _infix_bin_op(op: str, a: str, b: str) -> str: + return f"{a} {op} {b}" + + @staticmethod + def _prefix_bin_op(op: str, a: str, b: str) -> str: + return f"{op}({a}, {b})" + + @staticmethod + def _prefix_un_op(op: str, a: str) -> str: + return f"{op}({a})" + + @staticmethod + def to_dtype( + x: str, + dtype: Any, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> str: + return x + + @staticmethod + def constant(value: Any, dtype: Any) -> str: + raise NotImplementedError + + @staticmethod + def mul(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1) + + @staticmethod + def truediv(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1) + + @staticmethod + def ge(x0: str, x1: str) -> str: + raise NotImplementedError + + @staticmethod + def add(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1) + + @staticmethod + def relu(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("relu", x0) + + @staticmethod + def sigmoid(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0) + + @staticmethod + def sub(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1) + + @staticmethod + def tanh(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0) + + @staticmethod + def exp(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("exp", x0) + + +class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler): + """Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning""" + + +class _AssignmentFormatter(DefaultHandler): + def __init__(self, parent_handler: "CutlassEVTCodegen"): + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # Handle op dispatch here + if hasattr(self.parent_handler, name): + fn = getattr(self.parent_handler, name) + line = fn(*args, **kwargs) + if name in ("load", "store"): + return OpsValue(line) + else: + var = self.parent_handler._tmp_var() + line = DelayReplaceLine( + var, + lambda: "D" + if var == self.parent_handler.last_stored_var_name + else var, + f"{var} = {line}", + ) + self.parent_handler.body.writeline(line) + return OpsValue(var) + else: + raise NotImplementedError(name) + + +class CutlassEVTCodegen(CutlassEVTOpsMixIn): + """ + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTCodegen.ir_to_evt_python_code(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + """ + + def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTCodegen.ir_to_evt_python_code static method. + + Args: + accumulator_node_name: The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + epilogue_nodes: The list of scheduler nodes to be fused into the epilogue + """ + self.accumulator_node_name: str = accumulator_node_name # + self.body: IndentedBuffer = IndentedBuffer(1) # The body buffer for codegen + self.var_counter: Iterator[int] = itertools.count() + self.store_name_to_value: dict[str, OpsValue] = ( + dict() + ) # Aliases for subexpression functors + self.reads: OrderedSet[str] = OrderedSet([]) + # Used for creating example tensors + self.var_name_to_buffer_name: dict[str, str] = { + _ACCUMULATOR_ARG_NAME: accumulator_node_name + } + self.removed_buffers: OrderedSet[str] = removed_buffers + self.cur_node: Optional[ComputedBuffer] = None + self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants: + self.name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + self.is_D_assigned = False + self.D_var_name = None + + if accumulator_node_name not in removed_buffers: + # cannot return accumulator directly, so alias it + var = self._tmp_var() + self.body.writeline(f"{var} = {_ACCUMULATOR_ARG_NAME}") + self.store(accumulator_node_name, value=OpsValue(var)) + + @staticmethod + def ir_to_evt_python_code( + cuda_template_node_name: str, + epilogue_nodes: list[BaseSchedulerNode], + removed_buffers: OrderedSet[str], + ) -> tuple[list[str], list[str], dict[str, Any], str]: + codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers) + handler = _AssignmentFormatter(codegen) + + with virtualized.V.set_ops_handler(handler): + for s_node in epilogue_nodes: + node = s_node.node + assert isinstance(node, ComputedBuffer) + with codegen.set_cur_node(node): + index_vars = CutlassEVTCodegen.get_index_vars(node) + node.get_store_function()(index_vars) + + codegen.finalize() + + return ( + codegen.get_reads(), + codegen.get_writes(), + codegen.get_renames(), + codegen.get_value(), + ) + + def get_value(self) -> str: + return linesep.join( + [ + self._render_input_signature(), + self.body.getvalue(), + self._render_return_statement(), + ] + ) + + def finalize(self) -> None: + # Rename the last store to D + # no other code references this store + # to workaround https://github.com/NVIDIA/cutlass/issues/2288 + # Note: the delayed line will automatically rewrite the last assignment to + # be to D + buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name] + self.var_name_to_buffer_name.pop(self.last_stored_var_name) + self.var_name_to_buffer_name["D"] = buffer_name + self.store_name_to_value[buffer_name] = OpsValue("D") + + @contextmanager + def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]: + prev_node = self.cur_node + try: + self.cur_node = node + yield + finally: + self.cur_node = prev_node + + def get_renames(self) -> dict[str, str]: + return dict(self.var_name_to_buffer_name) + + def get_reads(self) -> list[str]: + return list(self.reads.difference(self.store_name_to_value.keys())) + + def get_writes(self) -> list[str]: + return list(self.store_name_to_value.keys()) + + def load(self, name: str, index: Any) -> str: + self._check_indexing(name, index) + if name in self.store_name_to_value: + return self.store_name_to_value[name].value + elif name == self.accumulator_node_name: + return _ACCUMULATOR_ARG_NAME + else: + self.reads.add(name) + self.var_name_to_buffer_name[name] = name + return name + + def store( + self, name: Any, index: Any = None, value: Any = None, mode: Any = None + ) -> None: + if name not in self.removed_buffers: + if index: + self._check_indexing(name, index) + assert value.value != _ACCUMULATOR_ARG_NAME, ( + "Cannot store accumulator arg name" + ) + self.var_name_to_buffer_name[value.value] = name + self.store_name_to_value[name] = value + self.last_stored_var_name = value.value + return None + + def _get_cur_node(self) -> ComputedBuffer: + assert self.cur_node + return self.cur_node + + @staticmethod + def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]: + data = node.data + # TODO mlazos: relax this, cutlass supports reductions and other ops + assert isinstance(data, Pointwise) + return data._index(data.ranges) + + def _get_current_index_vars(self) -> Sequence[sympy.Expr]: + return self.get_index_vars(self._get_cur_node()) + + def _check_indexing(self, name: str, index: sympy.Expr) -> None: + # We only support indexing that matches the layout today because + # CUTLASS doesn't support arbitrary indexing + buffer_name = ( + self.accumulator_node_name if name == _ACCUMULATOR_ARG_NAME else name + ) + buffer = self.name_to_buffer[buffer_name] + index_strides = V.graph.sizevars.stride_vars( + index, self._get_current_index_vars() + ) + stride = buffer.get_layout().stride + if not self._stride_compatible(stride, index_strides): + raise NotImplementedError( + f"Unsupported indexing for {name} with index {index}, index strides {index_strides}, and layout stride {stride}" + ) + + def _stride_compatible( + self, left: Iterable[sympy.Expr], right: Iterable[sympy.Expr] + ) -> bool: + return all( + sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0) + for l, r in (zip(left, right)) + ) + + def _render_input_signature(self) -> str: + arguments = ", ".join( + [_ACCUMULATOR_ARG_NAME] + + [name for name in self.reads if name != self.accumulator_node_name] + ) + return f"def fn({arguments}):" + + def _render_return_statement(self) -> str: + return_vars = OrderedSet( + op_v.value for op_v in self.store_name_to_value.values() + ) + assert "D" in return_vars + return f"return {', '.join(return_vars)}" + + def _tmp_var(self) -> str: + return f"tmp_{next(self.var_counter)}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa46e8766cd5819b41af4c5269945119722d2251 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,493 @@ +# mypy: allow-untyped-defs +import atexit +import functools +import logging +import os +import shutil +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional +from typing_extensions import TypeIs + +import sympy + +import torch +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.utils import clear_on_fresh_cache +from torch.utils._ordered_set import OrderedSet + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from ...virtualized import V +from ..cpp_utils import DTYPE_TO_CPP +from .cuda_env import get_cuda_arch, get_cuda_version + + +log = logging.getLogger(__name__) + +CUTLASS_OPERATION_KIND: str = "gemm" +ACCUMULATOR_DTYPES: OrderedSet[torch.dtype] = OrderedSet([torch.float, torch.int32]) +XW_DTYPES: OrderedSet[torch.dtype] = OrderedSet( + [torch.half, torch.bfloat16, torch.float8_e4m3fn, torch.int8] +) + + +@atexit.register +def move_cutlass_compiled_cache() -> None: + """Move CUTLASS compiled cache file to the cache directory if it exists.""" + if not try_import_cutlass.cache_info().currsize > 0: + return + + import cutlass_cppgen # type: ignore[import-not-found] + + # Check if the CACHE_FILE attribute exists in cutlass_cppgen and if the file exists + if not hasattr(cutlass_cppgen, "CACHE_FILE") or not os.path.exists( + cutlass_cppgen.CACHE_FILE + ): + return + + try: + filename = os.path.basename(cutlass_cppgen.CACHE_FILE) + shutil.move(cutlass_cppgen.CACHE_FILE, os.path.join(cache_dir(), filename)) + log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) + except OSError: + log.warning("Failed to move CUTLASS compiled cache file", exc_info=True) + + +def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +@functools.cache +def try_import_cutlass() -> bool: + """ + We want to support three ways of passing in CUTLASS: + 1. fbcode, handled by the internal build system. + 2. User specifies cutlass_dir. The default is ../third_party/cutlass/, + which is the directory when developers build from source. + """ + if config.is_fbcode(): + try: + import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401 + import cutlass_library # type: ignore[import-not-found] + except ImportError as e: + log.warning( # noqa: G200 + "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", + str(e), + ) + return False + + return True + + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass, + # but will be moved to python/cutlass_library in the future (later 2025) + def path_join(path0, path1): + return os.path.abspath(os.path.join(path0, path1)) + + # contains both cutlass and cutlass_library + # we need cutlass for eVT + cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + torch_root = os.path.abspath(os.path.dirname(torch.__file__)) + mock_src_path = os.path.join( + torch_root, + "_inductor", + "codegen", + "cuda", + "cutlass_lib_extensions", + "cutlass_mock_imports", + ) + + cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library") + cutlass_cppgen_src_path = path_join(cutlass_python_path, "cutlass_cppgen") + pycute_src_path = path_join(cutlass_python_path, "pycute") + + tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass")) + + dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library") + dst_link_cutlass_cppgen = path_join(tmp_cutlass_full_path, "cutlass_cppgen") + dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute") + + # mock modules to import cutlass + mock_modules = ["cuda", "scipy", "pydot"] + + if os.path.isdir(cutlass_python_path): + if tmp_cutlass_full_path not in sys.path: + + def link_and_append(dst_link, src_path, parent_dir): + if os.path.lexists(dst_link): + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + src_path, + ), f"Symlink at {dst_link} does not point to {src_path}" + else: + os.makedirs(parent_dir, exist_ok=True) + os.symlink(src_path, dst_link) + + if parent_dir not in sys.path: + sys.path.append(parent_dir) + + link_and_append( + dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path + ) + link_and_append( + dst_link_cutlass_cppgen, cutlass_cppgen_src_path, tmp_cutlass_full_path + ) + link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path) + + for module in mock_modules: + link_and_append( + path_join(tmp_cutlass_full_path, module), # dst_link + path_join(mock_src_path, module), # src_path + tmp_cutlass_full_path, # parent + ) + + try: + import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401, F811 + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + import pycute # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError as e: + log.debug( # noqa: G200 + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_python_path, + ) + return False + + +@functools.lru_cache(8) +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 100: + log.warning( + "Detected CUDA architecture >= 100: %s. We will generate operations with " + "GenerateSM100 (if available) and GenerateSM90. Please file an " + "issue for any problems and feedback. ", + arch, + ) + + if int(arch) >= 100: + return "100" + elif int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + instantiation_level: Optional[str] = None + operations: Optional[str] = None + + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + exclude_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@clear_on_fresh_cache +@functools.cache +def _gen_ops_cached(arch, version) -> dict[Any, Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return {} + arch = _normalize_cuda_arch(arch) + instantiation_level: str = config.cuda.cutlass_instantiation_level + args = CUTLASSArgs( + architectures=arch, + cuda_version=version, + instantiation_level=instantiation_level, + operations=CUTLASS_OPERATION_KIND, + ) + manifest = cutlass_manifest.Manifest(args) + + start_time = time.time() + if arch == "100": + if hasattr(cutlass_generator, "GenerateSM100"): + cutlass_generator.GenerateSM100(manifest, args.cuda_version) + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + + log.info( + "CUTLASS library generated a dict of %d operation kinds in %.2f seconds", + len(manifest.operations), + time.time() - start_time, + ) + return manifest.operations + + +def gen_ops() -> dict[Any, Any]: + """ + Generates all supported CUTLASS operations. + """ + with dynamo_timed("cutlass_utils.gen_ops"): + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", +} + + +@functools.lru_cache(32) +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +@functools.lru_cache(32) +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 + elif torch_dtype == torch.float8_e4m3fn: + return cutlass_dtype == cutlass_library.library.DataType.e4m3 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: list[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + assert OrderedSet(input_torch_dtypes) <= XW_DTYPES, ( + f"{input_torch_dtypes=} is not supported" + ) + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + else: + size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() + size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() + if size0 > size1: + dtype0, dtype1 = input_torch_dtypes + else: + dtype1, dtype0 = input_torch_dtypes + if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ + torch.int8, + torch.uint8, + ]: + torch_dtype = dtype0 + + if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn): + accumulator_dtype = torch.float + elif torch_dtype == torch.int8: + accumulator_dtype = torch.int32 + else: + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") + + assert accumulator_dtype in ACCUMULATOR_DTYPES, ( + f"{accumulator_dtype=} is not supported" + ) + return accumulator_dtype + + +@functools.lru_cache(32) +def get_alignments(torch_dtype: torch.dtype) -> list[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn): + return [16, 8, 4, 2] + elif torch_dtype == torch.int32: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number: object) -> TypeIs[int | sympy.Integer]: + return isinstance(number, (int | sympy.Integer)) + + def a_factor_of(x, alignment): + if is_static_int(x) and is_static_int(alignment): + return x % alignment == 0 + rem = sympy.Mod(x, alignment) + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + alignments = get_alignments(dtype) + for alignment in alignments: + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( + offset, alignment + ): + continue + if all( + (dim == contiguous_dim) + or a_factor_of(inductor_layout.stride[dim], alignment) + for dim in range(len(size)) + ): + return alignment + return 1 + + +class CUDACompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self): + self.sources = [] + self._compile_patch = None + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile + + def my_compile( + source_code, dst_file_ext, extra_args: Optional[list[str]] = None + ): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + # pyrefly: ignore [bad-assignment] + self._compile_patch = mock.patch( + "torch._inductor.codecache.CUDACodeCache.compile", my_compile + ) + self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + return self + + def __exit__(self, *args, **kwargs): + self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + from torch._inductor.codecache import cuda_compile_command + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] + compile_command = cuda_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..147515e0decfe8f14853e18193fa4ca45501cac8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + """ + CUDA-specific codegen functions, see DeviceOpOverrides for details + """ + + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.cuda._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTICudaGuard" + + def cpp_stream_guard(self) -> str: + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::cuda::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self) -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + CUresult code_get_error = cuGetErrorString(code, &msg); \\ + if (code_get_error != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline CUfunction loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) { + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoadData(&mod, start)); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def tma_descriptor_helpers(self) -> str: + """ + CUDA helper functions for initializing TMA Descriptors on host side + """ + if torch.version.hip is not None: + raise RuntimeError("Host-side TMA descriptors not supported on HIP.") + + # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here: + # Old APIs (fill(1|2)DTMADescriptor): + # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283 + # New APIs (fillTMADescriptor): + # https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L283 + return """ + #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + [[maybe_unused]] static void init1DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim, + uint32_t blockDim, + uint32_t elementSize) { + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t tensorDims[1] = {blockDim}; + uint32_t elementStrides[1] = {1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + if (elementSize * blockDim < 32) { + throw std::runtime_error("block size too small"); + } + + int rank = 1; + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void init2DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim1, + uint64_t dim0, + uint32_t blockDim1, + uint32_t blockDim0, + uint32_t elementSize) { + uint64_t dims[2] = {dim0, dim1}; + uint32_t tensorDims[2] = {blockDim0, blockDim1}; + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + int rank = 2; + + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void initTMADescriptor( + CUtensorMap* m, + void* globalAddress, + int elemSize, + int rank, + uint32_t* blockSize, + uint64_t* shape, + uint64_t* stride + ) { + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + uint32_t blockSizeInt[5]; + uint64_t shapeInt[5]; + uint64_t stridesLL[5]; + + // Reorder blockSize (reverse the order) + for (int i = 0; i < rank; ++i) { + blockSizeInt[rank - i - 1] = blockSize[i]; + } + + // Reorder shape (reverse the order) + for (int i = 0; i < rank; ++i) { + shapeInt[rank - i - 1] = shape[i]; + } + + // Reorder and calculate strides + for (int i = 0; i + 1 < rank; ++i) { + stridesLL[rank - i - 2] = elemSize * stride[i]; + } + stridesLL[rank - 1] = + shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]); + + CUtensorMapDataType type; + // In Triton this is computed ahead of time; but for simplicity + // in the PyTorch version we copied this code from the old + // TMA API handling (i.e. init2DTMADescriptor) + switch (elemSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elemSize must be 1, 2, or 4"); + } + + // Calculate the size of the most contiguous dimension in bytes + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elemSize * blockSizeInt[0]; + if (rank == 1) { + // rank 1 should not be swizzled + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, + shapeInt, stridesLL, blockSizeInt, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, (CUtensorMapSwizzle)swizzle, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + struct StableTMADescriptor { + CUtensorMap m; + uint32_t block_shape[5]; + uint64_t global_shape[5]; + uint64_t strides[5]; + }; + #endif + """ + + def cpp_stream_type(self) -> str: + return "cudaStream_t" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self) -> str: + return "CUfunction" + + def cpp_device_ptr(self) -> str: + return "CUdeviceptr" + + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None + ) -> Optional[tuple[list[str], str]]: + prefix = f"{prefix}_" if prefix else "" + var_name = f"{prefix}scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b7188bd9e621eb4a2bed773d7a5a116bca9b3e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,1966 @@ +# mypy: allow-untyped-defs +import copy +import enum +import functools +import logging +import re +import time +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._inductor.autotune_process import TensorMeta +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.select_algorithm import create_inputs_key +from torch._inductor.utils import clear_on_fresh_cache + +from ... import ir +from ...config import cuda as inductor_cuda_config +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + FixedLayout, + IRNode, + Layout, + ReinterpretView, +) +from ...utils import is_dynamic, Placeholder +from ...virtualized import V +from ..common import IndentedBuffer +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +from .cutlass_utils import ( + ACCUMULATOR_DTYPES, + dtype_match, + torch_dtype_to_cutlass_type, + XW_DTYPES, +) + + +GemmOperation = Any +EVTArgRenames = Any + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{epilogue_visitor_tree}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: {{op_conf_name}} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; + arguments.scheduler.max_swizzle_size = swizzle; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + +# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below. +GEMM_TEMPLATE_CUTLASS_2X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below. +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + +GEMM_ARGS_SPARSE_CUTLASS_2X = r""" + using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA, + {{instance_type}}::LayoutA>; + using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB, + {{instance_type}}::LayoutB>; + using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC, + {{instance_type}}::LayoutC>; + using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE, + {{instance_type}}::LayoutE>; + // Note that "X" and "W" names may be misleading here. Namely, for + // sparse GEMM, the first argument is always sparse, while typically + // weight matrix, implied by name "W" will be sparse in + // applications. Thus, just remember that here: "X" refers to first + // argument, that is sparse, and "W" to second, that is dense. + TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}}); + TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}}); + TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}}); + TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}}, + TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} })); + // Initialize GemmSparse arguments. + arguments = { + { + static_cast(M), + static_cast(N), + static_cast(2 * K), + }, // GemmCoord problem_size + X_ref, // TensorRef ref_A + W_ref, // TensorRef ref_B + Y_ref, // TensorRef ref_C + Y_ref, // TensorRef ref_D + Meta_ref, // TensorRef ref_E + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue, + }; +""" + +# Additional includes which are necessary if the standalone test / debug runner is generated as well +GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" +#ifdef GENERATE_STANDALONE_RUNNER +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include +#endif +""" + +# Jinja template for the standalone runner that may be generated as part of the code. +GEMM_STANDALONE_RUNNER_TEMPLATE = r""" +#ifdef GENERATE_STANDALONE_RUNNER +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed, float max=1.0, float min=-1.0) { + if (block.size()<=0) return false; + Element scope_max(static_cast(max)), scope_min(static_cast(min)); + cutlass::reference::device::BlockFillRandomUniform( + (Element*)block.get(), block.size(), seed, scope_max, scope_min); + + return true; +} + +{% if Meta is defined and Meta is not none %} +template +bool initialize_block_meta( + cutlass::DeviceAllocation& block, + uint64_t seed) { + if (block.size()<=0) return false; + cutlass::reference::device::BlockFillRandomSparseMeta( + (Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits); + return true; +} +{% endif %} + +extern "C" int run_standalone(uint64_t seed, int repetitions) { + std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; + size_t workspace_size = 0; + size_t* workspace_size_ptr = &workspace_size; + + int M = {{kernel.get_layout_args()[0]}}; + int N = {{kernel.get_layout_args()[1]}}; + int K = {{kernel.get_layout_args()[2]}}; + int B = {{kernel.get_layout_args()[3]}}; + int lda = {{kernel.get_layout_args()[4]}}; + int ldb = {{kernel.get_layout_args()[5]}}; + int ldc = {{kernel.get_layout_args()[6]}}; + int ldd = {{kernel.get_layout_args()[7]}}; + uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; + + using ElementA = {{kernel.cutlass_dtype(X)}}; + using ElementB = {{kernel.cutlass_dtype(W)}}; + using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void + using ElementD = {{kernel.cutlass_dtype(Y)}}; + {% if Meta is defined and Meta is not none %} + using ElementE = {{kernel.cutlass_dtype(Meta)}}; + {% endif %} + + cutlass::DeviceAllocation X_data({{kernel.max_valid_index(X)+1}}); + initialize_block(X_data, seed++); + cutlass::DeviceAllocation W_data({{kernel.max_valid_index(W)+1}}); + initialize_block(W_data, seed++); + cutlass::DeviceAllocation Bias_data({{kernel.max_valid_index(Bias)+1}}); + initialize_block(Bias_data, seed++); + cutlass::DeviceAllocation Y_data({{kernel.max_valid_index(Y)+1}}); + {% if Meta is defined and Meta is not none %} + cutlass::DeviceAllocation Meta_data({{kernel.max_valid_index(Meta)+1}}); + initialize_block_meta(Meta_data, seed++); + {% endif %} + + cutlass::DeviceAllocation workspace_data; + // Call once with workspace_size_ptr set to get workspace size + + std::cout << "Calling once to get workspace size" << std::endl; + {{test_call_statement}}; + // Allocate workspace if necessary + if (workspace_size > 0) { + workspace_data.reset(workspace_size); + std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; + } + std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl; + workspace_size_ptr = nullptr; + for (int i=0; i None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__( + str(Placeholder.KERNEL_NAME), input_nodes, layout, input_reorder + ) + self.alpha = alpha + self.beta = beta + self.use_fast_accum = use_fast_accum + assert 2 <= len(input_nodes) <= 5 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + self.cache_key: str = create_inputs_key(self.input_nodes) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @staticmethod + @abstractmethod + def _has_tma_epilogue(self) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + raise NotImplementedError + + @abstractmethod + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + + # pre-computation + layout_repr: str = str(layout) + input_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.input_nodes) + ) + output_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.output_node) + ) + + with dynamo_timed("CUTLASSGemmTemplate.maybe_append_choice"): + for name, op in ops: + for ( + swizzle + ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, + op=op, + name=name, + description=description, + input_key=self.cache_key, + layout_repr=layout_repr, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + swizzle=swizzle, + ) + + if len(ops) == 0: + log.info( + "No suitable Cutlass GEMM configs found, fallbacks used " + "( len(ops)=%d, output_layout=%s, input_layouts=%s, input_strides=%s )", + len(ops), + layout, + [node.get_layout() for node in input_nodes], + [node.get_stride() for node in input_nodes], + ) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/device/gemm_sparse.h" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + if inductor_cuda_config.generate_test_runner and not is_dynamic( + *self.input_nodes, self.output_node + ): + res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return cutlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + @functools.lru_cache(32) + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> None: # type: ignore[name-defined] # noqa: F821 + """ + Helper method: Sets the layout of a given tensor description to match the given torch layout + """ + if CUTLASSGemmTemplate.layout_match(torch_layout, tensor_desc.layout): + return + tensor_desc.layout = CUTLASSGemmTemplate.cutlass_layout(torch_layout) + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + cuda_arch = cutlass_utils.get_cuda_arch() + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def should_swap_XW( + bias: IRNode, + ) -> bool: + """ + Helper method to determine whether we should do an explicit transpose by switching the order of the + matmul operands. This might be necessary when we can't otherwise arrive at the right memory + layout for the given Bias operand. + + Note: This method is a workaround for CUDA Errors that seemingly non-deterministically + occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. + it might make sense to check on newer Cutlass releases whether it makes sense to keep + returning True in certain cases or whether it becomes unnecessary. + """ + # If bias is row major, swap all M and N dimensions + if ( + bias is not None + and len(bias.get_stride()) >= 2 + and bias.get_stride()[-1] in (0, 1) + ): + log.debug("GEMM Layout swapped X and W -> explicit transpose") + return True + return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Swap operands X and W (aka operans A and B) of the GEMM operation. This + requires transposing the operands, which is done by swapping the strides. + Note that we don't change the apparent external layout, just the operand layout. + this is intentional. + """ + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def fix_op_layout( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout) + new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout) + return new_op + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of A, B, acc, D here. + + Empirically speaking, CUTLASS2x ops have same dtype for C and D. + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.D.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return False + + return True + + @classmethod + def global_filter_ops( + cls, + ops: list["cutlass_library.gemm_op.GemmOperation"], # type: ignore[name-defined] # noqa: F821 + ) -> list["cutlass_library.gemm_op.GemmOperation"]: # type: ignore[name-defined] # noqa: F821 + """ + Filter ops without using information about the torch op, input nodes and output node. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + # Skip simt kernels + ops = [ + op + for op in ops + if op.tile_description.math_instruction.opcode_class + != cutlass_lib.OpcodeClass.Simt + ] + + # only keep the set of row x column ops + # for other layout, we modify in place in filter_op, after deepcopy + ops = [ + op + for op in ops + if op.A.layout.name == "RowMajor" and op.B.layout.name == "ColumnMajor" + ] + + # filter by supported accumulator types + ops = [ + op + for op in ops + if any( + dtype_match(torch_dtype, op.accumulator_type()) + for torch_dtype in ACCUMULATOR_DTYPES + ) + ] + + # check if dtypes of A and B are supported + ops = [ + op + for op in ops + if any(dtype_match(torch_dtype, op.A.element) for torch_dtype in XW_DTYPES) + and any(dtype_match(torch_dtype, op.B.element) for torch_dtype in XW_DTYPES) + ] + + return ops + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + if op.gemm_kind not in self._get_supported_ops(): + return None + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + if not self._dtype_match(op): + return None + + # Filter ops by alignment. + if not self._alignment_match(op): + log.debug( + "Skipping due to alignment mismatch. op: %s", op.configuration_name() + ) + return None + + # only use stream k for static shape + if op.tile_scheduler.name == "StreamK": + static_shape = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + tuple(X.get_size()) + tuple(W.get_size()) + ) + if not static_shape: + return None + + # Update op. + op = copy.deepcopy(op) + + # set layouts for X and W + self.set_layout(op.A, X.get_layout()) + self.set_layout(op.B, W.get_layout()) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + status = ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ) + if not status: + log.debug( + "Skipping due to alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + if self.use_fast_accum is not None: + is_op_fast_accum = "fastaccum" in op.configuration_name() + if self.use_fast_accum ^ is_op_fast_accum: + return None + + # Set bias layout and alignment. + status = self._set_bias_layout_and_alignment(op) + if not status: + log.debug( + "Skipping due to bias layout and alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + # Apply regex filters at the end when configuration name doesn't change anymore + if inductor_cuda_config.cutlass_op_allowlist_regex: + if not re.search( + inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() + ): + return None + if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if re.search( + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + ): + return None + + return op + + def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + + if self.cache_key in self.filtered_ops_cache: + log.debug("Using cached ops for %s", self.cache_key) + return self.filtered_ops_cache[self.cache_key] + + with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"): + maybe_ops = maybe_fetch_ops() + if maybe_ops is None: + log.debug("Cannot fetch ops from cache, generating ops from scratch") + full_ops = cutlass_utils.gen_ops() + ops = pytree.tree_flatten(full_ops)[0] + else: + log.debug("Using cached ops from cache") + ops = maybe_ops + + ops = self.global_filter_ops(ops) + + res: dict[str, cutlass_gemm_op.GemmOperation] = {} + start_time = time.time() + for op in ops: + # if changed, need to also change CUTLASS_OPERATION_KIND + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.info( + "Got cutlass configs: total number of ops: %d. Filtering took %.2f seconds", + len(res), + time.time() - start_time, + ) + sorted_res = sorted(res.items()) + ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + if len(self.filtered_ops_cache) < 50: + self.filtered_ops_cache[self.cache_key] = ret_res + else: + log.debug("Not caching ops since filtered_ops_cache has reached size 50.") + return ret_res + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (CUDATemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + CUDATemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) + + if epilogue_nodes and not self._has_tma_epilogue(op): + raise NotImplementedError( + "Non-TMA epilogue visitor tree is not supported in Cutlass." + ) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + for input_node in self.input_nodes: + if not isinstance(X.layout, FixedLayout): + input_node.freeze_layout() + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + is_scaled_mm = len(self.input_nodes) in (4, 5) + if Bias is not None and not is_scaled_mm: + assert Bias.get_dtype() == X.get_dtype() + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + assert op.C.element == op.D.element, ( + f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}" + ) + + argument_template, epilogue_template = self._get_template_args(op) + should_swap_xw: bool = False + if Bias is not None and self._has_tma_epilogue(op): + if ( + op.epilogue_schedule + != cutlass_lib.EpilogueScheduleType.EpilogueTransposed + and self.should_swap_XW(Bias) + ): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + + name_to_buffer = {node.get_name(): node for node in self.input_nodes} + # handle the fake output buffer during lowering + name_to_buffer[Y.get_name()] = Y # type: ignore[assignment] + + if epilogue_nodes or is_scaled_mm: + if epilogue_nodes: + ( + input_names, + output_names, + var_name_to_buffer_name, + evt_py_code, + ) = CutlassEVTCodegen.ir_to_evt_python_code( + Y.get_name(), epilogue_nodes, V.kernel.removed_buffers + ) + + # TODO: mlazos remove this by returning buffer metadata from + # ir_to_evt_python code + for name, buf in ( + V.graph.name_to_buffer | V.graph.graph_inputs + ).items(): + if name not in name_to_buffer: + name_to_buffer[name] = buf # type: ignore[assignment] + + D_output_name = var_name_to_buffer_name["D"] + D_output_buffer = name_to_buffer[D_output_name] + Y = D_output_buffer # type: ignore[assignment] + # Interestingly, I don't think the rest of the layout matters here since we + # use the properties of the Y buffer to fill in D's properties in the epilogue + # args. This is needed though because it defines types expected in the epilogue args. + op.D.element = cutlass_utils.torch_dtype_to_cutlass_type( + D_output_buffer.get_dtype() + ) + + assert output_names, "There should be at least one write" + + epilogue_inputs = [name_to_buffer[name] for name in input_names] + outputs = [name_to_buffer[name] for name in output_names] + else: # Scaled MM, we read the two scale matrices (and optional bias) and write a single output + bias = None if len(self.input_nodes) < 5 else self.input_nodes[4] + bias_name = bias.get_name() if bias else None + + ( + evt_read_names, + var_name_to_buffer_name, + evt_py_code, + ) = scaled_mm_evt( + self.input_nodes[2].get_name(), # scale_A + self.input_nodes[3].get_name(), # scale_B + bias_name, + Y.get_name(), + ) + + input_names = list(evt_read_names) + output_names = [] # We only need Y + epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]] + if bias: + epilogue_inputs.append(bias) + outputs = [] + + acc_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()] + ) + assert acc_dtype, "Could not determine accumulator dtype" + + evt_name, evt_args, evt_code, evt_arg_renames = self._render_evt( + op, + evt_py_code, + var_name_to_buffer_name, + name_to_buffer, + Y.get_dtype(), + acc_dtype, + ) + + inputs = [ + X, + W, + Bias, + *epilogue_inputs, # type: ignore[list-item] + Y, + *extra_inputs, + ] + input_names = [evt_arg_renames.get(name) for name in input_names] + output_names = [evt_arg_renames.get(name) for name in output_names] + + names_str = ",".join( + ["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names] + ) + else: + evt_name = None + outputs = [Y] + evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + evt_code = "" + + kernel_call_signature = kernel.def_kernel( + inputs=inputs, # type: ignore[arg-type] + outputs=outputs, # type: ignore[arg-type] + names_str=names_str, + input_reorder=input_reorder, + ) + + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + + instance_definition, instance_type = self._define_gemm_instance(op, evt_name) + + options = { + "alpha": self.alpha, + "beta": self.beta, + "X": X, + "W": W, + "Y": Y, + "kernel_call_signature": kernel_call_signature, + "Bias": Bias, + "epilogue_template": epilogue_template, + "argument_template": argument_template, + "should_swap_xw": should_swap_xw, + "template": self, + "kernel": kernel, + "instance_definition": instance_definition, + "instance_type": instance_type, + "input_reorder": self.input_reorder, + "epilogue_args": evt_args, + "test_call_statement": test_call_statement, + "op_conf_name": op.configuration_name(), + "epilogue_visitor_tree": evt_code, + } + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template()).render(**options) + if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + test_runner_code = self._template_from_string( + GEMM_STANDALONE_RUNNER_TEMPLATE + ).render(**options) + res += "\n\n" + test_runner_code + + # splice to remove trailing spaces in each line + buf = IndentedBuffer() + buf.splice(res) + return buf.getvalue() + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs(cutlass_utils.DTYPE_TO_CUTLASS_TYPE) + arg_names = [name.strip() for name in names_str.strip().split(",")] + arg_names = self._update_arg_names_for_test_call_statement( + arg_names, input_nodes + ) + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + buffer_renames: dict[str, str], + name_to_buffer: dict[str, Buffer], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + """ + CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, + layout, + alpha, + beta, + input_reorder, + use_fast_accum, + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + @functools.lru_cache(1) + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + @staticmethod + def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 + ) -> bool: # type: ignore[name-defined] + """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return CUTLASS3xGemmTemplate._has_tma_epilogue(op) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert 2 <= len(layouts) <= 5 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [V.graph.sizevars.size_hint(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + var_name_to_buffer_name: dict[str, str], + name_to_buffer: dict[str, Buffer], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str, EVTArgRenames]: + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + + acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) + output_dtype = torch_dtype_to_cutlass_type(output_dtype) + + examples = create_example_tensors( + var_name_to_buffer_name, + name_to_buffer, # type: ignore[arg-type] + V.graph.sizevars.size_hint, + ) + evt_name, evt_args, evt_code, arg_renames = trace( + evt_py_code, + examples, + acc_dtype, + output_dtype, + op.tile_description, # type: ignore[attr-defined] + op.epilogue_schedule, # type: ignore[attr-defined] + {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] + V.graph.sizevars.size_hint, + ) + + return ( + evt_name, + evt_args, + evt_code, + arg_renames, + ) + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None + if has_bias: + Bias = self.input_nodes[2] + # bias dtype + op.C.element = cutlass_utils.torch_dtype_to_cutlass_type( + Bias.get_layout().dtype + ) + + # Bias layout + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + op.C.layout = bias_layout + + # Bias alignment + status = self.set_alignment(Bias.get_layout(), op.C) + if not status: + return False + else: + op.C.element = cutlass_lib.DataType.void + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions + + emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] + + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None + inputs: list[Optional[Buffer]] = [] + names: list[str] = [] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[2] is None: + del arg_names[2] + else: + # Reorder them as Bias, A, B + if self.input_reorder is not None: + arg_names[0 : len(self.input_reorder)] = [ + arg_names[i] for i in self.input_reorder + ] + return arg_names + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = { + "alpha": alpha, + "beta": beta, + "X": X, + "W": W, + "Y": Y, + "Bias": Bias, + "template": self, + "kernel": kernel, + "M": "M", + "N": "N", + "epilogue_args": epilogue_args, + } + assert epilogue_template is not None + + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) # type: ignore[union-attr] + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + assert old_layout.device is not None + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), # type: ignore[union-attr] + new_stride, + old_layout.offset, # type: ignore[union-attr] + ) + return Buffer(name=node.get_name(), layout=new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments + + +class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = False, + **extra_kwargs, + ) -> None: + template = CUTLASS2xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse] + + @staticmethod + def _has_tma_epilogue(self) -> bool: + return False + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_2X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return (GEMM_ARGS_SPARSE_CUTLASS_2X, None) + + return (GEMM_ARGS_CUTLASS_2X, None) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) != 2: + return False + if len(B_layout.size) != 2: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + K = max(A_size[1], B_size[0]) + return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0] + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + X, W = self.input_nodes[0], self.input_nodes[1] + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return X.get_size()[1] * 2 == W.get_size()[0] + + return X.get_size()[1] == W.get_size()[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + return True + + # SparseGemm in CUTLASS has specific alignment check that for + # small k could make some of the choices throw kMisalignedOperand + # CUTLASS error when run, see: + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # So, let's skip these choices if that would be the case. + X = self.input_nodes[0] + return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0 + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + op.C.layout = op.D.layout + return True + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + emitter = cutlass_gemm_op.EmitSparseGemmInstance() + else: + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + Bias = None + Meta = self.input_nodes[2] + else: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Meta = None + inputs = [Meta] + names = ["Meta"] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[3] is None: + del arg_names[3] + if input_nodes[2] is None: + del arg_names[2] + return arg_names + + def render_gemm_arguments( + self, + instance_type: str, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Meta: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + instance_type (str): GEMM instance type. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Meta (IRNode): The meta tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = { + "instance_type": instance_type, + "alpha": alpha, + "beta": beta, + "X": X, + "W": W, + "Y": Y, + "Bias": Bias, + "Meta": Meta, + "template": self, + "kernel": kernel, + "M": "M", + "N": "N", + "epilogue_args": epilogue_args, + } + + if epilogue_template is None: + arguments = self._template_from_string(argument_template).render( + split_k=1, **options + ) + return arguments + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a17f04b0a1b5a25ee623880eac8daf56a63e8ef4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py @@ -0,0 +1,507 @@ +# mypy: allow-untyped-defs +import functools +import json +from enum import Enum +from typing import Any, Optional + +from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + + +class CUTLASSOperationSerializer: + """Serializes and deserializes CUTLASS GEMM operations to/from JSON. + + Handles GemmOperation objects and their nested components (TileDescription, TensorDescription). + """ + + # not used, but keeping in case we want to generalize the serializer + _SUPPORTED_CLASSES: list[str] = [ + "GemmOperation", + "GemmKind", + "TileDescription", + "TensorDescription", + "DataType", + "EpilogueFunctor", + "EpilogueFunctor3x", + "SwizzlingFunctor", + "KernelScheduleType", + "EpilogueScheduleType", + "TileSchedulerType", + ] + + @classmethod + def serialize(cls, operation: "GemmOperation") -> str: # type: ignore[name-defined] # noqa: F821 + """Serialize a GEMM operation to JSON string. + + Args: + operation: GemmOperation object + + Returns: + str: JSON string representation of the operation + """ + assert operation.__class__.__qualname__ == "GemmOperation", ( + "Only GemmOperation objects are supported via the main API" + ) + return json.dumps(cls._gemm_operation_to_json(operation)) + + @classmethod + def deserialize(cls, json_str: str) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Deserialize JSON string to a GEMM operation. + + Args: + json_str: JSON string of a GEMM operation + + Returns: + GemmOperation: Reconstructed operation + """ + json_dict = json.loads(json_str) + return cls._json_to_gemm_operation(json_dict) + + @classmethod + def _gemm_operation_to_json(cls, operation: "GemmOperation") -> dict[str, Any]: # type: ignore[name-defined] # noqa: F821 + """Convert GemmOperation to JSON-serializable dict. + + Args: + operation: GemmOperation object + + Returns: + dict: Dictionary representation + """ + from cutlass_library.library import TensorDescription + + # Create the main dictionary with required and optional parameters + result = { + # Required parameters + "gemm_kind": cls._enum_to_json(operation.gemm_kind), + "arch": operation.arch, + "tile_description": cls._tile_description_to_json( + operation.tile_description + ), + "A": cls._tensor_description_to_json(operation.A), + "B": cls._tensor_description_to_json(operation.B), + "C": cls._tensor_description_to_json(operation.C), + "element_epilogue": cls._enum_to_json(operation.element_epilogue), + # Optional parameters + "epilogue_functor": cls._enum_to_json(operation.epilogue_functor), + "swizzling_functor": cls._enum_to_json(operation.swizzling_functor), + "D": cls._tensor_description_to_json(operation.D) if operation.D else None, + "kernel_schedule": cls._enum_to_json(operation.kernel_schedule), + "epilogue_schedule": cls._enum_to_json(operation.epilogue_schedule), + "tile_scheduler": cls._enum_to_json(operation.tile_scheduler), + } + + # Process optional attributes + optional_attrs = [ + "mixed_input_mode", + "mixed_input_shuffle", + "ScaleFactorA", + "ScaleFactorB", + "ScaleFactorD", + "ScaleFactorMVecSize", + "ScaleFactorNVecSize", + "ScaleFactorKVecSize", + "ScaleFactorVectorSize", + "is_3x", + ] + + for attr in optional_attrs: + if not hasattr(operation, attr): + continue + + value = getattr(operation, attr) + + if isinstance(value, TensorDescription): + result[attr] = cls._tensor_description_to_json(value) + elif isinstance(value, Enum): + result[attr] = cls._enum_to_json(value) + else: + result[attr] = value + + return result + + @classmethod + def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Convert JSON dict to GemmOperation object. + + Args: + json_dict: Dictionary representation + + Returns: + GemmOperation: Reconstructed object + """ + from cutlass_library import DataType + from cutlass_library.gemm_operation import GemmKind, GemmOperation + from cutlass_library.library import ( + EpilogueFunctor, + EpilogueFunctor3x, + EpilogueScheduleType, + KernelScheduleType, + MixedInputMode, + SwizzlingFunctor, + TileSchedulerType, + ) + + # Extract constructor parameters from the JSON dictionary + gemm_kind = cls._json_to_enum(json_dict["gemm_kind"], GemmKind) + arch = json_dict["arch"] + tile_description = cls._json_to_tile_description(json_dict["tile_description"]) + A = cls._json_to_tensor_description(json_dict.get("A"), "A") + B = cls._json_to_tensor_description(json_dict.get("B"), "B") + C = cls._json_to_tensor_description(json_dict.get("C"), "C") + element_epilogue = cls._json_to_enum(json_dict["element_epilogue"], DataType) + + # Get optional parameters with defaults + epilogue_functor = cls._json_to_enum( + json_dict.get("epilogue_functor"), + EpilogueFunctor3x if json_dict.get("is_3x") else EpilogueFunctor, + ) + swizzling_functor = cls._json_to_enum( + json_dict.get("swizzling_functor"), SwizzlingFunctor + ) + D = cls._json_to_tensor_description(json_dict.get("D"), "D") + kernel_schedule = cls._json_to_enum( + json_dict.get("kernel_schedule"), KernelScheduleType + ) + epilogue_schedule = cls._json_to_enum( + json_dict.get("epilogue_schedule"), EpilogueScheduleType + ) + tile_scheduler = cls._json_to_enum( + json_dict.get("tile_scheduler"), TileSchedulerType + ) + + mixed_input_mode = cls._json_to_enum( + json_dict.get("mixed_input_mode"), MixedInputMode + ) + mixed_input_shuffle = json_dict.get("mixed_input_shuffle", False) + + # Scale factors + ScaleFactorA = cls._json_to_enum(json_dict.get("ScaleFactorA"), DataType) + ScaleFactorB = cls._json_to_enum(json_dict.get("ScaleFactorB"), DataType) + + ScaleFactorD = None + if "ScaleFactorD" in json_dict and "ScaleFactorVectorSize" in json_dict: + ScaleFactorD = { + "tensor": cls._json_to_tensor_description( + json_dict.get("ScaleFactorD"), "ScaleFactorD" + ), + "vector_size": json_dict.get("ScaleFactorVectorSize"), + } + + ScaleFactorMVecSize = json_dict.get("ScaleFactorMVecSize") + ScaleFactorNVecSize = json_dict.get("ScaleFactorNVecSize") + ScaleFactorKVecSize = json_dict.get("ScaleFactorKVecSize") + + # Create the GemmOperation with the extracted parameters + operation = GemmOperation( + gemm_kind=gemm_kind, + arch=arch, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + mixed_input_mode=mixed_input_mode, + mixed_input_shuffle=mixed_input_shuffle, + ScaleFactorA=ScaleFactorA, + ScaleFactorB=ScaleFactorB, + ScaleFactorD=ScaleFactorD, + ScaleFactorMVecSize=ScaleFactorMVecSize, + ScaleFactorNVecSize=ScaleFactorNVecSize, + ScaleFactorKVecSize=ScaleFactorKVecSize, + ) + + return operation + + @classmethod + @functools.lru_cache(None) + def _tile_description_to_json(cls, tile_desc: "TileDescription") -> str: # type: ignore[name-defined] # noqa: F821 + """ + Convert TileDescription to JSON string. + + Args: + tile_desc: TileDescription object + + Returns: + str: JSON string representation + """ + + # Create the main dictionary with field names matching TileDescription constructor parameters + result = { + "threadblock_shape": tile_desc.threadblock_shape, + "stages": tile_desc.stages, + "warp_count": tile_desc.warp_count, + "math_instruction": cls._math_instruction_to_json( + tile_desc.math_instruction + ), + "min_compute": tile_desc.minimum_compute_capability, # Store as min_compute for constructor + "max_compute": tile_desc.maximum_compute_capability, # Store as max_compute for constructor + "cluster_shape": tile_desc.cluster_shape, + "explicit_vector_sizes": tile_desc.explicit_vector_sizes, + } + + # Add tile_shape if it exists and differs from threadblock_shape + if ( + hasattr(tile_desc, "tile_shape") + and tile_desc.tile_shape != tile_desc.threadblock_shape + ): + result["tile_shape"] = tile_desc.tile_shape + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_tile_description( + cls, json_dict: Optional[str] + ) -> Optional["TileDescription"]: # type: ignore[name-defined] # noqa: F821 + """ + Convert JSON dict to TileDescription object. + + Args: + json_dict: Dictionary representation + + Returns: + TileDescription: Reconstructed object + """ + if json_dict is None: + return None + + tile_dict = json.loads(json_dict) + + from cutlass_library.library import TileDescription + + math_instruction = cls._json_to_math_instruction(tile_dict["math_instruction"]) + + # Get compute capability values, checking both naming conventions + min_compute = tile_dict.get( + "min_compute", tile_dict.get("minimum_compute_capability") + ) + max_compute = tile_dict.get( + "max_compute", tile_dict.get("maximum_compute_capability") + ) + + # Get cluster shape with default value + cluster_shape = tile_dict.get("cluster_shape", [1, 1, 1]) + + # Create the TileDescription object + tile_desc = TileDescription( + threadblock_shape=tile_dict["threadblock_shape"], + stages=tile_dict["stages"], + warp_count=tile_dict["warp_count"], + math_instruction=math_instruction, + min_compute=min_compute, + max_compute=max_compute, + cluster_shape=cluster_shape, + explicit_vector_sizes=tile_dict.get("explicit_vector_sizes"), + ) + + # Set tile_shape if it exists and differs from threadblock_shape + if ( + "tile_shape" in tile_dict + and tile_dict["tile_shape"] != tile_dict["threadblock_shape"] + ): + tile_desc.tile_shape = tile_dict["tile_shape"] + + return tile_desc + + @classmethod + @functools.lru_cache(None) + def _math_instruction_to_json( + cls, + math_instruction: Optional["MathInstruction"], # type: ignore[name-defined] # noqa: F821 + ) -> Optional[str]: + """Convert MathInstruction to JSON string. + + Args: + math_instruction: MathInstruction object + + Returns: + Optional[str]: JSON string representation or None + """ + if math_instruction is None: + return None + + result = { + "instruction_shape": math_instruction.instruction_shape, + "element_a": cls._enum_to_json(math_instruction.element_a), + "element_b": cls._enum_to_json(math_instruction.element_b), + "element_accumulator": cls._enum_to_json( + math_instruction.element_accumulator + ), + "opcode_class": cls._enum_to_json(math_instruction.opcode_class), + "math_operation": cls._enum_to_json(math_instruction.math_operation), + "element_scale_factor": cls._enum_to_json( + math_instruction.element_scale_factor + ), + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_math_instruction( + cls, json_dict: Optional[str] + ) -> Optional["MathInstruction"]: # type: ignore[name-defined] # noqa: F821 + """Convert JSON string to MathInstruction object. + + Args: + json_dict: JSON string representation + + Returns: + Optional[MathInstruction]: Reconstructed object or None + """ + if json_dict is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import MathInstruction, MathOperation, OpcodeClass + + mi_dict = json.loads(json_dict) + + # Convert string enum names back to enum values + element_a = cls._json_to_enum(mi_dict["element_a"], DataType) + element_b = cls._json_to_enum(mi_dict["element_b"], DataType) + element_acc = cls._json_to_enum(mi_dict["element_accumulator"], DataType) + + # Get the opcode_class enum + opcode_class = cls._json_to_enum(mi_dict["opcode_class"], OpcodeClass) + + # Get the math_operation enum + math_op = cls._json_to_enum(mi_dict["math_operation"], MathOperation) + + # Create the MathInstruction object + math_instruction_obj = MathInstruction( + instruction_shape=mi_dict["instruction_shape"], + element_a=element_a, + element_b=element_b, + element_accumulator=element_acc, + opcode_class=opcode_class, + math_operation=math_op, + ) + + # Add element_scale_factor if it exists + if ( + "element_scale_factor" in mi_dict + and mi_dict["element_scale_factor"] is not None + ): + math_instruction_obj.element_scale_factor = cls._json_to_enum( + mi_dict["element_scale_factor"], DataType + ) + + return math_instruction_obj + + @classmethod + @functools.lru_cache(None) + def _tensor_description_to_json( + cls, + tensor_desc: Optional["TensorDescription"], # type: ignore[name-defined] # noqa: F821 + ) -> Optional[str]: + """Convert TensorDescription to JSON string. + + Args: + tensor_desc: TensorDescription object + + Returns: + Optional[str]: JSON string representation or None + """ + if tensor_desc is None: + return None + + result = { + "element": cls._enum_to_json(tensor_desc.element), + "layout": cls._enum_to_json(tensor_desc.layout), + "alignment": tensor_desc.alignment, + "complex_transform": cls._enum_to_json(tensor_desc.complex_transform), + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_tensor_description( + cls, + json_dict: Optional[str], + tensor_name: Optional[str] = None, + ) -> Optional["TensorDescription"]: # type: ignore[name-defined] # noqa: F821 + """Convert JSON string to TensorDescription object. + + Args: + json_dict: JSON string representation + tensor_name: Name of the tensor to avoid cache in the same op + + Returns: + Optional[TensorDescription]: Reconstructed object or None + """ + if json_dict is None: + return None + + tensor_dict = json.loads(json_dict) + + from cutlass_library import DataType + from cutlass_library.library import ( + ComplexTransform, + LayoutType, + TensorDescription, + ) + + element = cls._json_to_enum(tensor_dict["element"], DataType) + layout = cls._json_to_enum(tensor_dict["layout"], LayoutType) + alignment = tensor_dict["alignment"] + complex_transform = cls._json_to_enum( + tensor_dict["complex_transform"], ComplexTransform + ) + + return TensorDescription(element, layout, alignment, complex_transform) + + @classmethod + @functools.lru_cache(None) + def _enum_to_json(cls, enum_value: Optional[Enum]) -> Optional[str]: + """Convert enum value to JSON string. + + Args: + enum_value: Enum value + + Returns: + Optional[str]: JSON string representation or None + """ + if enum_value is None: + return None + + result = { + "type": enum_value.__class__.__name__, + "name": enum_value.name, + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_enum(cls, json_dict: Optional[str], enum_class: Any) -> Optional[Enum]: + """Convert JSON string to enum value. + + Format: {name: "EnumName", value: 1} + + Args: + json_dict: JSON string representation + enum_class: Target enum class + + Returns: + Optional[Enum]: Reconstructed enum value or None + """ + if json_dict is None: + return None + + enum_dict = json.loads(json_dict) + + return enum_class[enum_dict["name"]] + + +@functools.lru_cache(1) +def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]: + if not try_import_cutlass(): + return None + return CUTLASSOperationSerializer() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..8779a9e86cda65cd7859a4a693ff2fb6a1ddba70 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING, Union + +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +from .cutedsl.cutedsl_scheduling import CuteDSLScheduling +from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling +from .triton import TritonScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import TypeAlias + + from sympy import Expr + + import torch + from torch.utils._ordered_set import OrderedSet + + from .common import BackendFeature + + _IntLike: TypeAlias = Union[int, Expr] + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) + self._cutedsl_scheduling = CuteDSLScheduling(scheduler) + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling + if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): + return self._rocm_cpp_scheduling + if self._cutedsl_scheduling.is_cutedsl_template(node): + return self._cutedsl_scheduling + return self._triton_scheduling + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + elif self._cuda_cpp_scheduling.is_cuda_cpp_template( + node1 + ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): + return False + # CuteDSL doesn't support vertical fusion currently + elif self._cutedsl_scheduling.is_cutedsl_template( + node1 + ) or self._cutedsl_scheduling.is_cutedsl_template(node2): + return False + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + if self._cutedsl_scheduling.is_cutedsl_template(node): + return self._cutedsl_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[_IntLike]] + ) -> tuple[tuple[_IntLike, ...], ...]: + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + assert not prologue_nodes + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node): + assert not epilogue_nodes + assert not prologue_nodes + return self._rocm_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + elif self._cutedsl_scheduling.is_cutedsl_template(template_node): + # TODO remove this when we add epilogue support + assert not epilogue_nodes + assert not prologue_nodes + return self._cutedsl_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + + def codegen_mix_order_reduction(self, node): + return self._triton_scheduling.codegen_mix_order_reduction(node) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self) -> None: + return self._triton_scheduling.codegen_sync() + + def flush(self) -> None: + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None: + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def benchmark_codegened_module(self, module): + return self._triton_scheduling.benchmark_codegened_module(module) + + def generate_kernel_code_from_nodes( + self, + nodes: Sequence[Any], + benchmark_kernel: bool = False, + hint_override: Optional[int] = None, + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel, hint_override=hint_override + ) + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f12fa963fd60c00deb9f36f9515e3e794c9529ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py @@ -0,0 +1,8 @@ +# mypy: allow-untyped-defs +from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller + + +__all__ = [ + "CuteDSLTemplate", + "CuteDSLTemplateCaller", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17f850c8078c8d058bad8007e9cf14b69599003b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py @@ -0,0 +1,29 @@ +# mypy: disable-error-code=import-not-found +# pyrefly: ignore [import-error] +import cutlass.cute as cute + + +@cute.jit # type: ignore[misc] +def ssa_to_indexable(ssa_value: cute.TensorSSA, dtype: str) -> cute.Numeric: + """ + Convert SSA form to indexable non-SSA form. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices in tensor loads. This converts SSA → fragment → scalar for indexing. + """ + frag = cute.make_rmem_tensor(1, dtype) + frag.store(ssa_value) + return frag[0] + + +@cute.jit # type: ignore[misc] +def result_to_ssa(value: cute.Numeric, dtype: str) -> cute.TensorSSA: + """ + Convert non-SSA result back to SSA form. + + After performing operations with non-SSA values (like indexed loads), + convert the result back to SSA form for further computation. + """ + frag = cute.make_rmem_tensor(1, dtype) + frag[0] = value + return frag.load() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..4772ee1541726ec6b016a39f8974d15e676da6c8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -0,0 +1,599 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import logging +import textwrap +from collections.abc import Callable +from typing import Any, Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.common import ( + CSE, + CSEVariable, + IndentedBuffer, + Kernel, + ValueRanges, +) +from torch._inductor.ir import ( + BaseView, + Buffer, + ComputedBuffer, + ExternKernel, + InputBuffer, + MutableBox, + ReinterpretView, +) +from torch._inductor.ops_handler import StoreMode +from torch._inductor.utils import OrderedSet +from torch._inductor.virtualized import V + +from ...utils import sympy_index_symbol +from .cutedsl_op_overrides import CuteDSLOpOverrides + + +# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this +MAIN_SUFFIX = "main" + + +log = logging.getLogger(__name__) +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class CuteDSLKernelWrapper: + """Wrapper to provide .run() interface for CuteDSL kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("CuteDSL kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the CuteDSL kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +@dataclasses.dataclass +class CuteDSLSubgraphInfo: + """Minimal subgraph info for CuteDSL kernels.""" + + body: IndentedBuffer + template_mask: Optional[str] = None + template_out: Optional[str] = None + cse: Optional[CSE[Any]] = None + + def __post_init__(self): + self.only_copy_if_non_none_fields = ("cse",) + + def to_dict(self): + return { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + + +class CuteDSLTemplateKernel(Kernel): + """ + Template kernel implementation for CuteDSL (CUTLASS Python DSL). + Handles code generation and argument management for CuteDSL CUDA kernels. + Provides CuteDSL-specific functionality for tensor conversion and kernel configuration. + """ + + def __init__( + self, + kernel_name: str, + input_nodes: list[Buffer], + output_node: Buffer, + subgraphs: Optional[list[Buffer]] = None, + ) -> None: + # Call parent Kernel constructor + super().__init__() + self.kernel_name = kernel_name + self.input_nodes = input_nodes + self.output_node = output_node + self.subgraphs = subgraphs + self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {} + + # Template attributes + self.body: IndentedBuffer = IndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + self.template_indices: Optional[list[Any]] = None + self.render_hooks: dict[str, Any] = {} + + # TODO Additional attributes needed by template system + self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() + self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() + self.named_input_nodes: dict[str, Buffer] = {} + + # Create named input nodes mapping + for i, input_node in enumerate(input_nodes): + node_name = getattr(input_node, "name", f"input_{i}") + self.named_input_nodes[node_name] = input_node + + self.cse = CSE(name_prefix="tmp") + + # Track all tensor buffers added during modification processing + self.collected_tensor_buffers: list[str] = [] + + def kexpr(self, expr: sympy.Expr) -> str: + """Convert sympy expression to CuteDSL string representation.""" + return str(expr) + + def gen_imports(self) -> str: + """Generate common imports for CuteDSL templates.""" + imports = IndentedBuffer() + imports.splice( + """ + import torch + import cutlass + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + import cuda.bindings.driver as cuda + from cutlass._mlir.dialects import math as mlir_math + import operator + from torch._inductor.codegen.cutedsl._cutedsl_utils import ssa_to_indexable, result_to_ssa + """ + ) + return imports.getvalue() + + def gen_defines(self, **kwargs) -> str: + """Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines.""" + params = IndentedBuffer() + for name, val in kwargs.items(): + params.writeline(f"{name}: cutlass.Constexpr = {val}") + return params.getvalue() + + def render(self, template, **kwargs): + from torch._inductor.select_algorithm import PartialRender + + """Render the kernel using the template, returning PartialRender object with hooks.""" + # Available {{}} hooks for jinja rendering + template_env = { + "def_kernel": self.def_kernel, + "gen_defines": lambda: self.gen_defines(**kwargs), + "get_output": self.get_output, + "get_tensor_buffers": self.get_tensor_buffers, + "unpack_buffers": self.unpack_buffers, + "modification": self.modification, + "set_cute_hash": self.set_cute_hash, + } + + # Render the template with the environment and provided kwargs + rendered_code = template.render( + kernel_name=self.kernel_name, + input_nodes=self.input_nodes, + output_node=self.output_node, + **template_env, + **kwargs, + ) + + # Always prepend the common imports + imports = self.gen_imports() + full_code = imports + rendered_code + + return PartialRender(full_code, self.render_hooks) + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + """Set the active subgraph body for template processing.""" + assert all( + hasattr(self, field.name) + for field in dataclasses.fields(CuteDSLSubgraphInfo) + ) + old_state = { + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + + if body_name not in self.subgraph_bodies: + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + cse=None, + ) + + subgraph = self.subgraph_bodies[body_name] + for key, value in subgraph.to_dict().items(): + if value is None and key in getattr( + subgraph, "only_copy_if_non_none_fields", () + ): + continue + setattr(self, key, value) + + try: + yield + finally: + # Save current state back to subgraph + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + **{ + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + ) + # Restore old state + for key, value in old_state.items(): + setattr(self, key, value) + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False): + """Create a new subgraph body for template processing.""" + assert body_name not in self.subgraph_bodies, ( + f"Subgraph body '{body_name}' already exists" + ) + new_cse = self.cse.clone() if clear_cse else None + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + cse=new_cse, + ) + with self.set_subgraph_body(body_name): + yield + + def _get_reinterpret_view(self, node) -> ReinterpretView | None: + """Extract or convert to ReinterpretView from a node, handling all views.""" + while isinstance(node, MutableBox): + node = node.data + if isinstance(node, BaseView): + return ExternKernel.convert_to_reinterpret_view(node) + return None + + def def_kernel(self, *argnames): + """Define kernel function signature for CuteDSL templates. + + When inputs are ReinterpretViews of the same underlying buffer (e.g., Q/K/V + from fused QKV projection), we generate separate arguments for each input + even though they share the same underlying buffer. + """ + renames = IndentedBuffer(initial_indent=1) + + # Track template input args - each input gets its own arg even if buffers are shared + self._template_input_args: list[tuple[str, Buffer]] = [] + self._seen_input_args: OrderedSet[str] = OrderedSet() + + for i, input_node in enumerate(self.input_nodes): + buf_name = input_node.get_name() + # Register with args system (may deduplicate, but we track separately) + self.args.input(buf_name) + + if i < len(argnames): + template_name = argnames[i] + arg_name = f"arg_{template_name}" + self.args.input_buffers[buf_name] = arg_name + renames.writeline(f"{template_name} = {arg_name}") + self._template_input_args.append((arg_name, input_node)) + self._seen_input_args.add(arg_name) + + if self.output_node: + self.args.output(self.output_node.get_name()) + + def hook(): + # Generate signature with template input args plus additional args (output, sizevars) + code = IndentedBuffer() + code.writeline(f"# Kernel function signature: {self.kernel_name}") + + # Start with template input args + params = [arg_name for arg_name, _ in self._template_input_args] + + # Get additional args from python_argdefs (output, sizevars, etc.) + arg_defs, _, _, _ = self.args.python_argdefs() + for arg_def in arg_defs: + if arg_def.full_name() not in self._seen_input_args: + params.append(arg_def.full_name()) + + params.append("stream") + code.writeline( + f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):" + ) + with code.indent(): + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + # Placeholder-based rendering: hook will be called when template encounters "" + self.render_hooks[""] = hook + return "" + + def get_output(self): + """Get the actual argument name for the output buffer.""" + assert self.output_node, "Output node must exist to get output buffer name" + buf_name = self.output_node.get_name() + output = self.args.output_buffers.get(buf_name, None) + if output is None: + raise ValueError(f"Output buffer '{buf_name}' not found in args") + return output + + def set_cute_hash(self, func_name: str, suffix: str = ""): + """Generate code to set __cute_hash__ on a codegen function. + + This allows hash_callable in flash_attn to skip expensive runtime hashing + for Inductor-generated functions. The hash is based on the kernel name + which already contains a unique hash suffix. + """ + hash_value = f"{self.kernel_name}_{suffix}" if suffix else self.kernel_name + return f'{func_name}.__cute_hash__ = "{hash_value}"' + + def get_tensor_buffers(self): + """Get list of tensor buffer names that were collected during modifications.""" + return self.collected_tensor_buffers + + def unpack_buffers(self, buffer_list_name: str, *, indent_width: int = 4): + """Generate buffer unpacking code via render hook.""" + + def hook(): + tensor_buffers = self.get_tensor_buffers() + if not tensor_buffers: + return "" + + # Generate unpacking assignments: in_ptr4 = buffers[0], etc. + unpacking_lines = [] + for i, buffer_name in enumerate(tensor_buffers): + # pyrefly: ignore [bad-argument-type] + unpacking_lines.append(f"{buffer_name} = {buffer_list_name}[{i}]") + + indent = " " * indent_width + return "\n" + indent + ("\n" + indent).join(unpacking_lines) + + # Register the hook and return placeholder + placeholder = "" + # TODO: I think double invoking is fine for this specific hook + # assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node=None): + """Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel. + + For inputs that are ReinterpretViews (e.g., Q/K/V slices from fused QKV), + we generate reinterpret_tensor() calls to properly handle the views. + """ + wrapper = V.graph.wrapper_code + + # Build call args matching the signature generated in `def_kernel` + call_args = [] + arg_types = [] + + for _, input_node in self._template_input_args: + reinterpret_view = self._get_reinterpret_view(input_node) + if reinterpret_view is not None: + call_args.append(reinterpret_view.codegen_reference()) + else: + call_args.append(input_node.get_name()) + arg_types.append(V.graph.get_dtype(input_node.get_name())) + + # Add additional args from python_argdefs (output, sizevars, ..) + orig_arg_defs, orig_call_args, _, orig_arg_types = self.args.python_argdefs() + for arg_def, call_arg, arg_type in zip( + orig_arg_defs, orig_call_args, orig_arg_types + ): + # dedupe + if arg_def.full_name() not in self._seen_input_args: + call_args.append(call_arg) + arg_types.append(arg_type) + + # TODO this karg really should not be called `triton` + wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types) + + def _get_subgraph(self, subgraph_number: int): + """Get subgraph by number for modification processing.""" + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len(self.subgraphs), ( + f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + ) + assert self.body.getvalue() == "", ( + "Body should be clear before adding a modification" + ) + return self.subgraphs[subgraph_number] + + def modification( + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, + ) -> str: + """Generate CuteDSL code for a subgraph modification.""" + # Find unique name to avoid collisions between multiple modifications of same subgraph + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapperCuteDSL( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_kernel_handler(self), V.set_ops_handler(modification_handler): + assert isinstance(subgraph, (ComputedBuffer, list)), ( + f"Expected ComputedBuffer or List[ComputedBuffer], got {type(subgraph)}" + ) + + if isinstance(subgraph, list): + raise NotImplementedError( + "Scatter graphs are not supported for CuteDSL" + ) + + if isinstance(subgraph.data, InputBuffer): + # grad_score_mod can be InputBuffers + out = subgraph.data.make_loader()(()) + else: + # Inline a pointwise lowering into the template + out = subgraph.data.inner_fn(()) + + if output_name is not None: + assert out is not None, ( + f"Expected computation result for named output {output_name}" + ) + self.body.writeline(f"{output_name} = {out.value}") + else: + # Side-effect only: no output assignment (currently only for scatter operations) + raise NotImplementedError( + "Side-effect only modifications not yet supported for CuteDSL" + ) + + # Add Buffers that were added during modification + self.collected_tensor_buffers.extend(modification_handler.tensor_buffers) + + return self.body.getvalue() + + +class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined] + """ + Wrapper handler that enables CuteDSL code generation during subgraph modifications. + + This class sits between the PyTorch IR and CuteDSL code generation, providing: + 1. Operation substitution: converts PyTorch ops to CuteDSL equivalents via CuteDSLOpOverrides + 2. Placeholder handling: resolves fixed_inputs during template processing + 3. Limited operation support: currently restricted to pointwise operations + + """ + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: dict[str, Any], + mask: Optional[str], + ): + cutedsl_ops = CuteDSLOpOverrides() + super().__init__(cutedsl_ops) + self.name = f"CuteDSLPlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + # Track tensor buffers that get added during modification processing + self.tensor_buffers: list[str] = [] + + def _get_input_dtype(self, name: str) -> torch.dtype: + """Get the dtype for an input from the kernel's named_input_nodes.""" + if name in self.kernel.named_input_nodes: + return self.kernel.named_input_nodes[name].dtype + # TODO: Fallback for common dimension names - should be replaced with proper dtype tracking + return torch.float32 if name not in ("b", "h", "m", "n") else torch.int32 + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed(template args) input for CuteDSL.""" + if name not in self.fixed_inputs: + var = self._add_kernel_input(name) + buffer = V.graph.get_buffer(name) + var_dtype = buffer.dtype + + cute_dtype = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get( + var_dtype, "cutlass.Float32" + ) + renamed_index = self.kernel.rename_indexing(index) + + idx_var = self._emit_scalar_fragment( + self.kernel.kexpr(renamed_index), "cutlass.Int32", torch.int32 + ) + + val_frag = self.kernel.cse.newvar(dtype=var_dtype) + self.kernel.body.writeline( + f"{val_frag} = cute.make_rmem_tensor(1, {cute_dtype})" + ) + + self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{idx_var}])") + + final_expr = f"{val_frag}.load()" + + if ( + var_dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + final_expr = f"({final_expr}).to(cutlass.Float32)" + var_dtype = torch.float32 + + out = self.kernel.cse.generate( + self.kernel.body, + final_expr, + dtype=var_dtype, + bounds=ValueRanges.unknown(), + ) + return out + + value = self.fixed_inputs[name] + dtype = self._get_input_dtype(name) + + return self.kernel.cse.generate( + self.kernel.body, value, bounds=ValueRanges.unknown(), dtype=dtype + ) + + def _emit_scalar_fragment( + self, expr_str: str, cute_dtype: str, torch_dtype: torch.dtype + ) -> str: + """ + Convert SSA expression to indexable scalar for tensor loads. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices. This generates code to convert SSA → indexable scalar. + """ + result = self.kernel.cse.newvar(dtype=torch_dtype) + self.kernel.body.writeline( + f"{result} = ssa_to_indexable({expr_str}, {cute_dtype})" + ) + return str(result) + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + # pyrefly: ignore [bad-override] + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> str: + raise NotImplementedError( + "Store operations not supported - CuteDSL limited to read-only operations" + ) + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + # Get the remapped name that will be used in the kernel + remapped_name = self.kernel.args.input(name) + # Track the remapped name for later collection + if remapped_name not in self.tensor_buffers: + self.tensor_buffers.append(remapped_name) + return remapped_name + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + renamed = self.kernel.rename_indexing(index) + return self.kernel.kexpr(renamed) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + try: + return getattr(self._inner, name)(*args, **kwargs) + except NotImplementedError as e: + bar = "=" * 80 + msg = textwrap.dedent(f""" + {bar} + UNSUPPORTED CUTEDSL OPERATION: '{name}' + {bar} + This operation is not yet implemented in Inductor. + + Please open an issue at: https://github.com/pytorch/pytorch/issues + with the following information: + + Operation: {name} + Args: {args!r} + Kwargs: {kwargs!r} + + Title your issue: [CuteDSL] Missing operation: {name} + {bar} + """).strip() + raise NotImplementedError(msg) from e diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3ca75c52adcf96bd1f8e4270eff933b953c1c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py @@ -0,0 +1,360 @@ +# mypy: allow-untyped-defs +""" +CuteDSL-specific operation overrides for pointwise operations. + +This module provides CuteDSL implementations of common operations used in +template kernels, particularly for flex attention modifications. +""" + +import math +from typing import Optional, Union + +import sympy + +import torch +from torch._inductor.codegen.common import CSEVariable, OpOverrides +from torch._inductor.virtualized import OpsValue, V +from torch.utils._sympy.value_ranges import ValueRanges + + +CuteDSLArg = Union[CSEVariable, str] + + +def upcast_compute_type(dtype: torch.dtype) -> torch.dtype: + """Maybe upcast [b]float16 to float32""" + if dtype in (torch.float16, torch.bfloat16): + return torch.float32 + return dtype + + +class CuteDSLOpOverrides(OpOverrides): + """ + CuteDSL-specific operation overrides that generate code using CuteDSL syntax. + + CuteDSL TensorSSA objects have built-in operator overloads (__add__, __mul__, etc.) + and math functions (cute.math.exp, cute.math.sqrt, etc.) + """ + + TORCH_TO_CUTE_DTYPE = { + torch.float16: "cutlass.Float16", + torch.bfloat16: "cutlass.BFloat16", + torch.float32: "cutlass.Float32", + torch.float64: "cutlass.Float64", + torch.int8: "cutlass.Int8", + torch.int16: "cutlass.Int16", + torch.int32: "cutlass.Int32", + torch.int64: "cutlass.Int64", + torch.bool: "cutlass.Boolean", + torch.float8_e4m3fn: "cutlass.Float8E4M3FN", + torch.float8_e5m2: "cutlass.Float8E5M2", + } + + # Math constants + LOG2_E = 1.4426950408889634 # 1/ln(2) for converting natural exp to base-2 exp + + @staticmethod + def _ensure_tensor_ssa(arg: CuteDSLArg, template_tensor: CuteDSLArg) -> str: + """ + Convert scalar arguments to TensorSSA using cute.full_like if needed. + + Args: + arg: The argument to check (CSEVariable for tensors, str for scalars, or OpsValue wrapper) + template_tensor: A tensor argument to use as template for full_like + + Returns: + String representation suitable for CuteDSL operations + """ + if isinstance(arg, CSEVariable): + return str(arg) + + if isinstance(arg, OpsValue) and isinstance(arg.value, CSEVariable): + return str(arg.value) + + if isinstance(template_tensor, CSEVariable): + return f"cute.full_like({template_tensor}, {arg})" + + return str(arg) + + @staticmethod + def _extract_dtype_and_bounds( + *args: CuteDSLArg, + ) -> tuple[Optional[torch.dtype], ValueRanges[sympy.Expr]]: + """Extract dtype and bounds from CSEVariable arguments.""" + for arg in args: + if isinstance(arg, CSEVariable): + return arg.dtype, arg.bounds + return None, ValueRanges.unknown() + + @staticmethod + def _apply_binary_op(a: CuteDSLArg, b: CuteDSLArg, op_format: str) -> CuteDSLArg: + """ + Apply a binary operation with automatic scalar-to-tensor conversion. + + CuteDSL requires both operands to be TensorSSA objects for tensor operations. + This helper automatically converts scalar arguments to TensorSSA using + cute.full_like when at least one argument is a tensor (CSEVariable). + + Args: + a: First operand (CSEVariable for tensors, str for scalars) + b: Second operand (CSEVariable for tensors, str for scalars) + op_format: Format string with {a} and {b} placeholders for the operation + + Returns: + CSEVariable if at least one operand is a CSEVariable, otherwise string + """ + tensor_arg = ( + a + if isinstance(a, CSEVariable) + else b + if isinstance(b, CSEVariable) + else None + ) + if tensor_arg is not None: + a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg) + b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg) + result_expr = op_format.format(a=a_ssa, b=b_ssa) + + dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(a, b) + + # Create and return CSEVariable using CSE generation for caching + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=bounds, dtype=dtype + ) + + return op_format.format(a=a, b=b) + + @staticmethod + def _apply_unary_op(x: CuteDSLArg, op_format: str) -> CuteDSLArg: + """ + Apply a unary operation, returning CSEVariable if input is CSEVariable. + + Args: + x: Input operand (CSEVariable for tensors, str for scalars) + op_format: Format string with {x} placeholder for the operation + + Returns: + CSEVariable if input is a CSEVariable, otherwise string + """ + if isinstance(x, CSEVariable): + result_expr = op_format.format(x=str(x)) + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=x.bounds, dtype=x.dtype + ) + + return op_format.format(x=x) + + @staticmethod + def constant(value: Union[bool, float, int], dtype: torch.dtype) -> str: + """Generate CuteDSL constant representation.""" + if value == float("-inf"): + return "float('-inf')" + elif value == float("inf"): + return "float('inf')" + elif math.isnan(value): + return "float('nan')" + return repr(value) + + @staticmethod + def add(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} + {b})") + + @staticmethod + def mul(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} * {b})") + + @staticmethod + def sub(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} - {b})") + + @staticmethod + def truediv(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} / {b})") + + @staticmethod + def mod(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})") + + @staticmethod + def remainder(a, b): + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})") + + @staticmethod + def exp(x: CuteDSLArg) -> CuteDSLArg: + """Exponential using CuteDSL cute.math.exp function.""" + return CuteDSLOpOverrides._apply_unary_op( + x, f"cute.math.exp2({{x}} * {CuteDSLOpOverrides.LOG2_E})" + ) + + @staticmethod + def sqrt(x: CuteDSLArg) -> CuteDSLArg: + """Square root using CuteDSL cute.math.sqrt function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sqrt({x})") + + @staticmethod + def log(x: CuteDSLArg) -> CuteDSLArg: + """Natural logarithm using CuteDSL cute.math.log function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.log({x})") + + @staticmethod + def cos(x: CuteDSLArg) -> CuteDSLArg: + """Cosine using CuteDSL cute.math.cos function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.cos({x})") + + @staticmethod + def sin(x: CuteDSLArg) -> CuteDSLArg: + """Sine using CuteDSL cute.math.sin function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sin({x})") + + @staticmethod + def erf(x: CuteDSLArg) -> CuteDSLArg: + """Error function using CuteDSL cute.math.erf function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.erf({x})") + + @staticmethod + def maximum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + raise NotImplementedError("TODO: maximum is not supported yet for TensorSSA") + + @staticmethod + def minimum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + raise NotImplementedError("TODO: minimum is not supported yet for TensorSSA") + + @staticmethod + def where( + condition: CuteDSLArg, + a: CuteDSLArg, + b: CuteDSLArg, + ) -> CuteDSLArg: + """Conditional selection - handles both CSEVariable and string inputs.""" + # Find a tensor argument to use as template for full_like + # Priority: use 'a' if it's a tensor, else use 'b', else condition + tensor_arg = ( + a + if isinstance(a, CSEVariable) + else ( + b + if isinstance(b, CSEVariable) + else condition + if isinstance(condition, CSEVariable) + else None + ) + ) + + if tensor_arg is not None: + a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg) + b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg) + result_expr = f"cute.where({condition}, {a_ssa}, {b_ssa})" + + dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds( + a, b, condition + ) + + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=bounds, dtype=dtype + ) + + return f"cute.where({condition}, {a}, {b})" + + @staticmethod + def pow(a: CuteDSLArg, b: CuteDSLArg): + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} ** {b})") + + @staticmethod + def abs(x: CuteDSLArg) -> CuteDSLArg: + """Absolute value using CuteDSL cute.math.abs function.""" + if isinstance(x, CSEVariable): + x_dtype = x.dtype + elif isinstance(x, OpsValue) and isinstance(x.value, CSEVariable): + x_dtype = x.value.dtype + else: + x_dtype = torch.float32 + + abs_op = ( + "mlir_math.absf" + if x_dtype in (torch.float16, torch.bfloat16, torch.float32) + else "mlir_math.absi" + ) + return CuteDSLOpOverrides._apply_unary_op( + # pyrefly: ignore [bad-argument-type] + x, + f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)", + ) + + @staticmethod + def neg(x: CuteDSLArg) -> CuteDSLArg: + """Negation using CuteDSL TensorSSA __neg__ operator.""" + # TODO: See https://github.com/NVIDIA/cutlass/issues/2584 + return CuteDSLOpOverrides._apply_unary_op( + x, "cute.TensorSSA(-{x}, {x}.shape, {x}.dtype)" + ) + + @staticmethod + def to_dtype( + x: CuteDSLArg, dtype: torch.dtype, src_dtype=None, use_compute_types=True + ) -> CuteDSLArg: + """Type conversion using CuteDSL TensorSSA.to(Type[Numeric]). + + Maps torch dtypes to cutlass.cute.typing numeric types and emits + `{x}.to(cute.typing.)`. + + Raises NotImplementedError for unsigned integer and unsupported dtypes. + """ + # Always convert up from bf16 and fp16 TODO on configuring + dtype = upcast_compute_type(dtype) + + cute_type = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get(dtype) + if cute_type is None: + raise NotImplementedError( + f"CuteDSL dtype cast not implemented for torch dtype: {dtype}" + ) + + if isinstance(x, CSEVariable): + result_expr = f"{str(x)}.to({cute_type})" + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=x.bounds, dtype=dtype + ) + + return f"{x}.to({cute_type})" + + @staticmethod + def tanh(x0: CuteDSLArg) -> CuteDSLArg: + """Hyperbolic tangent using CuteDSL cute.math.tanh function.""" + return CuteDSLOpOverrides._apply_unary_op(x0, "cute.math.tanh({x})") + + # Logical operations + @staticmethod + def logical_and(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} and {b})") + + @staticmethod + def logical_or(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} or {b})") + + @staticmethod + def logical_not(a): + """Logical NOT.""" + return CuteDSLOpOverrides._apply_unary_op(a, "({x} == 0)") + + # Comparison operations + @staticmethod + def eq(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.eq({a}, {b})") + + @staticmethod + def ne(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ne({a}, {b})") + + @staticmethod + def lt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.lt({a}, {b})") + + @staticmethod + def le(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.le({a}, {b})") + + @staticmethod + def gt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.gt({a}, {b})") + + @staticmethod + def ge(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ge({a}, {b})") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc1089a4082acc02f4b039f2fda9c0a726648d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ... import config +from ...codecache import code_hash, get_path +from ...ir import CuteDSLTemplateBuffer +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, +) +from ...select_algorithm import PartialRender +from ...utils import get_fused_kernel_name, get_kernel_metadata +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class CuteDSLScheduling(BaseScheduling): + """ + Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels. + This class is intended to be used in combination with other schedulers, + and delegated to by CUDACombinedScheduling. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + @staticmethod + def is_cutedsl_template(node: BaseSchedulerNode) -> bool: + """Check if a node is a CuteDSL template.""" + return isinstance(node, SchedulerNode) and isinstance( + node.node, CuteDSLTemplateBuffer + ) + + def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool: + """Check if a node is a fused CuteDSL template.""" + return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + TODO CuteDSL doesn't support vertical fusion yet. + This could be extended in the future for epilogue fusion. + """ + return False + + def define_kernel(self, src_code_str: str, node_schedule) -> str: + """Produce the kernel string + Args: + src_code_str: The finalized kernel code string + node_schedule: List of nodes in the schedule + + Note: + This is a little weird since async_compile.cutedsl() has to write the string to + a file in order to cute compile it. Feels bad to have two... + """ + wrapper = V.graph.wrapper_code + + # Use the string as the key for caching + if src_code_str in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code_str] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"cutedsl_{kernel_hash}" + else: + kernel_name = f"cutedsl_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code_str] = kernel_name + src_code_str = src_code_str.replace( + str(Placeholder.KERNEL_NAME), kernel_name + ) + + _, _, kernel_path = get_path(code_hash(src_code_str), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''") + compile_wrapper.splice(src_code_str, strip=True) + compile_wrapper.writeline("''')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CuteDSL template. Currently doesn't support fusion. + """ + assert self.is_cutedsl_template(template_node), ( + "Template node passed to CuteDSLScheduling.codegen_template must be a " + "SchedulerNode that wraps a CuteDSLTemplateBuffer" + ) + # TODO remove when supported + assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet" + assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet" + + template_node = cast(SchedulerNode, template_node) + ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node) + + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] + template_node.mark_run() + src_code = render() + # Finalize PartialRender if needed + if isinstance(src_code, PartialRender): + src_code_str = src_code.finalize_all() + else: + src_code_str = src_code + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code_str, node_schedule) + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bf30480981378daca74cf4ab4b1e4c01e8065e79 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from collections.abc import Iterable +from typing import Any, Optional, Union +from unittest.mock import patch + +from torch._inductor.ir import ShapeAsConstantBuffer +from torch._inductor.utils import Placeholder +from torch._inductor.virtualized import V +from torch._logging import getArtifactLogger + +from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta +from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, IRNode, Layout, TensorBox +from ..common import KernelTemplate +from .cutedsl_kernel import CuteDSLTemplateKernel + + +log = getArtifactLogger(__name__, "output_code") + + +class CuteDSLTemplate(KernelTemplate): + """Template for generating CuteDSL (CUTLASS Python DSL) kernels.""" + + kernel_type: type[Any] = CuteDSLTemplateKernel + index_counter = itertools.count() + all_templates: dict[str, "CuteDSLTemplate"] = {} + + def __init__( + self, + name: str, + source: str, + subgraph_fn: Optional[Any] = None, + mask_fn: Optional[Any] = None, + ) -> None: + super().__init__(name) + self.source = source + self.subgraph_fn = subgraph_fn + self.mask_fn = mask_fn + self.template = CuteDSLTemplate._template_from_string(source) + assert name not in self.all_templates, f"duplicate template name, {name}" + CuteDSLTemplate.all_templates[name] = self + + @staticmethod + @functools.lru_cache(None) + # pyrefly: ignore [bad-override] + def _template_from_string(source: str) -> Any: + return KernelTemplate._template_from_string(source) + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + """ + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.debug("CuteDSL template choice generation failed: %s", e) # noqa: G200 + return e + except Exception as e: + log.debug("CuteDSL template choice generation error: %s", e) # noqa: G200 + return NotImplementedError(f"CuteDSL template failed: {e}") + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """Generate the CuteDSL kernel caller.""" + input_nodes = kwargs.pop("input_nodes") + layout = kwargs.pop("layout") + mutated_inputs = kwargs.pop("mutated_inputs", None) + subgraphs = kwargs.pop("subgraphs", None) + + kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}" + + if self.template is None: + raise RuntimeError("Template compilation failed (Jinja2 required)") + + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Patch V.graph.get_dtype to handle the fake buf_out buffer + with patch.object( + V.graph, "get_dtype", KernelTemplate._fake_get_dtype(self.output_node) + ): + kernel = self.kernel_type( + kernel_name=kernel_name, + input_nodes=input_nodes, + output_node=self.output_node, + subgraphs=subgraphs, + ) + code = kernel.render(self.template, **kwargs) + + log.debug("Generated CuteDSL Code:\n%s", code) + + bmreq = CuteDSLBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=tuple(), + source_code=code, + ) + + def make_kernel_render(out_node, hint_override: Optional[int] = None): + """ + Factory function that creates a kernel renderer for the final output. + + This closure captures the current template and parameters, but allows + the output node to be specified later. This is used during the final + kernel selection phase when the actual output buffer is available. + """ + render_kernel = self.kernel_type( + kernel_name=str(Placeholder.KERNEL_NAME), + input_nodes=input_nodes, + output_node=out_node, + subgraphs=subgraphs, + ) + + def render(): + return render_kernel.render(self.template, **kwargs) + + return render_kernel, render + + return CuteDSLTemplateCaller( + name=kernel_name, + input_nodes=input_nodes, + layout=layout, + make_kernel_render=make_kernel_render, + bmreq=bmreq, + template=self, + mutated_inputs=mutated_inputs, + ) + + +class CuteDSLTemplateCaller(ChoiceCaller): + """Caller for CuteDSL templates that integrates with the autotuning system.""" + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Any, + bmreq: CuteDSLBenchmarkRequest, + template: "CuteDSLTemplate", + mutated_inputs: Optional[Iterable[IRNode]] = None, + ): + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + description=f"CuteDSL template {name}", + ) + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.mutated_inputs = mutated_inputs + + def __str__(self) -> str: + return f"CuteDSLTemplateCaller({self.name})" + + def benchmark(self, *args, out) -> float: + """Benchmark the kernel execution.""" + return self.bmreq.benchmark(*args, out=out) + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + """Create the output node for this template choice.""" + return TensorBox.create( + CuteDSLTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + mutated_inputs=self.mutated_inputs, + ) + ) + + def call_name(self) -> str: + """Return the kernel call name.""" + return self.name + + def to_callable(self) -> Any: + """Return callable that can execute this kernel.""" + return self.make_kernel_render + + def hash_key(self) -> str: + """Return unique hash key for this choice.""" + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def info_dict(self) -> dict[str, Any]: + """Return information about this kernel.""" + return { + "name": self.name, + "backend": "CuteDSL", + "template": self.template.name, + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b465e3d1ffab27bf67fca9a54e8eb6da6f9843d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py @@ -0,0 +1,290 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import os +from enum import Enum +from typing import Optional, TYPE_CHECKING + +import torch +from torch import dtype as torch_dtype + +from .. import config +from ..virtualized import V +from .multi_kernel import MultiKernel + + +if TYPE_CHECKING: + from collections.abc import Callable + + +log = logging.getLogger(__name__) + + +def _print_debugging_tensor_value_info(msg, arg): + # helper for printing debugging stats for intermediate tensor values + # at jit inductor level codegen + max_numel_to_print = 64 + print(msg) + if not isinstance(arg, torch.Tensor): + print("Value: ", arg) + return + numel = arg.float().numel() + # print the debug printing stats + if numel <= max_numel_to_print: + print(arg) + print("Number of elements: ", numel) + print("Size: ", arg.float().size()) + print("Dtype: ", arg.float().mean().item()) + print("Mean: ", arg.float().mean().item()) + print("Min: ", arg.float().min().item()) + print("Max: ", arg.float().max().item()) + print("Std: ", arg.float().std().item()) + + +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed. + # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc. + PRINT_KERNEL_NAMES_ONLY = "3" + + +class DebugPrinterManager: + def __init__( + self, + debug_printer_level, + use_array_ref: bool, + writeline: Optional[Callable[..., None]] = None, + args_to_print_or_save: Optional[list[str]] = None, + kernel_name: str = "", + kernel=None, + arg_signatures: Optional[list[type]] = None, + kernel_type=None, + ): + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + self.use_array_ref = use_array_ref + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures: Optional[list[type]] = None + self.kernel = kernel + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() + self.kernel_type = None + + def __enter__(self): + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[list[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + # Print all kernel names to the console only + self.codegen_intermediate_tensor_value_print( + [], + self.kernel_name, + before_launch, + ) + + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> list[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] + + def set_printer_args( + self, + args_to_print_or_save: list[str], + kernel_name: str, + arg_signatures: Optional[list[type]], + kernel, + kernel_type=None, + ): + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + + self.kernel_type = kernel_type + # Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to + # get the list of args_to_print_or_save + # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls + if kernel_type == "extern": + args_to_print_or_save_extern = [ + arg + for arg in args_to_print_or_save + if isinstance(arg, str) and arg.startswith(("buf", "arg")) + ] + self.args_to_print_or_save = args_to_print_or_save_extern + elif kernel_type == "cpp": + self.args_to_print_or_save = [ + ( + f"copy_arrayref_tensor_to_tensor({arg})" + if self.use_array_ref + else arg + ) + for arg in args_to_print_or_save + if isinstance(arg, str) and arg.startswith(("buf", "arg")) + ] + else: + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures = arg_signatures + self.kernel = kernel + + def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None: + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for arg in input_args_to_print: + if V.graph.cpp_wrapper: + V.graph.wrapper_code.prefix.writeline( + f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");' + ) + + def codegen_intermediate_tensor_value_save( + self, + args_to_save, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_save): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + cwd = os.getcwd() + saved_dir = cwd + "/tmp/jit_inductor/" + if not os.path.exists(saved_dir): + log.info( + "Creating directory to save inductor intermediate tensor values." + ) + os.makedirs(saved_dir) + # Save the model to the directory + saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt" + log.info( + "Saved intermediate tensor %s for %s to %s", + arg, + kernel_name, + saved_path, + ) + line = f"torch.save({arg}, '{saved_path}')" + V.graph.wrapper_code.writeline(line) + + def codegen_intermediate_tensor_value_print( + self, + args_to_print, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + launch_prefix = "before_launch" if before_launch else "after_launch" + + # if the debug printing level is PRINT_KERNEL_NAMES_ONLY + # we only print the kernel name to the console + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]\\n");' + ) + return + + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for i, arg in enumerate(args_to_print): + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name.lower() not in self.filtered_kernel_names_to_print + ): + continue + if V.graph.cpp_wrapper: + if arg_signatures is not None and isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + elif arg_signatures is not None and isinstance( + arg_signatures[i], + ( + type(torch._inductor.codegen.wrapper.SymbolicCallArg), + type(int), + type(float), + type(bool), + ), + ): + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");' + ) + else: + if arg_signatures is None and self.kernel_type in ("cpp", "extern"): + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + else: + V.graph.wrapper_code.writeline( + f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py new file mode 100644 index 0000000000000000000000000000000000000000..e47e8e6d7841d4b70b7b41f2298bcd083fe2b8ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py @@ -0,0 +1,1732 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import re +from collections import defaultdict +from math import inf +from typing import Any, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._logging + +from ..._prims_common import is_integer_dtype +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir +from ..codecache import HalideCodeCache +from ..ir import get_reduction_combine_fn +from ..metrics import is_metric_table_enabled, log_kernel_metadata +from ..ops_handler import AddParenHandler +from ..runtime.hints import HalideInputSpec, HalideMeta +from ..utils import ( + get_bounds_index_expr, + get_kernel_metadata, + parallel_num_threads, + sympy_index_symbol, + sympy_subs, +) +from ..virtualized import _ops as ops, V +from .common import ( + BackendFeature, + CSEVariable, + DeferredLine, + IndentedBuffer, + KernelArgType, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, +) +from .cpp import DTYPE_TO_CPP +from .cpp_utils import cexpr +from .simd import constant_repr, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ..ops_handler import ReductionType, StoreMode + from ..shape_propagation import BlockShapeType + +log = logging.getLogger(__name__) + + +def halide_constant(val): + if isinstance(val, int) and not (-2147483648 <= val <= 2147483647): + info = torch.iinfo(torch.int64) + if val == info.min: + return "hl.Int(64).min()" + if val == info.max: + return "hl.Int(64).max()" + return f"hl.i64({val!r})" + if isinstance(val, float): + return f"hl.f64({constant_repr(val)})" + return repr(val) + + +class Unsupported(RuntimeError): + def __init__(self, thing) -> None: + super().__init__(f"halide backend does not support: {thing}") + + +class HalidePrinter(PythonPrinter): + @staticmethod + def cast_index(expr): + return f"hl.cast({V.kernel.index_dtype}, {expr})" + + @staticmethod + def cast_float(expr): + return f"hl.cast(hl.Float(32), {expr})" + + def _print_Float(self, expr): + return f"hl.f32({expr})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"hl.f32({self._print(expr.args[0])})" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + + _print_FloorToInt = _print_floor + + def _print_Trunc(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") + + _print_TruncToInt = _print_Trunc + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.ceil({self._print(expr.args[0])})") + + def _helper_sqrt(self, expr): + return f"hl.sqrt({self.cast_float(self._print(expr))})" + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"hl.select({c}, {p}, {q})" + + def _print_Min(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"hl.min({a}, {b})" + + def _print_Max(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + + return f"hl.max({a}, {b})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.abs({self._print(expr.args[0])})") + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"hl.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"hl.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"hl.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"hl.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"hl.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"hl.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"hl.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"hl.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"hl.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr): + raise NotImplementedError("log2") + + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.cast_float(self.doprint(x)) + div = self.cast_float(self.doprint(div)) + return self.cast_index(f"hl.floor({x} / {div})") + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.round({self._print(expr.args[0])})") + + _print_RoundToInt = _print_Round + + def _print_IntTrueDiv(self, expr): + a, b = expr.args + # force a cast to float + return f"({a}) / ({b}+hl.f32(0))" + + def _print_RoundDecimal(self, expr): + val, n = expr.args + val = self._print(val) + n = int(n) + return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))" + + +texpr = HalidePrinter().doprint +pexpr = PythonPrinter().doprint + + +_halide_type = { + torch.bool: "hl.Bool()", + torch.bfloat16: "hl.BFloat(16)", + torch.float16: "hl.Float(16)", + torch.float32: "hl.Float(32)", + torch.float64: "hl.Float(64)", + torch.int8: "hl.Int(8)", + torch.int16: "hl.Int(16)", + torch.int32: "hl.Int(32)", + torch.int64: "hl.Int(64)", + torch.uint8: "hl.UInt(8)", + torch.uint16: "hl.UInt(16)", + torch.uint32: "hl.UInt(32)", + torch.uint64: "hl.UInt(64)", +} + + +def halide_type(dtype): + return _halide_type[dtype] + + +def halide_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64: + dtype = torch.int32 + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + return halide_type(dtype) + + +class HalideOverrides(OpOverrides): + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + if dtype == torch.bool: + return f"({x} != 0)" + return f"hl.cast({halide_type(dtype)}, {x})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + if src_dtype in (torch.float16, torch.bfloat16): + x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32 + line = f"hl.reinterpret({halide_type(dtype)}, {x})" + if dtype in (torch.float16, torch.bfloat16): + line = f"hl.cast(hl.Float(32), {line})" + return line + + @classmethod + def constant(cls, value, dtype): + return cls.to_dtype(halide_constant(value), dtype) + + @staticmethod + def abs(x): + return f"hl.abs({x})" + + @staticmethod + def exp(x): + if not hasattr(x, "name"): + return f"hl.exp({x})" + return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" + + @staticmethod + def sqrt(x): + return f"hl.sqrt({x})" + + @staticmethod + def minimum(a, b): + # return f"hl.min({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.min({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" + + @staticmethod + def maximum(a, b): + # return f"hl.max({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.max({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" + + @staticmethod + def where(a, b, c): + if hasattr(b, "name"): + c = f"hl.cast({b.name}.type(), {c})" + return f"hl.select({a}, {b}, {c})" + + @staticmethod + def cos(x): + return f"hl.cos({x})" + + @staticmethod + def sin(x): + return f"hl.sin({x})" + + @staticmethod + def lgamma(x): + raise Unsupported("lgamma") + + @staticmethod + def erf(x): + return f"hl.erf({x})" + + @staticmethod + def cosh(x): + return f"hl.cosh({x})" + + @staticmethod + def sinh(x): + return f"hl.sinh({x})" + + @staticmethod + def acos(x): + return f"hl.acos({x})" + + @staticmethod + def acosh(x): + return f"hl.acosh({x})" + + @staticmethod + def asin(x): + return f"hl.asin({x})" + + @staticmethod + def asinh(x): + return f"hl.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"hl.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"hl.atan({x})" + + @staticmethod + def atanh(x): + return f"hl.atanh({x})" + + @staticmethod + def copysign(x, y): + raise Unsupported("copysign") + + @staticmethod + def erfinv(x): + raise Unsupported("erfinv") + + @staticmethod + def hypot(x, y): + return f"hl.hypot({x}, {y})" + + @staticmethod + def nextafter(x, y): + raise Unsupported("nextafter") + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + return f"halide_helpers.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + return f"halide_helpers.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" + + @staticmethod + def rsqrt(x): + # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues + return f"1./hl.sqrt({x})" + + @staticmethod + def tan(x): + return f"hl.tan({x})" + + @staticmethod + def tanh(x): + return f"hl.tanh({x})" + + @staticmethod + def signbit(x): + return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" + + @staticmethod + def fmod(a, b): + # TODO(jansel): find a better way to do this, builtin % has wrong sign + return f"{a} - hl.trunc({a}/{b})*{b}" + + @staticmethod + def pow(a, b): + return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy + + @staticmethod + def log(x): + return f"hl.log({x})" # hl.fast_log fails accuracy + + @staticmethod + def log2(x): + raise NotImplementedError("log2") + + @staticmethod + def isinf(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def isnan(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def round(x): + return f"hl.round({x})" + + @staticmethod + def floor(x): + return f"hl.floor({x})" + + @staticmethod + def int_truediv(a, b): + return f"({a}) / ({b} + hl.f32(0))" + + @staticmethod + def floordiv(a, b): + # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work + return ( + f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @classmethod + def sign(cls, x): + left = ops.to_dtype(ops.lt("0", x), torch.int8) + right = ops.to_dtype(ops.lt(x, "0"), torch.int8) + sub = ops.sub(left, right) + return f"hl.cast({x.name}.type(), {sub})" + + @staticmethod + def trunc(x): + return f"hl.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # this causes crashes with floating point exception, see test_div_zero_dim_cpu + # return f"hl.div_round_to_zero({a}, {b})" + return ( + f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @staticmethod + def ceil(x): + return f"hl.ceil({x})" + + @staticmethod + def relu(x): + return f"hl.max({x}, 0)" + + @classmethod + def index_expr(cls, expr, dtype): + index = V.kernel.prepare_indexing(expr) + var = V.kernel.genfunc( + V.kernel.index_to_str(index), + V.kernel.used_dims_from_index(index), + bounds=get_bounds_index_expr(expr), + ) + if dtype not in (torch.int32, torch.int64): + return ops.to_dtype(var, dtype) + return var + + @classmethod + def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True): + # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow + index_var = ops.to_dtype(index_var, torch.int32) + index_var = ops.halide_clamp(index_var, size, check) + index_var.indirect_indexing_size = size + return sympy_index_symbol(str(index_var)) + + @classmethod + def halide_clamp(cls, value, size, check): + end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1) + if not isinstance(size, (int, sympy.Integer)): + end = f"hl.cast({value.name}.type(), {end})" + # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692 + # return f"hl.unsafe_promise_clamped({value}, 0, {end})" + return f"hl.clamp({value}, 0, {end})" + + @staticmethod + def masked(mask, body, other): + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) + + # Take dtype from result to prevent accidental promotion + other = V.kernel.genfunc( + f"hl.cast({result.name}.type(), {halide_constant(other)})", + [], + bounds=ValueRanges.wrap(other), + shape=result.shape, + ) + # TODO(jansel): look into removing the where in the same places triton does + return ops.where(new_mask, result, other) + + @staticmethod + def frexp(x): + raise NotImplementedError("frexp") + + @staticmethod + def device_assert_async(cond, msg): + raise NotImplementedError("device_assert_async") + + @staticmethod + # pyrefly: ignore [bad-override] + def partial_accumulate( + name: str, + reduction_type: str, + value: CSEVariable, + extra_meta: dict[str, Any], + ) -> None: + raise NotImplementedError + + +HalideOverrides._initialize_pointwise_overrides("halide") + + +class HalideCSEVariable(CSEVariable): + undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") + + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + shape: BlockShapeType = None, + ) -> None: + super().__init__(name, bounds, dtype, shape=shape) + self.used_dims: Optional[list[sympy.Symbol]] = None + + def update_on_args(self, name, args, kwargs): + used = OrderedSet(self.used_dims or ()) + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, HalideCSEVariable): + assert arg.used_dims is not None, (name, arg, args) + used.update(arg.used_dims) + self.used_dims = V.kernel.sort_used_dims(used) + + def index_str(self, dims): + if len(dims) == 0: + return f"{self.name}[()]" + # Reversed since Halide is column major + return f"{self.name}[{', '.join(map(str, dims))}]" + + def __str__(self) -> str: + if self.used_dims is None: + # This will get recomputed and replaced in codegen_kernel() + return f"{self.name}[?]" + return self.index_str(self.used_dims) + + def subs_str(self, replacements): + assert self.used_dims is not None and all( + isinstance(x, sympy.Expr) for x in self.used_dims + ) + return self.index_str([replacements.get(n, n) for n in self.used_dims]) + + +@dataclasses.dataclass +class DimensionInfo: + expr: Optional[sympy.Expr] + size: sympy.Expr + stride: sympy.Expr + + def __init__(self, expr, size, stride) -> None: + super().__init__() + if V.graph.sizevars.statically_known_lt(stride, 0): + stride = -stride + expr = -expr + self.expr = expr + self.size = size + self.stride = stride + + def index_str(self, replacements=None, zero_vars=False): + assert self.expr is not None + expr = self.expr + if zero_vars and expr == 0: + return "hl.Var()" + if replacements: + replacements = {**replacements} + # pyrefly: ignore [missing-attribute] + for sym in expr.free_symbols: + if symbol_is_type(sym, SymT.TMP): + assert isinstance(sym, sympy.Symbol) + var = V.kernel.lookup_cse_var(sym.name) + assert isinstance(var, HalideCSEVariable) + replacements[sym] = sympy_index_symbol(var.subs_str(replacements)) + expr = sympy_subs(expr, replacements) + return V.kernel.index_to_str(expr) + + +def eq(left, right): + if V.graph.sizevars.statically_known_equals(left, right): + return True + try: + a = V.graph.sizevars.size_hint_or_throw(left) + b = V.graph.sizevars.size_hint_or_throw(right) + except TypeError: # unbacked symints + return False + if a == b: + V.graph.sizevars.check_equals(left, right) + return a == b + + +def lt(left, right): + if V.graph.sizevars.statically_known_lt(left, right): + return True + try: + a = V.graph.sizevars.size_hint_or_throw(left) + b = V.graph.sizevars.size_hint_or_throw(right) + except TypeError: # unbacked symints + gcd = sympy.gcd(left, right) + if gcd == left: + return left != right + return False + if a < b: + V.graph.sizevars.check_lt(left, right) + return a < b + + +class HalideKernel(SIMDKernel): + overrides = HalideOverrides # type: ignore[assignment] + kexpr: Callable[[sympy.Expr], str] = texpr + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs, + ) -> None: + super().__init__(tiling, **kwargs) + # For halide, we just write directly to the body + self.compute = self.body + self.loads = self.body + self.stores = self.body + self.indexing_code_dom = IndentedBuffer() + self.needs_dom_indexing = self.inside_reduction + self.has_reduction = self.inside_reduction + self.buffer_dimensions: dict[str, list[DimensionInfo]] = {} + self.buffer_offsets: dict[str, sympy.Expr] = {} + # {h0: size1, h1: size2, ...} + self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {} + # {x0: h0, x1: h1+10*h2, ...} + self.index_replacements: dict[sympy.Expr, sympy.Expr] = {} + # {h1: hr1, ...} + self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {} + # {"i": {h0: hi0}, "o": ...} + self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {} + # {"in_ptr0": ["in_ptr0_view0"], ...} + self.buffer_aliases: dict[str, list[str]] = defaultdict(list) + self.has_indirect_indexing = False + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return halide_type(dtype) + + # pyrefly: ignore [bad-override] + def create_cse_var(self, name, bounds=None, dtype=None, shape=None): + self.body.writeline(f"{name} = hl.Func({name!r})") + # pyrefly: ignore [bad-argument-type] + return HalideCSEVariable(name, bounds, dtype, shape) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]): + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + + This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing + scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex + we base indexing on a larger number of vars whose product combines to those. + + This function populates self.halide_vars, self.index_replacements, and self.reduction_renames + """ + assert not ( + self.index_replacements or self.halide_vars or self.reduction_renames + ) + size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] + # pyrefly: ignore [bad-assignment] + indices = dict.fromkeys(map(super().prepare_indexing, indices)) + all_used_symbols = OrderedSet[Any]() + sym_to_node = { + n.symbol(): n + for n in itertools.chain.from_iterable( + [tree.nodes.values() for tree in self.range_trees] + ) + } + + def simplify(expr): + return sympy.simplify( + V.graph.sizevars.remove_precomputed_replacements(expr) + ) + + def visit_modular_indexing(base, divisor, modulus): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + V.graph.sizevars.evaluate_min( + modulus, FloorDiv(node.length, divisor) + ), + ).symbol() + ) + + def visit_floor_div(base, divisor): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + FloorDiv(node.length, divisor), + ).symbol() + ) + + # first figure out all_used_symbols to do dead symbol elimination + for index in indices: + if index.has(ModularIndexing): + index.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + if index.has(FloorDiv): + index.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_floor_div, + ) + all_used_symbols.update(super().prepare_indexing(index).free_symbols) + + self.has_indirect_indexing = any( + symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols + ) + + had_fallback = False + for tree in reversed(self.range_trees): + nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols] + nodes.sort(key=lambda n: size_hint(n.divisor)) + if not nodes: + nodes.append(tree.lookup(1, tree.numel)) + handled_count = 0 + divisor = sympy.S.One + added_sym_size = [] + # decide on a minimal set of symbols and put them in self.halide_vars + while handled_count < len(nodes) and not eq(tree.numel, divisor): + sizes_to_add = [ + simplify(n.length) for n in nodes if eq(n.divisor, divisor) + ] + handled_count += len(sizes_to_add) + assert sizes_to_add, nodes + end = divisor * functools.reduce( + V.graph.sizevars.evaluate_max, sizes_to_add + ) + sizes_to_add.extend( + [ + simplify(n.divisor / divisor) + for n in nodes + if lt(divisor, n.divisor) and lt(n.divisor, end) + ] + ) + while sizes_to_add: + next_size = functools.reduce(sympy.gcd, sizes_to_add) + if eq(next_size, 1): + # sizes share no common factors, e.g [2, 21, 42, 441, 889056] + # TODO(jansel): we should just prevent fusion in cases that hit this + next_size = simplify(tree.numel / divisor) + assert not eq(next_size, 1) + sizes_to_add = [] + handled_count = len(nodes) + had_fallback = True + sym = sympy_index_symbol(f"h{len(self.halide_vars)}") + # pyrefly: ignore [missing-argument] + if tree.is_reduction: + self.reduction_renames[sym] = sympy_index_symbol( + f"hr{len(self.halide_vars)}" + ) + self.halide_vars[sym] = next_size + added_sym_size.append((sym, next_size)) + divisor *= next_size + new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)] + handled_count += len(new_sizes) + prior_len = len(sizes_to_add) + sizes_to_add = [ + sympy.simplify(s / next_size) + for s in sizes_to_add + if not eq(s, next_size) + ] + assert len(sizes_to_add) < prior_len or prior_len == 0 + sizes_to_add.extend(new_sizes) + + # create a mapping to the new set of symbols in self.index_replacements + for node in nodes: + try: + idx = 0 + divisor = 1 + while not eq(node.divisor, divisor): + sym, size = added_sym_size[idx] + idx += 1 + divisor *= size + length = 1 + expr = sympy.S.Zero + while not eq(node.length, length): + sym, size = added_sym_size[idx] + idx += 1 + expr += length * sym + length *= size + self.index_replacements[node.symbol()] = expr + except IndexError: + assert had_fallback + full_index = sympy.S.Zero + stride = sympy.S.One + for sym, size in added_sym_size: + full_index += stride * sym + stride *= size + self.index_replacements[node.symbol()] = ( + V.graph.sizevars.simplify_with_ranges( + ModularIndexing(full_index, node.divisor, node.length), + self.halide_vars, # type: ignore[arg-type] + ) + ) + + # codegen the variable definitions + for sym in self.halide_vars: + self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})") + if self.reduction_renames: + self.codegen_rdom( + "rdom", + {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()}, + ) + + def setup_dom_indexing(self): + """RDom based indexing uses explicit iteration ranges for Func updates""" + prefix = "i" if self.inside_reduction else "o" + if prefix in self.dom_renames: + return self.dom_renames[prefix] + + renames = {} + for var in self.halide_vars: + if not self.inside_reduction and var in self.reduction_renames: + continue + m = re.match(r"^h(\d+)$", var.name) + assert m + renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}") + + self.codegen_rdom( + f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()} + ) + + self.dom_renames[prefix] = renames + return renames + + def codegen_rdom(self, name, vars): + rsizes = [ + f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})" + for size in vars.values() + ] + self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])") + for i, rsym in enumerate(vars.keys()): + self.indexing_code.writeline(f"{rsym} = {name}[{i}]") + + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = super().prepare_indexing(index) + index = sympy_subs(index, self.index_replacements) + return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type] + + def sym_size(self, sym): + """The size of an index symbol""" + if symbol_is_type(sym, SymT.TMP): + return self.lookup_cse_var(sym.name).indirect_indexing_size + return self.halide_vars[sym] + + def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): + """Convert address-based indexing into dimensions using self.halide_vars""" + symbols = [] + for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined] + if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)): + symbols.append(sym) + else: + assert symbol_is_type( + sym, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ), sym + + # group the expression by variables used + offset = sympy.S.Zero + split_expr = dict.fromkeys(symbols, sympy.S.Zero) + split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = [] + index = sympy.expand(self.rename_indexing(index)) + for part in index.args if isinstance(index, sympy.Add) else [index]: + part_vars = [v for v in part.free_symbols if v in split_expr] + if len(part_vars) == 0: + offset += part + elif len(part_vars) == 1: + split_expr[part_vars[0]] += part + else: + new_split_failed = [] + for i in range(len(split_failed)): + assert split_failed[i] is not None + other_vars, other_part = split_failed[i] + if OrderedSet(other_vars) & OrderedSet(part_vars): + part_vars.extend([v for v in other_vars if v not in part_vars]) + part += other_part + else: + new_split_failed.append((other_vars, other_part)) + split_failed = [*new_split_failed, (part_vars, part)] + + def expr_to_dimension(expr, syms): + expr = sympy.factor(expr) + if len(syms) == 1: + stride_wild = sympy.Wild("wild", exclude=symbols) + m = expr.match(stride_wild * syms[0]) + if m: + return DimensionInfo( + syms[0], self.sym_size(syms[0]), m[stride_wild] + ) + assert not is_store, expr + length = sympy.simplify( + sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 + ) + stride = sympy.S.One + if isinstance(expr, sympy.Mul): + for term in expr.args: + if isinstance(term, sympy.Integer): + stride *= term + expr = sympy.simplify(expr / term) + length = sympy.simplify(sympy.ceiling(length / term)) + return DimensionInfo(expr, length, stride) + + # try to turn each group into a strided access + dims = [] + for syms, expr in split_failed: + for v in syms: + expr += split_expr.pop(v) + dims.append(expr_to_dimension(expr, syms)) + for sym, expr in split_expr.items(): + dims.append(expr_to_dimension(expr, [sym])) + dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type] + + if not dims: # scalar load/store + if self.has_indirect_indexing: + # workaround https://github.com/halide/Halide/issues/8338 + dims.append(DimensionInfo(sympy.S.Zero, 1, 1)) + elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): + # Halide assumes dimension 0 is stride == 1, so add a dummy dimension + dims.insert( + 0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1) + ) + + if dims and not is_store: + if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq( + offset, self.buffer_offsets[var] + ): + # reuse the existing offset to avoid needing an input alias + self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var]) + offset = self.buffer_offsets[var] + elif V.graph.sizevars.statically_known_gt( + offset, 0 + ): # TODO(jansel): negative offsets + # roll the offset into the dimensions for cleaner indexing + self.apply_offset_to_dimension(dims, offset) + offset = 0 + + orig_var = var + for i in itertools.count(): + if self.install_dims(var, dims, offset, is_store): + return var, dims + assert not is_store + var = f"{orig_var}_view{i}" + if var not in self.buffer_aliases[orig_var]: + self.buffer_aliases[orig_var].append(var) + + def install_dims(self, var, dims, offset, is_store): + """Try to set self.buffer_dimensions[var], return True on success""" + if var not in self.buffer_dimensions: + self.buffer_dimensions[var] = dims + self.buffer_offsets[var] = offset + return True + if self.buffer_offsets[var] != offset or len( + self.buffer_dimensions[var] + ) != len(dims): + return False + if is_store: + return self.buffer_dimensions[var] == dims + for old, new in zip(self.buffer_dimensions[var], dims): + if old.stride != new.stride: + return False + if old.size != new.size or old.expr != new.expr: + old.size = V.graph.sizevars.evaluate_max(old.size, new.size) + old.expr = None + return True + + def apply_offset_to_dimension(self, dims, offset): + if offset == 0: + return + for i in reversed(range(len(dims))): + if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq( + offset, dims[i].stride + ): + part = FloorDiv(offset, dims[i].stride) + offset -= part * dims[i].stride + dims[i].expr += part + assert offset == 0 + + def used_dims_from_index(self, index: sympy.Expr): + """Detect which range trees are used to populate HalideCSEVariable.used_dims""" + used_dims = OrderedSet[sympy.Symbol]() + for sym in index.free_symbols: + assert isinstance(sym, sympy.Symbol) + if symbol_is_type(sym, SymT.TMP): + # indirect indexing + cse_var = self.lookup_cse_var(sym.name) + assert ( + isinstance(cse_var, HalideCSEVariable) + and cse_var.used_dims is not None + ) + used_dims.update(cse_var.used_dims) + elif symbol_is_type(sym, SymT.HALIDE): + used_dims.add(sym) + elif symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX) + ): + pass + else: + raise NotImplementedError(f"unhandled symbol {sym}") + return self.sort_used_dims(used_dims) + + def sort_used_dims(self, used_dims): + assert all(isinstance(x, sympy.Expr) for x in used_dims) + ordered = [ + sym + for sym in itertools.chain( + self.halide_vars, self.reduction_renames.values() + ) + if sym in used_dims + ] + assert len(ordered) == len(used_dims) + return ordered + + def make_index_str(self, dims, replacements=None, zero_vars=False): + index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims) + if len(dims) == 0: + index_str = "()" + elif len(dims) == 1: + # workaround for https://github.com/halide/Halide/issues/8299 + index_str = f"{index_str}," + return index_str + + def load(self, name: str, index: sympy.Expr): + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, False) + line = f"{var}[{self.make_index_str(dims)}]" + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + line = f"hl.cast(hl.Float(32), {line})" + + if self._load_mask: + assert ( + isinstance(self._load_mask, HalideCSEVariable) + and self._load_mask.used_dims is not None + ) + used_dims = OrderedSet( + (*self.used_dims_from_index(index), *self._load_mask.used_dims) + ) + result = self.newfunc(self.sort_used_dims(used_dims)) + if result.used_dims: + self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])") + self.body.writeline(f"{result.name}_mask.where({self._load_mask})") + other = self.kexpr(self._load_other or 0) # type: ignore[arg-type] + self.body.writeline( + f"{result} = hl.cast({halide_type(dtype)}, {other})" + ) + self.body.writeline( + f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)" + ) + else: + # scalar case + self.body.writeline( + f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))" + ) + return result + else: + return self.genfunc(line, self.used_dims_from_index(index)) + + def lookup_cse_var(self, name: str): + return self.cse.varname_map[re.sub(r"\[.*", "", name)] + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + """Codegen a store to an OutputBuffer""" + assert isinstance(value, HalideCSEVariable) + var = self.args.output(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, True) + if self.is_indirect_indexing(index) or mode is not None: + replacements = self.setup_dom_indexing() + index_str = self.make_index_str(dims, replacements) + value_str = value.subs_str(replacements) + undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()" + self.body.writeline( + DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())") + ) + else: + index_str = self.make_index_str(dims, zero_vars=True) + value_str = str(value) + + dtype = V.graph.get_dtype(name) + if mode is None: + line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})" + elif mode == "atomic_add": + line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})" + else: + raise NotImplementedError(f"store mode={mode}") + self.body.writeline(DeferredLine(name, line)) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation""" + assert self.inside_reduction + assert not self._load_mask + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + if isinstance(value, tuple): + assert reduction_type == "welford_combine" + self.cse.reduction_cache[cache_key] = result_tuple = ( + self.welford_combine_impl(*value) + ) + return result_tuple + + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + reduction_vars = OrderedSet(self.reduction_renames) + result_var = self.newfunc( + [v for v in value.used_dims if v not in reduction_vars], + ) + if reduction_vars - OrderedSet(value.used_dims): + value = self.genfunc( + f"{value}", + self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))), + shape=value.shape, + ) + value_str = value.subs_str(self.reduction_renames) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + acc_type = halide_acc_type(dtype) + + if reduction_type in ("argmax", "argmin"): + index = f"{result_var.name}_{reduction_type}" + self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})") + # turn the N-D argmax index into a 1-D one + parts = [] + stride = 1 + for i, sym in enumerate(self.reduction_renames): + # pyrefly: ignore [bad-argument-type] + parts.append(f"{index}[{i}]") + if stride != 1: + # pyrefly: ignore [unsupported-operation] + parts[-1] += f"*{stride}" + stride *= self.halide_vars[sym] + self.body.writeline(f"{result_var} = {' + '.join(parts)}") + elif reduction_type == "welford_reduce": + # TODO(jansel): implement welford_reduce without fallback + result_var = self.welford_reduce_fallback(dtype, value) + else: + combine_fn = get_reduction_combine_fn(reduction_type, acc_type) + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] + default_str = f"hl.cast({acc_type}, {halide_constant(default)})" + self.body.writeline(f"{result_var} = {default_str}") + self.body.writeline(f"{result_var} = {combine_str}") + + self.cse.reduction_cache[cache_key] = result_var + return result_var + + def welford_combine_impl(self, mean, m2, weight): + assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None + assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None + assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None + used_dims = OrderedSet( + (*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars + ) + used_dims -= OrderedSet(self.reduction_renames) + result_var = self.newfunc(self.sort_used_dims(used_dims)) + default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)] + pfx = result_var.name + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])") + self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]") + self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]") + self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]") + self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}") + self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}") + self.body.writeline( + f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}" + ) + self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1") + self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2") + self.body.writeline( + f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)" + ) + update = [ + f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w", + f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w", + f"{pfx}_new_weight", + ] + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])") + + unpacked = [] + for i in range(3): + unpacked.append(self.newfunc(result_var.used_dims)) + self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]") + return tuple(unpacked) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values_orig: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert len(dtypes) == len(values_orig) + values: list[HalideCSEVariable] = [] + all_used_dims = OrderedSet[sympy.Symbol]() + + for value in values_orig: + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames): + values.append(value) + else: + values.append( + self.genfunc( + f"{value}", + [*value.used_dims, [*self.reduction_renames][:1]], + shape=value.shape, + ) + ) + all_used_dims.update(value.used_dims) + result_var = self.newfunc(self.sort_used_dims(all_used_dims)) + assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet( + self.reduction_renames + ) + initial = [ + f"hl.cast({halide_acc_type(dtype)}, {value})" + for dtype, value in zip(dtypes, values) + ] + + length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel)) + scan_dom = f"{result_var.name}_rdom" + scan = f"{scan_dom}.x" + self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") + + assert len(self.reduction_renames) == 1, ( + "multi-dimensional scan not implemented" + ) + (scan_var,) = [*self.reduction_renames] # type: ignore[misc] + scan_renames_cur = {scan_var: sympy_index_symbol(scan)} + scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} + + if len(values) == 1: + + def maybe_tuple(x): + return x[0] + + read_left = [result_var.subs_str(scan_renames_pri)] + read_right = [result_var.subs_str(scan_renames_cur)] + else: + + def maybe_tuple(x): + return f"hl.Tuple([{', '.join(x)}])" + + read_left = [ + result_var.subs_str(scan_renames_pri) + f"[{i}]" + for i in range(len(values)) + ] + read_right = [ + result_var.subs_str(scan_renames_cur) + f"[{i}]" + for i in range(len(values)) + ] + + self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") + + # Disable CSE for update fn + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] + self.body.writeline( + f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" + ) + + if len(values) == 1: + return (result_var,) + + unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values] + for i, v in enumerate(unpack_vars): + self.body.writeline(f"{v} = {result_var}[{i}]") + return tuple(unpack_vars) + + def genfunc( + self, + line, + used_dims, + *, + bounds=ValueRanges.unknown(), + shape: BlockShapeType = None, + ) -> HalideCSEVariable: + var = self.cse.generate(self.body, line, bounds=bounds, shape=shape) + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable: + var = self.cse.newvar(shape=shape) + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def halide_buffer_numel(self, name: str): + """ + We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch + supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while + PyTorch's numel excludes them. + """ + return V.graph.get_buffer(name).get_layout().storage_size() + + def halide_argdefs(self): + """ + Halide requires scalar inputs before outputs, so need to reorder args. + """ + + def arg_order(arg_tuple): + _call_str, arg = arg_tuple + if isinstance(arg, SizeArg): + return 1 # this would normally be at the end, move it to middle + elif "out_ptr" in arg.name: + return 2 + else: + assert "in_ptr" in arg.name + return 0 + + result: list[tuple[Optional[str], KernelArgType]] = [] + _, a, b, _ = self.args.python_argdefs() + for call_str, arg in sorted(zip(a, b), key=arg_order): + result.append((call_str, arg)) + if isinstance(arg, TensorArg): + assert arg.offset == 0 and arg.alias_of is None + result.extend( + ( + None, + TensorArg( + alias, + arg.buffer, + arg.dtype, + arg.offset, + alias_of=arg.name, + ), + ) + for alias in self.buffer_aliases.get(arg.name, ()) + ) + return result + + def halide_kernel_meta(self) -> HalideMeta: + """Compute metadata required by codecache.py""" + argtypes = [] + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + shape = None + stride = None + offset = None + dtype = "long" + else: + shape = [ + cexpr(self.rename_indexing(x.size)) + for x in self.buffer_dimensions[arg.name] + ] + stride = [ + cexpr(self.rename_indexing(x.stride)) + for x in self.buffer_dimensions[arg.name] + ] + assert len(shape) == len(stride) + offset = cexpr(self.buffer_offsets[arg.name]) + dtype = f"{DTYPE_TO_CPP[arg.dtype]}*" + argtypes.append( + HalideInputSpec( + dtype, + arg.name, + shape=shape, + stride=stride, + offset=offset, + alias_of=arg.alias_of, + ) + ) + + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cpu": + target = [config.halide.cpu_target] + scheduler = config.halide.scheduler_cpu + scheduler_flags = { + "parallelism": parallel_num_threads(), + } + cuda_device = None + else: + assert current_device.type == "cuda", "only cpu/cuda supported" + assert current_device.index <= 0, "only default device supported" + target = [config.halide.gpu_target] + scheduler = config.halide.scheduler_cuda + capability = torch.cuda.get_device_properties(current_device) + if "cuda_capability" not in target[0]: + for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: + if capability.major >= major and capability.minor >= minor: + target.append(f"cuda_capability_{major}{minor}") + break + target.append("user_context") + scheduler_flags = { + "parallelism": capability.multi_processor_count, + # TODO(jansel): explore other flags, see: + # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp + } + cuda_device = max(0, current_device.index) + + # strict_float is requires for correctness + target.append("strict_float") + + # without this we will initialize cuda once per kernel and hit errors + target.append("no_runtime") + + if not config.halide.asserts: + target.append("no_asserts") + + if config.halide.debug: + target.append("debug") + + if "64" in self.index_dtype: + # TODO(jansel): it is unclear if this does anything, since input sizes are still int32 + target.append("large_buffers") + + return HalideMeta( + argtypes, + target="-".join(target), + scheduler=scheduler, + scheduler_flags=scheduler_flags, # type: ignore[arg-type] + cuda_device=cuda_device, + ) + + def codegen_kernel(self, name=None): + """Called at the end to generate a final kernel string""" + if self.args.inplace_buffers: + raise Unsupported("inplace_buffers") + meta = self.halide_kernel_meta() # ensure needed args are added early + code = IndentedBuffer() + code.splice( + """ + import halide as hl + from torch._inductor.runtime import halide_helpers + from math import inf, nan + + @hl.generator(name="kernel") + class Kernel: + """, + strip=True, + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})") + else: + assert arg.buffer, arg + argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer" + argtype = halide_type(arg.dtype) + ndim = len(self.buffer_dimensions[arg.name]) + code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})") + code.splice( + """ + def generate(g): + """ + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + code.writeline(f"{arg.name} = g.{arg.name}") + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.indexing_code) + + def update_index(m): + var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)]) + assert var.used_dims is not None, var + return str(var) + + for line in self.body._lines: + if isinstance(line, str): + # fill in missing indices + line = HalideCSEVariable.undefined_re.sub(update_index, line) + code.writeline(line) + code.writeline("") + code.writeline("assert g.using_autoscheduler()") + + for _, arg in self.halide_argdefs(): + # fallback=1 below because halide requires buffers to be at least as large as the estimates + # This causes crashes if our estimate is greater than the vector length + # https://github.com/halide/Halide/issues/3103 + if isinstance(arg, SizeArg): + hint = V.graph.sizevars.size_hint(arg.expr, fallback=1) + code.writeline(f"{arg.name}.set_estimate({hint})") + else: + dims = self.buffer_dimensions[arg.name] + range_hints = [] + for i, dim in enumerate(dims): + hint = self._autoscheduler_workarounds( + V.graph.sizevars.size_hint(dim.size, fallback=1), dims + ) + # pyrefly: ignore [bad-argument-type] + range_hints.append(f"hl.Range(0, {hint})") + if "out" not in arg.name: + code.writeline(f"{arg.name}.dim({i}).set_min(0)") + try: + code.writeline( + f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" + ) + except TypeError: + pass # not integer + try: + code.writeline( + f"{arg.name}.dim({i}).set_extent({int(dim.size)})" + ) + except TypeError: + pass # not integer + code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])") + + code.do_unindent(2) + code.splice( + """ + if __name__ == "__main__": + hl.main() + """.rstrip(), + ) + if meta.scheduler: + code.splice( + f""" + else: + hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) + target = hl.Target({meta.target!r}) + autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) + with hl.GeneratorContext(target, autoscheduler): + gen = Kernel() + pipeline = gen._build_pipeline() + # gen.compile_to_callable() does not run the autoscheduler + pipeline.apply_autoscheduler(target, autoscheduler) + kernel = pipeline.compile_to_callable([ + gen._get_input_parameter(a.name)._to_argument() + for a in gen._get_arginfos() + if a.dir == hl.ArgInfoDirection.Input + ], target) + """, + strip=True, + ) + else: + code.splice( + f""" + else: + with hl.GeneratorContext(hl.Target({meta.target!r})): + kernel = Kernel().compile_to_callable() + """, + strip=True, + ) + return code.getvalue() + + @staticmethod + def _autoscheduler_workarounds(n, dims): + if ( + len(dims) == 1 + and config.halide.scheduler_cuda == "Anderson2021" + and V.graph.get_current_device_or_throw().type == "cuda" + ): + # workaround https://github.com/halide/Halide/issues/8246 + n = max(2, n) + return n + + def call_kernel(self, name: str, node=None, deallocate_ws: bool = True): + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cuda": + stream_name = wrapper.write_get_raw_stream( + current_device.index, V.graph.name + ) + call_args.append(stream_name) + wrapper.generate_kernel_call( + name, + call_args, + device=current_device, + triton=False, + ) + + def generate_assert(self, check): + return False # TODO(jansel): support asserts + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + pass # TODO(jansel): support asserts + + +class HalideScheduling(SIMDScheduling): + kernel_type = HalideKernel # type: ignore[arg-type,assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + result = OrderedSet( + [ + BackendFeature.TUPLE_REDUCTION, + BackendFeature.PREFER_STORE_LOOP_ORDER, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + if config.halide.scan_kernels: + result.add(BackendFeature.SCAN) + return result + + def define_kernel(self, src_code, node_schedule, kernel): + """Codegen kernel definition to go in output wrapper code""" + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}" + wrapper.src_to_kernel[src_code] = kernel_name + wrapper.add_import_once( + "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec" + ) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline( + f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''" + ) + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, "", src_code) + + return kernel_name diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..12d7500975e5b93c6c837a48821ef737df6a3f19 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,816 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Optional, Protocol, TYPE_CHECKING + +import sympy + +import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V +from .wrapper import ( + AllocateLine, + BufferLike, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | +/-inf + end: float # int | +/-inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: BufferLike + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + earliest_available: Optional[float] = None + + def __post_init__(self) -> None: + has_unbacked_sym = False + for s in self.node.get_layout().size: + if free_unbacked_symbols(s): + has_unbacked_sym = True + break + + if has_unbacked_sym: + self.earliest_available = self.get_live_ranges().begin + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + def get_earliest_available(self): + return self.earliest_available + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: list[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: list[str] = dataclasses.field(default_factory=list) + creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) + + def __post_init__(self) -> None: + for block in self.root.allocations: + if isinstance(block, Allocation): + self.update_restrict_live_range(block) + + def allocate(self, block: Allocation, is_last: bool): + if ( + self.restrict_live_range is not None + and not self.restrict_live_range.contains(block.live_range) + ): + return False + + block_earliest_available = block.get_earliest_available() + pool_begin = self.root.get_live_ranges().begin + if block_earliest_available and block_earliest_available > pool_begin: + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + self.update_restrict_live_range(block) + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def update_restrict_live_range(self, block: Allocation): + if block_earliest_available := block.get_earliest_available(): + if self.restrict_live_range is None: + self.restrict_live_range = LiveRange( + block_earliest_available, float("inf") + ) + else: + self.restrict_live_range = LiveRange( + min(self.restrict_live_range.begin, block_earliest_available), + self.restrict_live_range.end, + ) + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + self.update_restrict_live_range(block) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: dict[torch.device, list[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: BufferLike): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool( + self.wrapper + ) + code.writelines(allocation_lines_to_write) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[list[BufferGroup]] = None + + def plan(self, lines: list[Any]) -> list[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = OrderedSet(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: list[Allocation] = [] + intermediates: list[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = OrderedSet[AllocationPool]() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = OrderedSet[AllocationPool]() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py new file mode 100644 index 0000000000000000000000000000000000000000..84165fea6e3803e6f4feaa33d8bbb5ae4af6be26 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py @@ -0,0 +1,1097 @@ +# This is not a feature-complete compiler backend +# Just an early prototype that shows that one can compile elementwise ops into a Metal shader +from __future__ import annotations + +import functools +import itertools +import logging +import math +from pathlib import Path +from typing import Any, Optional, TYPE_CHECKING + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +from torch.utils._cpp_embed_headers import _embed_headers +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter, ExprPrinter as ExprPrinter_ +from torch.utils._sympy.value_ranges import ValueRanges + +from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata +from ..virtualized import ops, OpsWrapper, V +from .common import ( + CSEVariable, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + OpOverrides, + PythonPrinter, +) +from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from typing import Union + + from ..ops_handler import ReductionType, StoreMode + from ..scheduler import Scheduler, SchedulerNode + from .common import OpVarT + +log = logging.getLogger(__name__) + +DTYPE_TO_METAL = { + torch.bool: "bool", + torch.int8: "char", + torch.int16: "short", + torch.int32: "int", + torch.int64: "long", + torch.uint8: "uchar", + torch.float: "float", + torch.half: "half", + torch.bfloat16: "bfloat", +} + + +def value_to_metal(val: Union[float, int, bool, str, CSEVariable]) -> str: + if isinstance(val, float): + if val == torch.inf: + return "HUGE_VALF" + elif val == -torch.inf: + return "-HUGE_VALF" + elif val != val: # Only float that not equal to self is nan + return "NAN" + return str(val) + elif isinstance(val, bool): + return "true" if val else "false" + return str(val) + + +class MetalExprPrinter(ExprPrinter_): + """Converts sympy expression to Metal code snippet""" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::metal::floor_divide({x}, {div})" + return f"metal::floor({x}) / ({div})" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"({x}) / ({div})" + else: + x = f"metal::floor({x}) / ({div})" + mod = self.doprint(mod) + return f"({x}) % ({mod})" + + def _print_Min(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::min only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::min({typecast_a}, {typecast_b})" + + def _print_Max(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::max only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::max({typecast_a}, {typecast_b})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"metal::abs({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"static_cast(metal::rint({self._print(expr.args[0])}))" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(metal::rint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**23 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + x, y = map(self.doprint, expr.args) + return f"metal::pow(static_cast({x}), static_cast({y}))" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast({x})" + + def _print_Float(self, expr: sympy.Expr) -> str: + if expr.is_integer: + # sympy considers 0.0 to be integer, but Metal doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + return str(int(expr)) + else: + return str(expr) + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::floor(static_cast({x})))" + + _print_floor = _print_FloorToInt + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::trunc({x}))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"metal::log2({x})" + + def _print_Where(self, expr: sympy.Expr) -> str: + c, p, q = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + return f"{c} ? {p} : {q}" + + +class MetalOverrides(OpOverrides): + """Implements Metal-specific overrides for ops. Base class emits Python-friendly overrides.""" + + @staticmethod + def to_dtype( + x: CSEVariable, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> str: + if dtype == torch.double: + log.warning( + "float64 cast requested, probably from tensorify_python_scalars" + ) + return f"static_cast({x})" + return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})" + + @staticmethod + def to_dtype_bitcast( + x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype + ) -> str: + return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))" + + @staticmethod + def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str: + return value_to_metal(val) + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str: + idx_str = V.kernel.index_to_str(V.kernel.prepare_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: + # TODO: Type annotation for other is wrong, it's often float or int + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) # type: ignore[assignment] + + return ops.where(new_mask, result, other) + + @staticmethod + def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str: + return f"{a} ? {b} : {value_to_metal(c)}" + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> str: + return f"c10::metal::remainder({a}, {b})" + + @staticmethod + def maximum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::max({typecast_a}, {typecast_b})" + + @staticmethod + def minimum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::min({typecast_a}, {typecast_b})" + + @staticmethod + def logical_or(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} || {b}" + + @staticmethod + def logical_and(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} && {b}" + + @staticmethod + def isnan(x: CSEVariable) -> str: + return f"metal::isnan({x})" + + @staticmethod + def isinf(x: CSEVariable) -> str: + return f"metal::isinf({x})" + + @staticmethod + def log(x: CSEVariable) -> str: + return f"metal::log({x})" + + @staticmethod + def exp(x: CSEVariable) -> str: + return f"metal::exp({x})" + + @staticmethod + def abs(x: CSEVariable) -> str: + return f"metal::abs({x})" + + @staticmethod + def signbit(x: CSEVariable) -> str: + return f"metal::signbit({x})" + + @staticmethod + def sin(x: CSEVariable) -> str: + return f"metal::precise::sin({x})" + + @staticmethod + def sinc(x: CSEVariable) -> str: + return f"c10::metal::sinc({x})" + + @staticmethod + def cos(x: CSEVariable) -> str: + return f"metal::precise::cos({x})" + + @staticmethod + def tan(x: CSEVariable) -> str: + return f"metal::tan({x})" + + @staticmethod + def asin(x: CSEVariable) -> str: + return f"metal::asin({x})" + + @staticmethod + def acos(x: CSEVariable) -> str: + return f"metal::acos({x})" + + @staticmethod + def atan(x: CSEVariable) -> str: + return f"metal::atan({x})" + + @staticmethod + def atan2(x: CSEVariable, y: CSEVariable) -> str: + return f"::metal::atan2({x}, {y})" + + @staticmethod + def sqrt(x: CSEVariable) -> str: + return f"metal::sqrt({x})" + + @staticmethod + def neg(x: CSEVariable) -> str: + # TODO: Does it rely on undefined behavior? + # If so, add special logic for unsigned types + return f"static_cast(-{x})" + + @staticmethod + def rsqrt(x: CSEVariable) -> str: + return f"metal::rsqrt({x})" + + @staticmethod + def tanh(x: CSEVariable) -> str: + return f"metal::tanh({x})" + + @staticmethod + def atanh(x: CSEVariable) -> str: + return f"metal::atanh({x})" + + @staticmethod + def floordiv(a: CSEVariable, b: CSEVariable) -> str: + # a and b must be of integer type + return f"c10::metal::floor_divide({a}, {b})" + + @staticmethod + def floor(x: CSEVariable) -> str: + return f"metal::floor({x})" + + @staticmethod + def sign(x: CSEVariable) -> str: + return f"metal::sign({x})" + + @staticmethod + def fmod(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::fmod({typecast_a}, {typecast_b})" + + @staticmethod + def trunc(x: CSEVariable) -> str: + return f"metal::trunc({x})" + + @staticmethod + def truncdiv(a: CSEVariable, b: CSEVariable) -> str: + quot = f"{a} / {b}" + if (a.dtype is not None and a.dtype.is_floating_point) or ( + b.dtype is not None and b.dtype.is_floating_point + ): + return f"metal::trunc({quot})" + return quot + + @staticmethod + def ceil(x: CSEVariable) -> str: + return f"metal::ceil({x})" + + @staticmethod + def rand(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::rand({seed}, {offset})" + + @staticmethod + def randn(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randn({seed}, {offset})" + + @staticmethod + def randint64( + seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable + ) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def round(x: CSEVariable) -> str: + return f"metal::rint({x})" + + @staticmethod + def pow(a: CSEVariable, b: CSEVariable) -> str: + cast_a = f"static_cast({a})" + cast_b = f"static_cast({b})" + return f"metal::pow({cast_a}, {cast_b})" + + def _special_unary(self, a: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a})" + + def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a}, {b})" + + @classmethod + def _initialize_special_ops(cls) -> None: + # Unary special ops + for name in [ + "erf", + "erfinv", + "i0", + "i0e", + "i1", + "i1e", + "digamma", + "spherical_bessel_j0", + ]: + setattr(cls, name, functools.partialmethod(cls._special_unary, name=name)) + + cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma") # type: ignore[assignment] + + # Unary special ops with forward in method name + for name in [ + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_unary, name=name + "_forward"), + ) + + # Binary special ops + for name in [ + "polygamma", + "igamma", + "igammac", + "zeta", + ]: + setattr(cls, name, functools.partialmethod(cls._special_binary, name=name)) + + # Binary special ops with forward in method name + for name in [ + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "hermite_polynomial_h", + "hermite_polynomial_he", + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_binary, name=name + "_forward"), + ) + + +MetalOverrides._initialize_pointwise_overrides("mps") +MetalOverrides._initialize_special_ops() + + +class MetalKernel(SIMDKernel): + """Implement Metal codegen based on the SIMDKernel abstraction""" + + overrides = MetalOverrides # type: ignore[assignment] + suffix = ";" + newvar_prefix = "auto " + max_threadgroup_size = 1024 + simd_group_size = 32 + pexpr = PythonPrinter().doprint + cexpr = CppPrinter().doprint + sexpr = MetalExprPrinter().doprint + kexpr = sexpr + headers: OrderedSet[str] = OrderedSet(["utils"]) + multistage_reduction_entry: list[IterationRangesEntry] = [] + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs: Any, + ) -> None: + super().__init__(tiling, **kwargs) + self.acc_var_ids = itertools.count() + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return DTYPE_TO_METAL[dtype] + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + dtype = V.graph.get_dtype(name) + line = f"{var}[{self.index_to_str(index)}]" + if dtype in [torch.float16, torch.bfloat16]: + # TODO(NS): Figure out the right balance between optype casts + # op_math_t for half-precision floats should be float32 + # Otherwise it can lead to a correctness issues with eager + line = f"static_cast({line})" + dtype = torch.float32 + return self.cse.generate(self.loads, line, dtype=dtype) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + cast_val = f"static_cast<{dtype_str}>({value})" + if mode is None: + line = f"{var}[{self.index_to_str(index)}] = {cast_val};" + elif mode == "atomic_add": + self.headers.add("atomic") + atomic_type = f"c10::metal::AtomicType<{dtype_str}>" + cast_var = f"reinterpret_cast({var})" + line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});" + else: + raise RuntimeError(f"Unimplemented store mode {mode}") + if self.inside_reduction: + self.compute.writeline(DeferredLine(name, line)) + else: + self.stores.writeline(DeferredLine(name, line)) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + # pyrefly: ignore [missing-argument] + reduction_dim = next(t for t in self.range_trees if t.is_reduction) + # Only one thread in the reduction group needs to store the results + line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" + line = f"if ({reduction_dim.name} == 0) {line}" + self.stores.writeline(DeferredLine(name, line)) + + def _new_idxvar( + self, + dtype: Union[str | torch.dtype], + elem_count: Optional[int] = None, + default_value: Optional[Any] = None, + is_threadgroup: bool = True, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + ) -> CSEVariable: + if isinstance(dtype, torch.dtype): + dtype = self.dtype_to_str(dtype) + var_name = f"tmp_acc_{next(self.acc_var_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) + var_def = "threadgroup " if is_threadgroup else "" + var_def += f"{dtype} {var_name}" + if elem_count: + var_def += f"[{self.sexpr(elem_count)}]" + if default_value is not None: + assert not is_threadgroup, "Thread group var can not have default value" + var_def += f" = {default_value}" + self.indexing_code.writeline(var_def + self.suffix) + return var + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + "Caching wrapper around _reduction_nocache" + cache_key = (src_dtype, reduction_type, value) + # Return cached reduction + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + result = self._reduction_nocache(dtype, src_dtype, reduction_type, value) + self.cse.reduction_cache[cache_key] = result # type: ignore[assignment] + return result + + def _reduction_nocache( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation. + Only sum and prod operations are somewhat reasonable optimized""" + assert self.inside_reduction + assert not self._load_mask + + def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: + # Uwraps vec3 dtype into individual components + return OpsWrapper._unwrap( + [CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"] + ) + + # Establish reduction buffer size and index expression + reduction_idx = "" + acc_buf_size = 1 + for rd in self.range_trees: + # pyrefly: ignore [missing-argument] + if not rd.is_reduction: + continue + if reduction_idx: + reduction_idx += " + " + reduction_idx += f"{rd.name} * {acc_buf_size}" + + if isinstance(rd.numel, sympy.Integer): + acc_buf_size *= rd.numel + else: + acc_buf_size *= sympy.Symbol( + f"{rd.prefix}numel", integer=True, positive=True + ) + + acc_buf_size = sympy.Min(acc_buf_size, self.max_threadgroup_size) + acc_buf_size_str = self.sexpr(acc_buf_size) + shmem_buf_size = ( + ceildiv(acc_buf_size, self.simd_group_size) + if isinstance(acc_buf_size, sympy.Integer) + else self.simd_group_size + ) + + if reduction_type == "any": + acc = self._new_idxvar(dtype) + self.indexing_code.writeline(f"{acc} = false;") + self.indexing_code.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + self.compute.splice( + f""" + if ({value}) {{ + {acc} = true; + }} + """ + ) + self.stores.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + return acc + + self.headers.add("reduction_utils") + + if reduction_type in ["prod", "sum"]: + acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] + acc_buf = self._new_idxvar(acc_dtype, shmem_buf_size) + if not self.multistage_reduction_entry: + val = value + else: + default_val, reduction_op = ( + (0, "+") if reduction_type == "sum" else (1, "*") + ) + val = self._new_idxvar( + acc_dtype, default_value=default_val, is_threadgroup=False + ) + self.compute.splice(f"{val} {reduction_op}= {value};") + + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})", + dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], + ) + if reduction_type in ["max", "min"]: + acc_buf = self._new_idxvar(src_dtype, shmem_buf_size) + src_metal_type = DTYPE_TO_METAL[src_dtype] + cast_value = f"static_cast<{src_metal_type}>({value})" + if not self.multistage_reduction_entry: + val = cast_value # type: ignore[assignment] + else: + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()" + val = self._new_idxvar( + src_dtype, default_value=limit_val, is_threadgroup=False + ) + self.compute.splice( + f"{val} = ::c10::metal::{reduction_type}({val}, {cast_value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})", + dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], + ) + if reduction_type in ["argmin", "argmax"]: + data_acc_buf = self._new_idxvar(src_dtype, shmem_buf_size) + idx_acc_buf = self._new_idxvar(dtype, shmem_buf_size) + src_metal_type = DTYPE_TO_METAL[src_dtype] + cast_value = f"static_cast<{src_metal_type}>({value})" + if not self.multistage_reduction_entry: + val = cast_value # type: ignore[assignment] + idx_val = f"static_cast<{DTYPE_TO_METAL[dtype]}>({reduction_idx})" + else: + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()" + val = self._new_idxvar( + src_dtype, default_value=limit_val, is_threadgroup=False + ) + idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment] + idx_var = next( + t + for t in self.range_tree_nodes.values() + # pyrefly: ignore [missing-argument] + if t.is_reduction + ) + cmp_op = ">" if reduction_type == "argmax" else "<" + nan_suffix = ( + f" || ::metal::isnan({value}) " + if src_dtype.is_floating_point + else "" + ) + self.compute.splice(f""" + if ({value} {cmp_op} {val}{nan_suffix}) {{ + {val} = {value}; + {idx_val} = {idx_var.name}; + }} + """) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({data_acc_buf}, {idx_acc_buf}, " + f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size_str})", + dtype=dtype, + ) + if reduction_type == "welford_reduce": + if not self.multistage_reduction_entry: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") + wf_res = self.cse.generate( + self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));" + ) + wf_res = self.cse.generate( + self.stores, + f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + if reduction_type == "welford_combine": + assert isinstance(value, tuple), "Input to welford combine must be tuple" + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + inp_value = f"float3({value[0]}, {value[1]}, {value[2]})" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + if self.multistage_reduction_entry: + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});" + ) + else: + self.compute.writeline(f"{acc_thread_var} = {inp_value};") + wf_res = self.cse.generate( + self.stores if self.multistage_reduction_entry else self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + raise NotImplementedError(reduction_type) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: + index_expr = self.rename_indexing(entry.expr) + index_str = self.sexpr(index_expr) # type: ignore[misc] + + # pyrefly: ignore [missing-argument] + if not entry.is_reduction or ( + isinstance(entry.root.numel, sympy.Integer) + and entry.root.numel <= self.max_threadgroup_size + ): + self.indexing_code.writeline( + f"{self.index_dtype} {entry.name} = {index_str};" + ) + return + + acc_size = ( + entry.root.numel + if isinstance(entry.root.numel, sympy.Integer) + else sympy.Symbol(f"{entry.root.prefix}numel", integer=True, positive=True) + ) + + self.multistage_reduction_entry.append(entry) + # When reducing the tensor whose size exceeds max threadgroup size + # loop over extra indices per reduction thread and perform part of the operation + # using values in the shared memory + + # Use floats so that it doesn't do integer division + loop_size = (acc_size + float(self.max_threadgroup_size - 1)) // float( + self.max_threadgroup_size + ) + loop_size_str = self.sexpr(loop_size) + + self.body.writeline( + f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size_str}; ++{entry.name}_cnt) {{" + ) + with self.body.indent(): + if isinstance(acc_size, sympy.Symbol): + self.body.writeline( + f"{self.index_dtype} {entry.name} = {self.max_threadgroup_size} * {entry.name}_cnt + {index_str};" + ) + else: + self.body.writeline( + f"{self.index_dtype} {entry.name} = {loop_size_str} * {index_str} + {entry.name}_cnt;" + ) + + # Check that reduction is performed only within tensor boundary + if ( + isinstance(acc_size, sympy.Symbol) + or loop_size * self.max_threadgroup_size != acc_size + ): + self.body.writeline(f"if ({entry.name} >= {acc_size}) break;") + + def codegen_body(self) -> None: + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if self.multistage_reduction_entry: + with self.body.indent(): + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.writeline("}" * len(self.multistage_reduction_entry)) + # Invalidate variables instantiated inside loop + # But results of reduction alive. Reduction cache values can be + # either CSEVariable or tuple of CSEVariables, in which case all + # variables in the tuple must be preserved + self.cse.invalidate( + OrderedSet( + v + for item in self.cse.reduction_cache.values() + for v in (item if isinstance(item, tuple) else (item,)) + ) + ) + # And loop codegen + while self.multistage_reduction_entry: + self.multistage_reduction_entry.pop().cache_clear() + else: + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + + def codegen_kernel(self, name: Optional[str] = None) -> str: + """Called at the end to generate a final kernel string""" + self.codegen_body() + code = IndentedBuffer() + + if V.graph.cpp_wrapper: + code.writeline('(R"MTL(') + else: + code.writeline("compile_mps_shader('''") + + idx_vars = self.active_range_trees() + with code.indent(): + if not V.graph.cpp_wrapper: + for header in self.headers: + code.writeline(f"#include ") + else: + headers = [ + f"#include " for header in self.headers + ] + header_contents = _embed_headers( + headers, + [Path(__file__).parent.parent.parent / "include"], + OrderedSet(), # type: ignore[arg-type] + ) + code.writeline(header_contents) + + if self.inside_reduction: + total_reduction_size = math.prod( + t.numel + for t in self.range_trees + # pyrefly: ignore [missing-argument] + if t.is_reduction + ) + # If using dynamic shapes, set the threadgroup size to be the + # max possible size + threadgroup_size = ( + min(total_reduction_size, self.max_threadgroup_size) + if isinstance(total_reduction_size, sympy.Integer) + else self.max_threadgroup_size + ) + code.writeline( + f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" + ) + code.writeline("kernel void generated_kernel(") + with code.indent(): + for outer, inner in self.args.output_buffers.items(): + if outer in self.removed_buffers: + continue + dtype_str = self.dtype_to_str(V.graph.get_dtype(outer)) + code.writeline(f"device {dtype_str}* {inner},") + for outer, inner in self.args.input_buffers.items(): + dtype = V.graph.get_dtype(outer) + # MPS does not support float64, but scalar inputs are fine + if dtype == torch.float64: + outer_buf = V.graph.try_get_buffer(outer) + if outer_buf is None or outer_buf.get_size() != []: + raise RuntimeError("float64 is not supported by MPS") + dtype_str = "float" + else: + dtype_str = self.dtype_to_str(dtype) + code.writeline(f"constant {dtype_str}* {inner},") + for inner in self.args.sizevars.values(): + code.writeline(f"constant long& {inner},") + + # Write dynamic values as inputs + for idx_var in idx_vars: + if isinstance(idx_var.numel, sympy.Integer): + pass + else: + code.writeline(f"constant long& {idx_var.prefix}numel,") + + assert len(idx_vars) < 4, "Up to 3 index variables are supported" + thread_pos_dtype = ( + f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint" + ) + thread_pos_var_name = ( + idx_vars[0].name if len(idx_vars) == 1 else "thread_pos" + ) + thread_pos_suffix = "," if self.inside_reduction else "" + code.writeline( + f"{thread_pos_dtype} {thread_pos_var_name} [[thread_position_in_grid]]{thread_pos_suffix}" + ) + if self.inside_reduction: + code.writeline( + f"{thread_pos_dtype} group_pos [[thread_position_in_threadgroup]]" + ) + code.writeline(") {") + with code.indent(): + if len(idx_vars) > 1: + for idx, var in enumerate(idx_vars): + code.writeline( + f"auto {var.name} = thread_pos.{chr(120 + idx)};" + ) + code.splice(self.indexing_code) + code.splice(self.body) + code.writeline("}") + + if V.graph.cpp_wrapper: + code.writeline(')MTL");') + else: + code.writeline("''')") + + return code.getvalue() + + def call_kernel( + self, name: str, node: Any = None, deallocate_ws: bool = True + ) -> None: + """ + Codegens a call to this kernel + """ + wrapper = V.graph.wrapper_code + # Make sure sizevars has been computed + for v in self.args.sizevars: + wrapper.ensure_size_computed(v) + + _, call_args, _, arg_types = self.args.python_argdefs() + arg_name_to_type = { + str(call_arg): arg_type for call_arg, arg_type in zip(call_args, arg_types) + } + + args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] + args = [arg for arg in args if arg not in self.removed_buffers] + args += [str(v) for v in self.args.sizevars] + arg_types = [arg_name_to_type[arg] for arg in args] + + # Add any dynamic ints as inputs + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, int)): + # Don't need to pass in integers as inputs + continue + elif isinstance(tree.numel, sympy.Symbol): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree).inner + + # pyrefly: ignore [missing-argument] + if not tree.is_reduction or self.inside_reduction: + args.append(str(expr)) + arg_types.append(int) + + expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr + + def format_threads(threads: list[str], kwarg: str) -> str: + if V.graph.cpp_wrapper: + threads = [f"static_cast({t})" for t in threads] + return f"{{{', '.join(threads)}}}" + else: + return f"{kwarg}=[{', '.join(threads)}]" + + # For reduction kernels, limit the maximum size over reduction dimensions to + # a maximum threadgroup size + if len(self.active_range_trees()) > 0: + threads = [ + expr_printer( + sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] + # pyrefly: ignore [missing-argument] + if v.is_reduction + else v.numel + ) + for v in self.active_range_trees() + ] + + args.append(format_threads(threads, "threads")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + raise RuntimeError("We should always have threads?") + + if self.inside_reduction: + threads = [ + expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] + # pyrefly: ignore [missing-argument] + if v.is_reduction + else "1" + for v in self.active_range_trees() + ] + args.append(format_threads(threads, "group_size")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + # Add a None so that we always have a group_size in the + # arguments. We won't use it if the value is None. + args += [None] # type: ignore[list-item] + arg_types.append(None) + + wrapper.generate_kernel_call( + name, + args, + device=torch.device("mps"), + triton=False, + arg_types=arg_types, + ) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + if not (lower or upper): + return + # TODO(malfet): support asserts + # See https://github.com/pytorch/pytorch/issues/144634 + expr_str = self.index_to_str(expr) + lower_expr = f"{expr_str} < 0" if lower else "" + # TODO(malfet): Is upper bound inclusive or exclusive? + upper_expr = f"{expr_str} > {self.index_to_str(size)}" if upper else "" + if lower and upper: + line = f"if (({lower_expr}) && ({upper_expr})) return" + else: + line = f"if ({lower_expr}{upper_expr}) return" + self.cse.generate(self.compute, line, assignment=False) + + +class MetalScheduling(SIMDScheduling): + kernel_type = MetalKernel # type: ignore[assignment] + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + wrapper = V.graph.wrapper_code + if wrapper is not None: + if not V.graph.cpp_wrapper: + wrapper.header.splice( + "from torch._inductor.runtime.runtime_utils import compile_mps_shader" + ) + + def define_kernel( + self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel + ) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + # TODO: Merge multiple kernels into a single library + # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling + mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" + + kernel_name = f"{mps_lib_name}" + wrapper.src_to_kernel[src_code] = kernel_name + + if V.graph.cpp_wrapper: + # For shimified version, generate source constant instead of direct instantiation + src_code = f"const char* {mps_lib_name}_source = " + src_code + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) + + return kernel_name diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4ddb163ef4f9957e1a64a4ab25ec865e8206b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class MPSDeviceOpOverrides(DeviceOpOverrides): + def device_guard(self, device_idx: int) -> str: + assert device_idx == 0 + return "torch._ops.contextlib.nullcontext()" + + def set_device(self, device_idx: int) -> str: + assert device_idx == 0 + return "pass # MPS set device" + + def kernel_driver(self) -> str: + return """ + #include + """ + + def cpp_kernel_type(self) -> str: + return "MTLFunction_t" + + +register_device_op_overrides("mps", MPSDeviceOpOverrides()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..135bee2b8fe9226d5b69077201c0b08bfc8460a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class MTIADeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _mtia_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.mtia.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.mtia.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.mtia.device({device_idx})" + + +register_device_op_overrides("mtia", MTIADeviceOpOverrides()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..094164a1f08ca8db973047a90d182fb943795b88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,612 @@ +# mypy: allow-untyped-defs +import functools +import logging +import math +import os +import pathlib +from typing import Any, Optional, Union + +from torch._inductor.ir import MultiTemplateBuffer +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..codecache import code_hash, CodeCacheFuture, get_path, write_atomic +from ..runtime.benchmarking import benchmarker +from ..utils import cache_on_self, IndentedBuffer +from ..virtualized import V +from .common import TensorArg, WorkspaceArg + + +log = logging.getLogger(__name__) + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + self.kernel_defs = IndentedBuffer() + + def define_kernel( + self, + kernels: list[Any], + kernel_shape_keys: Optional[ + list[Union[None, tuple[tuple[int, ...], ...]]] + ] = None, + ) -> str: + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + + kernels: + A list of kernels + kernel_shape_keys: + Specified for size-hint multi-kernels. + Each list element is a shape key, corresponding to the concrete input & output size hints each kernel was tuned for. + """ + # Prevent circular import + from ..select_algorithm import TritonTemplateKernel + + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + arg_index: dict[int, list[slice]] = {} + _, call_args, _, arg_types = kernels[0].args.python_argdefs() + if isinstance(kernels[0], TritonTemplateKernel) and isinstance( + kernels[0].output_node, MultiTemplateBuffer + ): + for i, kernel in enumerate(kernels): + additional_call_args, _ = kernel.additional_call_args_and_types() + if i not in arg_index: + arg_index[i] = [] + arg_index[i].append(slice(0, len(call_args))) + arg_index[i].append( + slice( + len(call_args) + i * len(additional_call_args), + len(call_args) + (i + 1) * len(additional_call_args), + ) + ) + else: + kernels[0].add_numel_to_call_args(multi_kernel_name, call_args, arg_types) + for i in range(len(kernels)): + arg_index[i] = [slice(0, len(call_args))] + + keyed_by_sizes = kernel_shape_keys is not None + buf = self.kernel_defs + buf.writeline("") + buf.writeline("arg_index = {") + for key, slice_list in arg_index.items(): + slice_reprs = ", ".join(repr(s) for s in slice_list) + buf.writeline(f" {key}: [{slice_reprs}],") + buf.writeline("}") + + if not keyed_by_sizes: # no size hint keys, just call with list of kernels + buf.writeline( + f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" + ) + with buf.indent(): + for name in kernel_names: + buf.writeline(f"{name},") + buf.writeline("], arg_index=arg_index)") + else: # call with dict[size hint key, kernel] + assert isinstance(kernels[0], TritonTemplateKernel) + assert isinstance(kernel_shape_keys, list) + assert len(kernels) == len(kernel_shape_keys) + buf.writeline( + f"{multi_kernel_name} = async_compile.size_hint_multi_kernel({multi_kernel_name!r}, {{" + ) + with buf.indent(): + for shape_key, name in zip(kernel_shape_keys, kernel_names): + buf.writeline(f"{shape_key}: {name},") + buf.writeline("}, arg_index=arg_index)") + + if config.triton.autotune_at_compile_time: + V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = ( + multi_kernel_name + ) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall( + [kernel1, kernel2], multi_kernel_definition_code + ) + ``` + + Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + @staticmethod + def _merge_workspace_args(left: list[WorkspaceArg], right: list[WorkspaceArg]): + if left == right: + return left + result = {x.inner_name: x for x in left} + for arg in right: + if arg.inner_name in result: + result[arg.inner_name] = WorkspaceArg.maximum( + result[arg.inner_name], arg + ) + else: + result[arg.inner_name] = arg + return [*result.values()] + + @staticmethod + def merge_workspaces_inplace(kernels): + if len(kernels) < 2: + return + # All kernels must share the same workspace + workspace_args = functools.reduce( + MultiKernel._merge_workspace_args, + [kernel.args.workspace_args for kernel in kernels], + ) + for kernel in kernels: + kernel.args.workspace_args = workspace_args + return workspace_args + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + # Prevent circular import + from ..select_algorithm import TritonTemplateKernel + + assert kernel_name == self.kernel_name + V.graph.wrapper_code.write_triton_header_once() + _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() + for kernel in self.kernels[1:]: + _, other_call_args, _, other_arg_types = kernel.args.python_argdefs() + assert call_args == other_call_args, (call_args, other_call_args) + assert arg_types == other_arg_types + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + if isinstance(self.kernels[0], TritonTemplateKernel) and isinstance( + self.kernels[0].output_node, MultiTemplateBuffer + ): + # For matmuls the grid arguments are passed in as additional arguments + # to the kernel run method. These grids change based on the various + # parameters of the matmul. So we need to pass each kernel's grid into + # the multi call kernel. + multi_call_args = call_args + multi_call_arg_types = arg_types + for kernel in self.kernels: + additional_call_args, additional_arg_types = ( + kernel.additional_call_args_and_types() + ) + multi_call_args.extend(list(additional_call_args)) + multi_call_arg_types.extend(list(additional_arg_types)) + else: + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types) + multi_call_args = call_args + multi_call_arg_types = arg_types + + for ws in self.kernels[0].args.workspace_args: + V.graph.wrapper_code.generate_workspace_allocation(ws) + + if V.graph.cpp_wrapper: + # We have already selected the best kernel at compile time + # so we only have one set of call args. NB: this currently + # doesn't work with MultiTemplateBuffer kernels. @bobrenjc93 + # will add it in a subsequent PR. + V.graph.wrapper_code.generate_kernel_call( + kernel_name, call_args, arg_types=arg_types + ) + else: + V.graph.wrapper_code.generate_kernel_call( + kernel_name, multi_call_args, arg_types=multi_call_arg_types + ) + + for ws in reversed(self.kernels[0].args.workspace_args): + V.graph.wrapper_code.generate_workspace_deallocation(ws) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen: OrderedSet[str] = OrderedSet() + for k in self.kernels: + _, call_args, precompile_args, _ = k.args.python_argdefs() + for arg, precompile_arg in zip(call_args, precompile_args): + if arg in seen: + continue + seen.add(arg) + if isinstance(precompile_arg, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels, arg_index): + assert len(kernels) >= 1 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + self.arg_index = arg_index + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + # pyrefly: ignore [bad-assignment] + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + key = code_hash( + ",".join( + [ + f"{k.fn.cache_key}{k.size_hints!r}{k.triton_meta!r}" + for k in self.kernels + ] + ) + ) + _, _, path = get_path(key, "picked_kernel") + return pathlib.Path(path) + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if path.exists(): + with path.open() as fd: + # pyrefly: ignore [bad-assignment] + self.picked_kernel = int(fd.read()) + # pyrefly: ignore [unsupported-operation] + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + path.parent.mkdir(parents=True, exist_ok=True) + + write_atomic(path, str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, CodeCacheFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def benchmark_sub_kernels(self, *args, **kwargs): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + + def wrap_fn(kernel, index): + def inner(): + filtered_args = self._get_filtered_args(args, index) + args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs) + return kernel.run(*args_clone, **kwargs_clone) + + return inner + + return [ + benchmarker.benchmark( + wrap_fn(kernel, index), + # Currently the kernel type must be a CachingAutotuner + device=kernel.device_props.type, + rep=40, + ) + for index, kernel in enumerate(self.kernels) + ] + + def _get_filtered_args(self, args, index): + """ + We pass in all arguments to all kernels into the MultiKernelCall + so when invoking a particular kernel we need to filter to only the + arguments for that specific kernel. + """ + + # This is sometimes invoked at runtime where V.graph is + # a NullHandler + if hasattr(V.graph, "cpp_wrapper") and V.graph.cpp_wrapper: + # for cpp-wrapper, we should not filter the args since + # we already have chosen a single kernel and arg set. + return args + return [item for s in self.arg_index[index] for item in args[s]] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name: str, picked_kernel_name: str): + """ + Record the multi-kernel choice for cpp-wrapper after autotuning + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name + + @staticmethod + def lookup_choice(multi_kernel_name: str) -> str: + # this should always been done during cpp-wrapper codegen + assert ( + V.graph.record_multi_kernel_choice + and multi_kernel_name in V.graph.multi_kernel_to_choice + ) + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run(self, *args, **kwargs): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(*args, **kwargs) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + get_metric_table("persistent_red_perf").add_row( + functools.partial(self._metrics_table_row, timings) + ) + + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get( + "kernel_name" + ) + assert picked_kernel_name is not None + self.record_choice(self.multi_kernel_name, picked_kernel_name) + + run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + filtered_args = self._get_filtered_args(args, self.picked_kernel) + run(*filtered_args, **kwargs) + + def _metrics_table_row(self, timings): + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + k0 = self.kernels[0] + row = { + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + } + max_kernels = 4 + assert len(timings) <= max_kernels + for i in range(max_kernels): + if i < len(self.kernels): + row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i]) + row[f"kernel{i}_latency"] = timings[i] + else: + row[f"kernel{i}_path"] = "" + row[f"kernel{i}_latency"] = "" + return row + + +class SizeHintMultiKernel(MultiKernel): + """ + Version of multi-kernel that generates kernels based on specified size hints. + Currently only performs 1-d search over hints; doesn't perform combinatorial n-d search + if n > 1 dynamic dimensions are specified. + + e.g. matmul([s0, s1], [s1, s2]) with size-hints [64, 256] only generates 2 kernels, + based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]) + """ + + def __init__(self, kernels): + assert isinstance(kernels, dict) and len(kernels) >= 1 + + self.kernels, self.kernel_shape_keys = [], [] + for shape_key, kernel in kernels.items(): + self.kernels.append(kernel) + self.kernel_shape_keys.append(shape_key) + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + self.kernels, self.kernel_shape_keys + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + +class SizeHintMultiKernelCall(MultiKernelCall): + """ + Runtime class for size-hint multi-kernels. + Instead of having a plain list of kernels to benchmark over, keys them by input & output shapes, + and optionally perform shape-based selection. The pre-generated kernel is chosen based on the shape keys, + with the heuristic being log2 l1 distance between the pre-generated / runtime input & output shapes. + """ + + def __init__(self, multi_kernel_name, kernels, arg_index): + super().__init__(multi_kernel_name, list(kernels.values()), arg_index) + self._kernel_hints = list(kernels.keys()) + + # Caches results for unique shapes. + self._shape_cache = {} + + def _get_shape_cache_key(self, *args, **kwargs): + """ + Generate a cache key based on tensor shapes for shape-specialized dispatch. + """ + shapes = [] + for arg in args: + if hasattr(arg, "shape"): + shapes.append(tuple(arg.shape)) + return tuple(shapes) + + def _get_cached_shape_choice(self, cache_key): + """ + Get cached kernel choice for a specific shape. + """ + return self._shape_cache.get(cache_key) + + def _cache_shape_choice(self, cache_key, kernel_idx): + """ + Cache kernel choice for a specific shape. + """ + self._shape_cache[cache_key] = kernel_idx + + def _dist_heuristic(self, k1, k2): + """ + log2 L1 distance heuristic for kernel selection. + """ + + def dist(x, y): + lx = math.log2(x) if x > 0 else -1 + ly = math.log2(y) if y > 0 else -1 + return abs(lx - ly) + + out = 0 + for s1, s2 in zip(k1, k2): + out += sum(dist(x, y) for x, y in zip(s1, s2)) + return out + + def run(self, *args, **kwargs): + cache_key = self._get_shape_cache_key(*args, **kwargs) + cached_choice = self._get_cached_shape_choice(cache_key) + if cached_choice is not None: + self.picked_kernel = cached_choice + log.debug( + "using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + cache_key, + ) + else: + self._select_kernel_by_shape(*args, **kwargs) + + if not self._recorded: + self._recorded = True + picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get( + "kernel_name" + ) + assert picked_kernel_name is not None + self.record_choice(self.multi_kernel_name, picked_kernel_name) + + run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + filtered_args = self._get_filtered_args(args, self.picked_kernel) + run(*filtered_args, **kwargs) + + def _select_kernel_by_shape(self, *args, **kwargs): + """ + Benchmark kernels for a particular shape and return the + best kernel for this shape. + """ + shape_key = self._get_shape_cache_key(*args, **kwargs) + dists = [ + self._dist_heuristic(shape_key, key) if key is not None else 2**62 + for key in self._kernel_hints + ] + # pyrefly: ignore [bad-assignment] + self.picked_kernel = dists.index(min(dists)) + self._cache_shape_choice(shape_key, self.picked_kernel) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py new file mode 100644 index 0000000000000000000000000000000000000000..ca955ba5f351839209b9be5d3d1312dc69efb48f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/pallas.py @@ -0,0 +1,1840 @@ +from __future__ import annotations + +import hashlib +import math +from typing import Any, Optional, TYPE_CHECKING, Union + +import sympy # noqa: TC002 + +import torch # noqa: TC001 +from torch.utils._ordered_set import OrderedSet +from torch.utils._pallas import has_tpu_pallas + +from .. import config +from ..runtime.runtime_utils import torch_dtype_to_jax +from ..utils import get_fused_kernel_name, get_kernel_metadata +from ..virtualized import V +from .block_analysis import BlockPatternMatcher +from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides +from .simd import pexpr, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ..ir import IRNode + from ..ops_handler import ReductionType + from ..scheduler import BaseSchedulerNode + + +# Main function suffix used in generated Pallas code +MAIN_SUFFIX = "main" + +# Logger for Pallas kernel code +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class PallasKernelWrapper: + """Wrapper to provide .run() interface for Pallas kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("Pallas kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the Pallas kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +class Unsupported(RuntimeError): + """Exception raised when an operation is not supported by the Pallas backend.""" + + +class PallasKernelOverrides(OpOverrides): + """ + Map element-wise ops to JAX/Pallas operations. + + For now, we use the default Python operators which are compatible + with JAX numpy broadcasting semantics. + """ + + @staticmethod + def sin(x: str) -> str: + return f"jnp.sin({x})" + + @staticmethod + def cos(x: str) -> str: + return f"jnp.cos({x})" + + @staticmethod + def tan(x: str) -> str: + return f"jnp.tan({x})" + + @staticmethod + def sinh(x: str) -> str: + return f"jnp.sinh({x})" + + @staticmethod + def cosh(x: str) -> str: + return f"jnp.cosh({x})" + + @staticmethod + def tanh(x: str) -> str: + return f"jnp.tanh({x})" + + @staticmethod + def asin(x: str) -> str: + return f"jnp.arcsin({x})" + + @staticmethod + def acos(x: str) -> str: + return f"jnp.arccos({x})" + + @staticmethod + def atan(x: str) -> str: + return f"jnp.arctan({x})" + + @staticmethod + def exp(x: str) -> str: + return f"jnp.exp({x})" + + @staticmethod + def exp2(x: str) -> str: + return f"jnp.exp2({x})" + + @staticmethod + def expm1(x: str) -> str: + return f"jnp.expm1({x})" + + @staticmethod + def log(x: str) -> str: + return f"jnp.log({x})" + + @staticmethod + def log10(x: str) -> str: + return f"jnp.log10({x})" + + @staticmethod + def log2(x: str) -> str: + return f"jnp.log2({x})" + + @staticmethod + def log1p(x: str) -> str: + return f"jnp.log1p({x})" + + @staticmethod + def sqrt(x: str) -> str: + return f"jnp.sqrt({x})" + + @staticmethod + def rsqrt(x: str) -> str: + return f"(1.0 / jnp.sqrt({x}))" + + @staticmethod + def abs(x: str) -> str: + return f"jnp.abs({x})" + + @staticmethod + def neg(x: str) -> str: + return f"(-{x})" + + @staticmethod + def floor(x: str) -> str: + return f"jnp.floor({x})" + + @staticmethod + def ceil(x: str) -> str: + return f"jnp.ceil({x})" + + @staticmethod + def trunc(x: str) -> str: + return f"jnp.trunc({x})" + + @staticmethod + def round(x: str) -> str: + return f"jnp.round({x})" + + @staticmethod + def sigmoid(x: str) -> str: + return f"(1.0 / (1.0 + jnp.exp(-{x})))" + + @staticmethod + def relu(x: str) -> str: + return f"jnp.maximum({x}, 0)" + + @staticmethod + def pow(a: str, b: str) -> str: + return f"jnp.power({a}, {b})" + + @staticmethod + def maximum(a: str, b: str) -> str: + return f"jnp.maximum({a}, {b})" + + @staticmethod + def minimum(a: str, b: str) -> str: + return f"jnp.minimum({a}, {b})" + + @staticmethod + def where(cond: str, a: str, b: str) -> str: + return f"jnp.where({cond}, {a}, {b})" + + @staticmethod + def to_dtype( + x: str, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> str: + jax_dtype = torch_dtype_to_jax(dtype) + # Wrap in jnp.asarray to handle scalars from integer indexing + return f"jnp.asarray({x}).astype({jax_dtype})" + + @staticmethod + def to_dtype_bitcast(x: str, dtype: torch.dtype, src_dtype: torch.dtype) -> str: + """Bitcast a value from one dtype to another with the same size.""" + jax_dtype = torch_dtype_to_jax(dtype) + jax_src_dtype = torch_dtype_to_jax(src_dtype) + # First ensure the value is the correct source dtype, then bitcast + return f"jax.lax.bitcast_convert_type(jnp.asarray({x}).astype({jax_src_dtype}), {jax_dtype})" + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str: + """Convert a sympy expression to a JAX array indexing expression.""" + from ..utils import get_bounds_index_expr + + idx_str = V.kernel.kexpr(V.kernel.prepare_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return PallasKernelOverrides.to_dtype(var, dtype) + + @staticmethod + def constant(val, dtype: torch.dtype) -> str: + """Convert a constant value to JAX representation.""" + jax_dtype = torch_dtype_to_jax(dtype) + if dtype == torch.bool: + return "True" if val else "False" + # Handle special float values + if isinstance(val, float): + if math.isnan(val): + return "jnp.nan" + if math.isinf(val): + return "jnp.inf" if val > 0 else "-jnp.inf" + return f"jnp.array({val}, dtype={jax_dtype})" + + @staticmethod + def real(x: str) -> str: + return f"jnp.real({x})" + + @staticmethod + def imag(x: str) -> str: + return f"jnp.imag({x})" + + @staticmethod + def conj(x: str) -> str: + return f"jnp.conj({x})" + + @staticmethod + def angle(x: str) -> str: + return f"jnp.angle({x})" + + @staticmethod + def view_as_real(x: str) -> str: + """View complex tensor as real tensor with extra dimension.""" + return f"jnp.stack([jnp.real({x}), jnp.imag({x})], axis=-1)" + + @staticmethod + def view_as_complex(x: str) -> str: + """View real tensor as complex tensor.""" + return f"({x}[..., 0] + 1j * {x}[..., 1])" + + # Comparison operations + @staticmethod + def eq(a: str, b: str) -> str: + return f"({a} == {b})" + + @staticmethod + def ne(a: str, b: str) -> str: + return f"({a} != {b})" + + @staticmethod + def lt(a: str, b: str) -> str: + return f"({a} < {b})" + + @staticmethod + def le(a: str, b: str) -> str: + return f"({a} <= {b})" + + @staticmethod + def gt(a: str, b: str) -> str: + return f"({a} > {b})" + + @staticmethod + def isnan(x: str) -> str: + return f"jnp.isnan({x})" + + @staticmethod + def isinf(x: str) -> str: + return f"jnp.isinf({x})" + + @staticmethod + def isfinite(x: str) -> str: + return f"jnp.isfinite({x})" + + @staticmethod + def ge(a: str, b: str) -> str: + return f"({a} >= {b})" + + # Logical operations + @staticmethod + def logical_and(a: str, b: str) -> str: + return f"jnp.logical_and({a}, {b})" + + @staticmethod + def logical_or(a: str, b: str) -> str: + return f"jnp.logical_or({a}, {b})" + + @staticmethod + def logical_not(x: str) -> str: + return f"jnp.logical_not({x})" + + @staticmethod + def logical_xor(a: str, b: str) -> str: + return f"jnp.logical_xor({a}, {b})" + + # Math operations + @staticmethod + def atan2(a: str, b: str) -> str: + return f"jnp.arctan2({a}, {b})" + + @staticmethod + def hypot(a: str, b: str) -> str: + return f"jnp.hypot({a}, {b})" + + @staticmethod + def fmod(a: str, b: str) -> str: + return f"jnp.fmod({a}, {b})" + + @staticmethod + def remainder(a: str, b: str) -> str: + return f"jnp.remainder({a}, {b})" + + @staticmethod + def truncdiv(a: str, b: str) -> str: + # Truncated division (rounds toward zero) + # For integers: sign(a)*sign(b) * (abs(a) // abs(b)) + return f"(jnp.sign({a}) * jnp.sign({b}) * (jnp.abs({a}) // jnp.abs({b}))).astype({a}.dtype)" + + @staticmethod + def floordiv(a: str, b: str) -> str: + return f"({a} // {b})" + + @staticmethod + def clamp(x: str, min_val: str, max_val: str) -> str: + return f"jnp.clip({x}, {min_val}, {max_val})" + + @staticmethod + def clip(x: str, min_val: str, max_val: str) -> str: + return f"jnp.clip({x}, {min_val}, {max_val})" + + # Sign operations + @staticmethod + def sign(x: str) -> str: + return f"jnp.sign({x})" + + @staticmethod + def signbit(x: str) -> str: + return f"jnp.signbit({x})" + + # Special math functions + @staticmethod + def erf(x: str) -> str: + return f"jax.scipy.special.erf({x})" + + @staticmethod + def erfc(x: str) -> str: + return f"jax.scipy.special.erfc({x})" + + @staticmethod + def erfinv(x: str) -> str: + return f"jax.scipy.special.erfinv({x})" + + @staticmethod + def lgamma(x: str) -> str: + return f"jax.scipy.special.gammaln({x})" + + @staticmethod + def digamma(x: str) -> str: + return f"jax.scipy.special.digamma({x})" + + @staticmethod + def bessel_j0(x: str) -> str: + # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN) + # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n + # Handle by: convert to float64, compute, handle x=0, convert back + # J0(0) = 1.0 + return ( + f"jnp.where({x}.astype(jnp.float64) == 0.0, 1.0, " + f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=0)[0])" + f".astype({x}.dtype)" + ) + + @staticmethod + def bessel_j1(x: str) -> str: + # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN) + # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n + # Handle by: convert to float64, compute, handle x=0, convert back + # J1(0) = 0.0 + return ( + f"jnp.where({x}.astype(jnp.float64) == 0.0, 0.0, " + f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=1)[1])" + f".astype({x}.dtype)" + ) + + @staticmethod + def modified_bessel_i0(x: str) -> str: + # Modified Bessel function of the first kind I_0(x) + # I_0(x) = bessel_i0e(x) * exp(|x|) where bessel_i0e is the scaled version + return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def modified_bessel_i1(x: str) -> str: + # Modified Bessel function of the first kind I_1(x) + # I_1(x) = bessel_i1e(x) * exp(|x|) where bessel_i1e is the scaled version + return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def spherical_bessel_j0(x: str) -> str: + # Spherical Bessel function of the first kind j_0(x) = sin(x) / x + # Handle x=0: j_0(0) = 1 + return f"jnp.where({x} == 0.0, 1.0, jnp.sin({x}) / {x})" + + @staticmethod + def i0(x: str) -> str: + # Modified Bessel function I_0 (same as modified_bessel_i0) + return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def i0e(x: str) -> str: + # Exponentially scaled modified Bessel function I_0 + return f"jax.lax.bessel_i0e({x})" + + @staticmethod + def i1(x: str) -> str: + # Modified Bessel function I_1 (same as modified_bessel_i1) + return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def i1e(x: str) -> str: + # Exponentially scaled modified Bessel function I_1 + return f"jax.lax.bessel_i1e({x})" + + @staticmethod + def gammainc(x: str, y: str) -> str: + # Regularized lower incomplete gamma function P(a, x) + # Note: PyTorch uses gammainc(input, other) where input is a (shape param) + return f"jax.scipy.special.gammainc({x}, {y})" + + @staticmethod + def gammaincc(x: str, y: str) -> str: + # Regularized upper incomplete gamma function Q(a, x) + return f"jax.scipy.special.gammaincc({x}, {y})" + + @staticmethod + def igamma(x: str, y: str) -> str: + # Regularized lower incomplete gamma function (alias for gammainc) + return f"jax.scipy.special.gammainc({x}, {y})" + + @staticmethod + def igammac(x: str, y: str) -> str: + # Regularized upper incomplete gamma function (alias for gammaincc) + return f"jax.scipy.special.gammaincc({x}, {y})" + + @staticmethod + def polygamma(x: str, y: str) -> str: + # Polygamma function psi^(n)(x), x is order n, y is the value + # Note: JAX uses polygamma(n, x) where n is integer order + return f"jax.scipy.special.polygamma({x}.astype(jnp.int32), {y})" + + @staticmethod + def ndtri(x: str) -> str: + # Inverse of the standard normal CDF + return f"jax.scipy.special.ndtri({x})" + + @staticmethod + def zeta(x: str, y: str) -> str: + # Hurwitz zeta function zeta(x, q) = sum_{k=0}^inf 1/(k+q)^x + return f"jax.scipy.special.zeta({x}, {y})" + + @staticmethod + def xlogy(x: str, y: str) -> str: + # x * log(y), with proper handling of x=0 + return f"jax.scipy.special.xlogy({x}, {y})" + + @staticmethod + def xlog1py(x: str, y: str) -> str: + # x * log1p(y), with proper handling of x=0 + return f"jax.scipy.special.xlog1py({x}, {y})" + + @staticmethod + def chebyshev_polynomial_t(x: str, n: str) -> str: + # Chebyshev polynomial of the first kind T_n(x) + # For |x| <= 1: T_n(x) = cos(n * arccos(x)) + # For x > 1: T_n(x) = cosh(n * arccosh(x)) + # For x < -1: T_n(x) = (-1)^n * cosh(n * arccosh(-x)) + return ( + f"jnp.where(jnp.abs({x}) <= 1, " + f"jnp.cos({n} * jnp.arccos(jnp.clip({x}, -1, 1))), " + f"jnp.where({x} > 1, " + f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({x}, 1.0))), " + f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{x}, 1.0)))))" + ) + + @staticmethod + def chebyshev_polynomial_u(x: str, n: str) -> str: + # Chebyshev polynomial of the second kind U_n(x) + # For |x| < 1: U_n(x) = sin((n+1) * arccos(x)) / sqrt(1 - x^2) + # For x = 1: U_n(1) = n+1 + # For x = -1: U_n(-1) = (-1)^n * (n+1) + # For x > 1: U_n(x) = sinh((n+1) * arccosh(x)) / sqrt(x^2 - 1) + # For x < -1: U_n(x) = (-1)^n * U_n(-x) (symmetry) + return ( + f"jnp.where(jnp.abs({x}) < 1, " + f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({x}, -1, 1))) / " + f"jnp.sqrt(jnp.maximum(1 - {x}**2, 1e-10)), " + f"jnp.where({x} >= 1, " + f"jnp.where({x} == 1, {n} + 1.0, " + f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({x}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10))), " + f"jnp.where({x} == -1, ((-1.0) ** {n}) * ({n} + 1.0), " + f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{x}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10)))))" + ) + + @staticmethod + def chebyshev_polynomial_v(x: str, n: str) -> str: + # Chebyshev polynomial of the third kind V_n(x) + # V_n(x) = (T_n(x) - T_{n+1}(x)) / (1 - x) for x != 1 + # V_n(1) = 1, recurrence: V_0 = 1, V_1 = 2x - 1, V_n = 2x*V_{n-1} - V_{n-2} + # Explicit: V_0 = 1, V_1 = 2x-1, V_2 = 4x^2-2x-1, V_3 = 8x^3-4x^2-4x+1 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{x} - 1, " + f"jnp.where({n} == 2, 4*{x}**2 - 2*{x} - 1, " + f"jnp.where({n} == 3, 8*{x}**3 - 4*{x}**2 - 4*{x} + 1, " + f"jnp.where({n} == 4, 16*{x}**4 - 8*{x}**3 - 12*{x}**2 + 4*{x} + 1, " + f"jnp.where({n} == 5, 32*{x}**5 - 16*{x}**4 - 32*{x}**3 + 12*{x}**2 + 6*{x} - 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def chebyshev_polynomial_w(x: str, n: str) -> str: + # Chebyshev polynomial of the fourth kind W_n(x) + # W_n(x) = (T_n(x) + T_{n+1}(x)) / (1 + x) for x != -1 + # W_n(-1) = (-1)^n, recurrence: W_0 = 1, W_1 = 2x + 1, W_n = 2x*W_{n-1} - W_{n-2} + # Explicit: W_0 = 1, W_1 = 2x+1, W_2 = 4x^2+2x-1, W_3 = 8x^3+4x^2-4x-1 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{x} + 1, " + f"jnp.where({n} == 2, 4*{x}**2 + 2*{x} - 1, " + f"jnp.where({n} == 3, 8*{x}**3 + 4*{x}**2 - 4*{x} - 1, " + f"jnp.where({n} == 4, 16*{x}**4 + 8*{x}**3 - 12*{x}**2 - 4*{x} + 1, " + f"jnp.where({n} == 5, 32*{x}**5 + 16*{x}**4 - 32*{x}**3 - 12*{x}**2 + 6*{x} + 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_t(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the first kind T*_n(x) = T_n(2x - 1) + # T_n(y) where y = 2x - 1 + # Use same formula as chebyshev_polynomial_t + y = f"(2 * {x} - 1)" + return ( + f"jnp.where(jnp.abs({y}) <= 1, " + f"jnp.cos({n} * jnp.arccos(jnp.clip({y}, -1, 1))), " + f"jnp.where({y} > 1, " + f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({y}, 1.0))), " + f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{y}, 1.0)))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_u(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the second kind U*_n(x) = U_n(2x - 1) + # Use same formula as chebyshev_polynomial_u + y = f"(2 * {x} - 1)" + return ( + f"jnp.where(jnp.abs({y}) < 1, " + f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({y}, -1, 1))) / " + f"jnp.sqrt(jnp.maximum(1 - ({y})**2, 1e-10)), " + f"jnp.where({y} >= 1, " + f"jnp.where({y} == 1, {n} + 1.0, " + f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({y}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10))), " + f"jnp.where({y} == -1, ((-1.0) ** {n}) * ({n} + 1.0), " + f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{y}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10)))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_v(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the third kind V*_n(x) = V_n(2x - 1) + y = f"(2 * {x} - 1)" # shifted variable + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{y} - 1, " + f"jnp.where({n} == 2, 4*{y}**2 - 2*{y} - 1, " + f"jnp.where({n} == 3, 8*{y}**3 - 4*{y}**2 - 4*{y} + 1, " + f"jnp.where({n} == 4, 16*{y}**4 - 8*{y}**3 - 12*{y}**2 + 4*{y} + 1, " + f"jnp.where({n} == 5, 32*{y}**5 - 16*{y}**4 - 32*{y}**3 + 12*{y}**2 + 6*{y} - 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_w(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the fourth kind W*_n(x) = W_n(2x - 1) + y = f"(2 * {x} - 1)" # shifted variable + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{y} + 1, " + f"jnp.where({n} == 2, 4*{y}**2 + 2*{y} - 1, " + f"jnp.where({n} == 3, 8*{y}**3 + 4*{y}**2 - 4*{y} - 1, " + f"jnp.where({n} == 4, 16*{y}**4 + 8*{y}**3 - 12*{y}**2 - 4*{y} + 1, " + f"jnp.where({n} == 5, 32*{y}**5 + 16*{y}**4 - 32*{y}**3 - 12*{y}**2 + 6*{y} + 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def hermite_polynomial_h(x: str, n: str) -> str: + # Physicist's Hermite polynomial H_n(x) + # H_n(x) = 2^n * x^n - n*(n-1)/2 * 2^(n-2) * x^(n-2) + ... + # Use explicit formula: H_n(x) = n! * sum_{m=0}^{n//2} (-1)^m / (m! * (n-2m)!) * (2x)^(n-2m) + # For simplicity, use the relation: H_n(x) = 2^(n/2) * He_n(x * sqrt(2)) where He is probabilist's + # Actually simpler: use recurrence or closed form + # H_0 = 1, H_1 = 2x, H_2 = 4x^2 - 2, H_3 = 8x^3 - 12x + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2 * {x}, " + f"jnp.where({n} == 2, 4 * {x}**2 - 2, " + f"jnp.where({n} == 3, 8 * {x}**3 - 12 * {x}, " + f"jnp.where({n} == 4, 16 * {x}**4 - 48 * {x}**2 + 12, " + f"jnp.where({n} == 5, 32 * {x}**5 - 160 * {x}**3 + 120 * {x}, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def hermite_polynomial_he(x: str, n: str) -> str: + # Probabilist's Hermite polynomial He_n(x) + # He_0 = 1, He_1 = x, He_2 = x^2 - 1, He_3 = x^3 - 3x + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, {x}, " + f"jnp.where({n} == 2, {x}**2 - 1, " + f"jnp.where({n} == 3, {x}**3 - 3 * {x}, " + f"jnp.where({n} == 4, {x}**4 - 6 * {x}**2 + 3, " + f"jnp.where({n} == 5, {x}**5 - 10 * {x}**3 + 15 * {x}, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def laguerre_polynomial_l(x: str, n: str) -> str: + # Laguerre polynomial L_n(x) + # L_0 = 1, L_1 = 1 - x, L_2 = (x^2 - 4x + 2)/2, L_3 = (-x^3 + 9x^2 - 18x + 6)/6 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 1 - {x}, " + f"jnp.where({n} == 2, ({x}**2 - 4*{x} + 2) / 2, " + f"jnp.where({n} == 3, (-{x}**3 + 9*{x}**2 - 18*{x} + 6) / 6, " + f"jnp.where({n} == 4, ({x}**4 - 16*{x}**3 + 72*{x}**2 - 96*{x} + 24) / 24, " + f"jnp.where({n} == 5, (-{x}**5 + 25*{x}**4 - 200*{x}**3 + 600*{x}**2 - 600*{x} + 120) / 120, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def legendre_polynomial_p(x: str, n: str) -> str: + # Legendre polynomial P_n(x) + # P_0 = 1, P_1 = x, P_2 = (3x^2 - 1)/2, P_3 = (5x^3 - 3x)/2 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, {x}, " + f"jnp.where({n} == 2, (3 * {x}**2 - 1) / 2, " + f"jnp.where({n} == 3, (5 * {x}**3 - 3 * {x}) / 2, " + f"jnp.where({n} == 4, (35 * {x}**4 - 30 * {x}**2 + 3) / 8, " + f"jnp.where({n} == 5, (63 * {x}**5 - 70 * {x}**3 + 15 * {x}) / 8, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + # Reciprocal and square + @staticmethod + def reciprocal(x: str) -> str: + return f"jnp.reciprocal({x})" + + @staticmethod + def square(x: str) -> str: + return f"jnp.square({x})" + + # Additional operations + @staticmethod + def fma(a: str, b: str, c: str) -> str: + """Fused multiply-add: a * b + c""" + return f"jnp.fma({a}, {b}, {c})" + + @staticmethod + def copysign(a: str, b: str) -> str: + return f"jnp.copysign({a}, {b})" + + @staticmethod + def nextafter(a: str, b: str) -> str: + return f"jnp.nextafter({a}, {b})" + + @staticmethod + def ldexp(a: str, b: str) -> str: + return f"jnp.ldexp({a}, {b})" + + @staticmethod + def frexp(x: str) -> str: + return f"jnp.frexp({x})" + + @staticmethod + def modf(x: str) -> str: + return f"jnp.modf({x})" + + # Bitwise operations + @staticmethod + def bitwise_and(a: str, b: str) -> str: + return f"jnp.bitwise_and({a}, {b})" + + @staticmethod + def bitwise_or(a: str, b: str) -> str: + return f"jnp.bitwise_or({a}, {b})" + + @staticmethod + def bitwise_xor(a: str, b: str) -> str: + return f"jnp.bitwise_xor({a}, {b})" + + @staticmethod + def bitwise_not(x: str) -> str: + return f"jnp.bitwise_not({x})" + + @staticmethod + def left_shift(a: str, b: str) -> str: + return f"jnp.left_shift({a}, {b})" + + @staticmethod + def right_shift(a: str, b: str) -> str: + return f"jnp.right_shift({a}, {b})" + + +class PallasKernel(SIMDKernel): + """ + Pallas kernel for elementwise operations with support for strided/scatter access. + + Strategy: + - Convert index expressions to JAX-compatible array slicing + - Load/store using indexed access: "in_ptrX[slice]" or full-array "in_ptrX[...]" + - Compute expression with Python operators (compatible with jax.numpy broadcasting) + - Generate Python code that defines a Pallas kernel and a host entrypoint. + - Use async_compile.pallas path to compile and load Python code. + + For GPU (Triton backend): + - Use masked loads/stores with power-of-2 block sizes to handle non-power-of-2 shapes + """ + + overrides = PallasKernelOverrides # type: ignore[assignment] + kexpr: Callable[[sympy.Expr], str] = pexpr # Use Python expression printer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Determine device type once at initialization + device = V.graph.get_current_device_or_throw() + self.is_gpu = device.type == "cuda" + self.use_masked_ops: bool | None = None + self.tensor_masks = {} # Map tensor name to mask variable name + # Track which output param each store uses: list of (out_ptr_name, store_line) + self.store_with_output: list[tuple[str, str]] = [] + # Track load index expressions for argmax/argmin axis detection + self.load_index_exprs: dict[str, sympy.Expr] = {} + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + """Check array bounds for indirect indexing.""" + # For now, skip explicit bounds checking as JAX/Pallas handles this internally + # TODO: Implement explicit bounds checking with assertions if needed + + def _get_index_str(self, index: sympy.Expr) -> str: + """ + Convert an index expression to a string suitable for Pallas indexing. + + Pallas operates on full arrays, so we need to convert index expressions + to JAX array slicing. For example: + - x0 -> "..." (contiguous access, full array) + - 2*x0 -> "::2" (strided access with stride 2) + - 2*x0 + 1 -> "1::2" (strided access with offset 1, stride 2) + + Args: + index: The indexing expression to convert + + Returns: + The indexing string to use in generated code + """ + # Prepare and simplify the index + prepared_index = self.prepare_indexing(index) + + # For simple single-symbol access (contiguous case), we can use [...] + # which is more efficient as it operates on the entire array at once + if isinstance(prepared_index, sympy.Symbol): + return "..." + elif prepared_index.is_Integer: + # Scalar index + return str(prepared_index) + else: + # Complex expression (strided/scatter access) + # Try to extract stride and offset for common patterns + return self._convert_to_jax_slice(prepared_index) + + def _convert_to_jax_slice(self, index: sympy.Expr) -> str: + """ + Convert a sympy index expression to JAX slice notation. + + Handles common patterns like: + - stride*var -> ::stride + - stride*var + offset -> offset::stride + + For more complex patterns, falls back to explicit indexing. + Uses BlockPatternMatcher for robust pattern matching. + """ + # Get the iteration variables for this kernel + if not self.range_trees: + return "..." + + # Simplify the index + index = V.graph.sizevars.simplify(index) + free_symbols = index.free_symbols + + # Get iteration variables from range_tree_nodes + iter_vars = OrderedSet(self.range_tree_nodes.keys()) + + # Find which iteration variable(s) are used + used_vars = free_symbols & iter_vars + + if len(used_vars) == 0: + # No iteration variables, this is a constant index + return str(index) + elif len(used_vars) == 1: + # Single iteration variable - try to extract stride and offset using BlockPatternMatcher + var = next(iter(used_vars)) + + # Get the subexpression involving this variable + var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var) + + # Try to match affine pattern: stride * var + stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var) + + if stride is not None: + # Extract the constant offset (terms not involving var) + offset = index - var_expr + offset = V.graph.sizevars.simplify(offset) + + # Generate JAX slice notation + if stride == 1 and offset == 0: + # Contiguous access + return "..." + elif offset == 0: + # Pure stride: ::stride + stride_str = self.kexpr(stride) + return f"::{stride_str}" + else: + # Offset + stride: offset::stride + offset_str = self.kexpr(offset) + stride_str = self.kexpr(stride) + return f"{offset_str}::{stride_str}" + else: + # Couldn't match affine pattern, fall back to original logic + offset = index - var_expr + offset = V.graph.sizevars.simplify(offset) + if offset == 0 and var_expr == var: + # Just the variable itself, unit stride + return "..." + elif len(used_vars) > 1: + # Multi-dimensional indexing + # For contiguous multi-dim access, all terms should have unit stride + all_unit_stride = True + for var in used_vars: + var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var) + stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var) + if stride != 1: + all_unit_stride = False + break + + if all_unit_stride: + # Contiguous multi-dimensional access + return "..." + else: + # Strided multi-dimensional access - requires advanced indexing + # For now, use ellipsis which may work for many cases + # TODO: Implement proper multi-dimensional strided indexing + return "..." + + # For complex cases, raise an error + return self._generate_index_array(index) + + def _generate_index_array(self, index: sympy.Expr) -> str: + """ + Generate JAX code to compute an index array for complex indexing patterns. + + For very complex patterns that can't be expressed as simple slices, + we need to compute the indices explicitly. This is not yet fully implemented. + """ + # For now, raise an error for complex patterns + # TODO: Implement advanced indexing support + raise Unsupported( + f"Pallas backend does not yet support complex indexing pattern: {index}" + ) + + def _has_iteration_vars(self, index: sympy.Expr) -> bool: + """Check if index expression contains iteration variables (x0, x1, etc.).""" + free_symbols = index.free_symbols + iter_vars = OrderedSet(self.range_tree_nodes.keys()) + return bool(free_symbols & iter_vars) + + def _has_indirect_vars(self, index: sympy.Expr) -> bool: + """Check if index expression contains indirect variables (tmp0, tmp1, etc.).""" + free_symbols = index.free_symbols + for sym in free_symbols: + if str(sym).startswith("tmp"): + return True + return False + + def _get_index_expr(self, index: sympy.Expr) -> tuple[str, bool]: + """ + Get the index expression string and whether it needs flattening. + + Returns: + Tuple of (index_str, needs_flatten) where needs_flatten indicates + if the buffer should be flattened before indexing (for mixed indexing). + """ + has_indirect = self._has_indirect_vars(index) + has_iter_vars = self._has_iteration_vars(index) + + if has_indirect and has_iter_vars: + return self._handle_mixed_indexing(index), True + elif has_indirect: + return self.kexpr(index), False + else: + return self._get_index_str(index), False + + def _determine_masked_ops_for_kernel(self) -> bool: + """ + Determine if we should use masked ops for this entire kernel. + + Masked ops with pl.ds(block_size) flatten tensors to 1D, which works when: + 1. We're on GPU (CUDA backend uses Triton which requires power-of-2 sizes) + 2. All tensors are already 1D (so flattening doesn't change dimensionality) + 3. All tensors have the same size (so broadcasting works correctly) + + With per-tensor masks, each tensor gets its own mask based on its size. + + This should be called once in codegen_kernel() before generating the kernel body. + """ + if not self.is_gpu: + return False + + # Get all buffer sizes + # We need ALL buffers - inputs, outputs, and intermediates + all_buffer_names = OrderedSet() + + # Get input buffers from args + all_buffer_names.update(self.args.input_buffers.keys()) + # Get output buffers from args + all_buffer_names.update(self.args.output_buffers.keys()) + # Also get any intermediate buffers from the graph + all_buffer_names.update(V.graph.name_to_buffer.keys()) + + # Get shapes and sizes for all buffers + buf_info = [] + for buf_name in all_buffer_names: + try: + buf = V.graph.get_buffer(buf_name) + size = buf.get_size() + shape = tuple(int(s) if hasattr(s, "__int__") else s for s in size) + # Calculate flattened size + total_size = 1 + for s in size: + if hasattr(s, "__int__"): + total_size *= int(s) + else: + total_size *= s + buf_info.append((buf_name, shape, total_size)) + except Exception: + pass + + # Only use masked ops if: + # 1. All buffers are 1D (single-element shape tuples) + # 2. All buffers have the same size + # This ensures that pl.ds(block_size) flattening works correctly + # and masks can be properly applied without broadcasting issues. + if buf_info and len(buf_info) > 0: + # Check if all are 1D + all_1d = all(len(shape) == 1 for _, shape, _ in buf_info) + if not all_1d: + return False + + # Check if all have the same size + first_size = buf_info[0][2] + all_same_size = all(size == first_size for _, _, size in buf_info) + return all_same_size + + return False + + def _get_or_create_mask(self, buf_name: str) -> str: + """Get or create a unique mask variable for a buffer.""" + if buf_name not in self.tensor_masks: + mask_var = f"mask_{buf_name}" + self.tensor_masks[buf_name] = mask_var + return self.tensor_masks[buf_name] + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override] + buf = self.args.input(name) + dtype = V.graph.get_dtype(name) + + # Track the load index expression for argmax/argmin axis detection + self.load_index_exprs[name] = index + + # Determine masked ops strategy on first load/store if not yet determined + if self.use_masked_ops is None: + self.use_masked_ops = self._determine_masked_ops_for_kernel() + + index_str, needs_flatten = self._get_index_expr(index) + + # Build load expression using string concatenation + use_masked = index_str == "..." and not needs_flatten and self.use_masked_ops + + if use_masked: + # GPU masked load: flatten tensor and apply per-tensor mask + mask_var = self._get_or_create_mask(name) + load_expr = f"pltriton.load({buf}.at[pl.ds(block_size)], mask={mask_var})" + elif needs_flatten: + # Flatten then index for non-contiguous access + load_expr = f"{buf}[...].flatten()[{index_str}]" + else: + # Direct indexing for contiguous access + load_expr = f"{buf}[{index_str}]" + + return self.cse.generate( + self.compute, + load_expr, + dtype=dtype, + ) + + def _handle_mixed_indexing(self, index: sympy.Expr) -> str: + """ + Handle indexing with both indirect variables and iteration variables. + + For example, x[indices, :] generates index = i0 + stride * tmp0 + where tmp0 is loaded from indices and i0 is the iteration variable. + + We need to convert this to JAX advanced indexing with proper broadcasting. + When there are multiple iteration variables, they need different shapes + to form an outer product (grid) rather than broadcasting together. + """ + # Get iteration variables + iter_vars = OrderedSet(self.range_tree_nodes.keys()) + free_symbols = index.free_symbols + used_iter_vars_set = free_symbols & iter_vars + + if len(used_iter_vars_set) == 0: + return self.kexpr(index) + + # Sort iteration variables by their coefficient (stride) in the index expression. + # Variables with larger strides correspond to earlier output dimensions. + def get_coefficient(var): + """Extract the coefficient of a variable in the index expression.""" + coeff = index.coeff(var) + if coeff == 0: + # Variable appears in a more complex form, try differentiation + coeff = sympy.diff(index, var) + # Convert to int if possible for sorting + try: + return int(coeff) + except (TypeError, ValueError): + return 0 + + used_iter_vars = sorted(used_iter_vars_set, key=get_coefficient, reverse=True) + iter_coeffs = [get_coefficient(var) for var in used_iter_vars] + + index_str = self.kexpr(index) + indirect_var_syms = [s for s in free_symbols if str(s).startswith("tmp")] + indirect_vars = [str(sym) for sym in indirect_var_syms] + + # Get coefficients for indirect vars to determine output ordering + indirect_coeffs = {str(s): get_coefficient(s) for s in indirect_var_syms} + + # Build a sorted list of all components by coefficient (descending) + # Each component is (coeff, type, var) where type is 'iter' or 'indirect' + all_components = [] + for var in used_iter_vars: + all_components.append((get_coefficient(var), "iter", var)) + for sym in indirect_var_syms: + all_components.append((get_coefficient(sym), "indirect", sym)) + all_components.sort(key=lambda x: x[0], reverse=True) + + # Calculate trailing dims needed for each component + # Each component needs trailing dims for all subsequent iter vars + # plus trailing dims for all dimensions of subsequent indirect vars + # For simplicity, assume each indirect var contributes some dimensions + # that will be handled by the reshape at store time + + # For iter vars, we need to count how many dimensions come after in the output + for i, var in enumerate(used_iter_vars): + var_name = str(var) + if var in self.range_tree_nodes: + range_entry = self.range_tree_nodes[var] + range_size = range_entry.length + var_coeff = get_coefficient(var) + + arange_expr = f"jnp.arange({self.kexpr(range_size)})" + + # Count trailing dims needed: + # - One for each subsequent iter var (with smaller coeff) + # - One for each dimension of indirect vars with smaller coeff + # For indirect vars, assume each contributes 2 dims (common case) + # The actual reshape at store time will fix any shape mismatches + n_trailing_iter = sum(1 for c in iter_coeffs if c < var_coeff) + n_trailing_indirect = sum( + 2 for c in indirect_coeffs.values() if c < var_coeff + ) + n_trailing = n_trailing_iter + n_trailing_indirect + + if n_trailing > 0: + trailing_dims = ", None" * n_trailing + arange_expr = f"{arange_expr}[:{trailing_dims}]" + + index_str = index_str.replace(var_name, arange_expr) + + # Reshape indirect variables for proper broadcasting. + for indirect_var in indirect_vars: + indirect_coeff = indirect_coeffs[indirect_var] + + # Count dims needed before and after this indirect var + n_leading = sum(1 for c in iter_coeffs if c > indirect_coeff) + n_trailing = sum(1 for c in iter_coeffs if c < indirect_coeff) + + # Build the indexing expression with leading Nones, ellipsis, trailing Nones + if n_leading > 0 and n_trailing > 0: + leading_nones = "None, " * n_leading + trailing_nones = ", None" * n_trailing + reshape_expr = f"{indirect_var}[{leading_nones}...{trailing_nones}]" + elif n_leading > 0: + leading_nones = "None, " * n_leading + reshape_expr = f"{indirect_var}[{leading_nones}...]" + elif n_trailing > 0: + trailing_nones = ", None" * n_trailing + reshape_expr = f"{indirect_var}[...{trailing_nones}]" + else: + reshape_expr = indirect_var + + index_str = index_str.replace(indirect_var, reshape_expr) + + return index_str + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None + ) -> None: # type: ignore[override] + if mode is not None: + raise Unsupported("pallas store mode not supported") + out = self.args.output(name) + self.store_buffer_names.add(name) + + # Determine masked ops strategy on first load/store if not yet determined + if self.use_masked_ops is None: + self.use_masked_ops = self._determine_masked_ops_for_kernel() + + # Check if this is a scalar output (reduction to scalar) + # Only shape () is a true scalar, not (1,) which is a 1-element tensor + try: + buf = V.graph.get_buffer(name) + output_shape = buf.get_size() + is_scalar = len(output_shape) == 0 + except Exception: + output_shape = () + is_scalar = False + + if is_scalar: + # For scalar outputs, use [...] to assign the entire scalar + store_expr = f"{out}[...] = {value}" + else: + index_str, needs_flatten = self._get_index_expr(index) + + # Build store expression using string concatenation + use_masked = ( + index_str == "..." and not needs_flatten and self.use_masked_ops + ) + + if use_masked: + # GPU masked store: flatten tensor and apply per-tensor mask + mask_var = self._get_or_create_mask(name) + store_expr = f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})" + elif index_str == "...": + # When storing the full array, reshape to match the output shape. + # This handles: + # - Mixed indexing producing flat results needing reshape + # - Squeeze operations where value has more dims than output + # - If shapes already match, reshape is a no-op. + # Use the output array's shape at runtime to avoid issues with + # symbolic sizes not being defined in the kernel. + store_expr = f"{out}[...] = {value}.reshape({out}.shape)" + else: + # Direct indexed assignment + store_expr = f"{out}[{index_str}] = {value}" + + self.stores.writeline(store_expr) + # Track which output param this store uses for filtering in codegen_kernel + self.store_with_output.append((out, store_expr)) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: # type: ignore[override] + """ + Generate code for reduction operations in JAX/Pallas. + + Reductions in Pallas work by: + 1. Loading the input data into the kernel + 2. Applying JAX reduction operations (jnp.sum, jnp.max, etc.) + 3. Storing the reduced result + + The reduction happens over the loaded block of data. + """ + assert self.inside_reduction + + if isinstance(value, tuple): + raise Unsupported( + "Tuple reductions (e.g., welford_combine) not supported in Pallas backend" + ) + + # Check if this reduction is already cached + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + # Map reduction types to JAX functions + reduction_ops = { + "sum": "jnp.sum", + "prod": "jnp.prod", # CPU only - not supported in Pallas GPU (Triton) backend + "max": "jnp.max", + "min": "jnp.min", + "any": "jnp.any", + "argmax": "jnp.argmax", + "argmin": "jnp.argmin", + } + + # Determine if this is a partial reduction (has pointwise dimensions) + # or a full reduction to scalar + pointwise_prefixes = OrderedSet(["x", "y", "z"]) + has_pointwise = any(p in self.numels for p in pointwise_prefixes) + + # Get the individual pointwise dimension sizes from range_tree_nodes + pointwise_sizes = [] + for var, entry in sorted( + self.range_tree_nodes.items(), key=lambda x: str(x[0]) + ): + if not entry.prefix.startswith("r"): + try: + pointwise_sizes.append(int(entry.length)) + except (TypeError, ValueError): + pointwise_sizes = None + break + + # Get the pointwise and reduction numels + pointwise_numel = 1 + for p in pointwise_prefixes: + if p in self.numels: + numel = self.numels[p] + try: + pointwise_numel *= int(numel) + except (TypeError, ValueError): + pointwise_numel = None + break + + reduction_numel = 1 + for p in self.numels: + if p.startswith("r"): + numel = self.numels[p] + try: + reduction_numel *= int(numel) + except (TypeError, ValueError): + reduction_numel = None + break + + # Count the number of pointwise and reduction dimensions + n_reduction_dims = sum( + 1 + for var, entry in self.range_tree_nodes.items() + if entry.prefix.startswith("r") + ) + + if reduction_type == "xor_sum": + if has_pointwise and pointwise_numel and reduction_numel: + reduction_expr = f"jnp.bitwise_xor.reduce({value}.reshape({pointwise_numel}, -1), axis=-1)" + else: + reduction_expr = f"jnp.bitwise_xor.reduce({value})" + elif reduction_type in ("argmax", "argmin"): + # For argmax/argmin, we need to preserve the axis information + # because the result is indices, not values. + reduction_op = reduction_ops[reduction_type] + # Check if this is a true partial reduction (pointwise numel > 1) + # When pointwise_numel == 1, it's effectively a full reduction to scalar + is_partial_reduction = ( + has_pointwise and pointwise_numel and pointwise_numel > 1 + ) + if is_partial_reduction and n_reduction_dims > 0: + # Partial reduction: determine the reduction axis from load index + # The reduction variable's coefficient in the index expression tells us its stride + # Higher stride = outer axis (lower axis number in row-major order) + reduction_axis = 0 # Default to axis 0 + if self.load_index_exprs: + # Get the first load index expression + load_index = next(iter(self.load_index_exprs.values())) + # Find the reduction variable (starts with 'r') + reduction_vars = [ + var + for var, entry in self.range_tree_nodes.items() + if entry.prefix.startswith("r") + ] + if reduction_vars: + r_var = reduction_vars[0] + # Get the coefficient (stride) of the reduction variable + r_coeff = load_index.coeff(r_var) + try: + r_stride = int(r_coeff) if r_coeff != 0 else 1 + except (TypeError, ValueError): + r_stride = 1 + # Get pointwise variable + pw_vars = [ + var + for var, entry in self.range_tree_nodes.items() + if not entry.prefix.startswith("r") + ] + if pw_vars: + pw_var = pw_vars[0] + pw_coeff = load_index.coeff(pw_var) + try: + pw_stride = int(pw_coeff) if pw_coeff != 0 else 1 + except (TypeError, ValueError): + pw_stride = 1 + # Higher stride = earlier (outer) axis + # For 2D: axis 0 has stride = dim1_size, axis 1 has stride = 1 + reduction_axis = 0 if r_stride > pw_stride else 1 + if n_reduction_dims == 1: + reduction_expr = f"{reduction_op}({value}, axis={reduction_axis})" + else: + # Multiple reduction dims - reduce over all of them + axes = tuple(range(n_reduction_dims)) + reduction_expr = f"{reduction_op}({value}, axis={axes})" + else: + # Full reduction to scalar + reduction_expr = f"{reduction_op}({value})" + elif reduction_type in reduction_ops: + if ( + has_pointwise + and pointwise_numel + and reduction_numel + and pointwise_sizes + ): + # For partial reductions, we need to: + # 1. Move pointwise axes to the front and reduction axes to the back + # 2. Reshape to (pointwise_numel, reduction_numel) + # 3. Reduce over the last axis + # + # We use moveaxis to reorder: first move axes matching pointwise sizes + # to the front, then the remaining (reduction) axes go to the back. + # Finally reshape and reduce. + # + # Generate code to dynamically determine and reorder axes: + pw_sizes_str = str(pointwise_sizes) + reduction_op = reduction_ops[reduction_type] + reduction_expr = ( + f"(lambda v: (lambda pw_sizes: " + f"{reduction_op}(v.reshape(-1, {reduction_numel}), axis=-1) " + f"if v.ndim == 2 else " + f"(lambda input_shape, pw_axes: " + f"{reduction_op}(" + f"jnp.moveaxis(v, pw_axes, list(range(len(pw_axes)))).reshape({pointwise_numel}, -1), axis=-1)" + f")(" + f"v.shape, " + f"[i for i, s in enumerate(v.shape) if s in pw_sizes][:len(pw_sizes)]" + f")" + f")({pw_sizes_str}))({value})" + ) + else: + # Full reduction to scalar + reduction_expr = f"{reduction_ops[reduction_type]}({value})" + else: + raise Unsupported( + f"Reduction type '{reduction_type}' not yet supported in Pallas backend. " + f"Supported types: {list(reduction_ops.keys())}, xor_sum" + ) + + # Generate CSE variable for the reduction result + result = self.cse.generate( + self.compute, + reduction_expr, + dtype=dtype, + ) + + # Cache the result + self.cse.reduction_cache[cache_key] = result + return result + + @staticmethod + def _buffer_is_contiguous(buffer_name: str) -> bool: + buf = V.graph.get_buffer(buffer_name) + layout = buf.get_layout() + return layout.is_contiguous() + + def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override] + """ + Generate the complete Pallas kernel code as a Python string. + + This includes: + - Import statements for JAX/Pallas + - The kernel function that operates on refs + - The main wrapper function that handles PyTorch<->JAX conversions via DLPack + + Args: + name: Optional kernel name (will use placeholder if not provided) + + Returns: + str: Complete Python source code for the Pallas kernel + """ + code = IndentedBuffer() + + # Define the Pallas kernel: accepts refs, uses broadcasted expressions + arg_defs, _, _, _ = self.args.python_argdefs() + kernel_params = [a.name for a in arg_defs] + pure_out_params = [p for p in kernel_params if p.startswith("out_ptr")] + output_params = [ + p for p in kernel_params if p.startswith(("out_ptr", "in_out_ptr")) + ] + if not output_params: + raise RuntimeError("Pallas backend requires at least one output buffer") + + output_buffer_lookup = { + inner: outer + for outer, inner in self.args.output_buffers.items() + if isinstance(inner, str) + } + + kernel_name = name or "" + interpret_is_cpu = V.graph.get_current_device_or_throw().type == "cpu" + is_tpu = torch._inductor.config._debug_cpu_to_tpu_pallas + if is_tpu: + if not torch._inductor.config.pallas_take_first_jax_device_only: + raise RuntimeError( + "Pallas backend currently only supports using the first JAX device." + ) + if not has_tpu_pallas(): + raise RuntimeError( + "PALLAS_TARGET_TPU is set, but no TPU device was found. " + "Please make sure that you have a TPU available and that JAX is configured correctly." + ) + interpret_literal = "True" if interpret_is_cpu else "False" + + # For GPU (Triton backend), import pltriton for masked loads/stores + # Import math at module level if we'll use it for masked ops + imports = ( + """ + import functools + """ + + ("import math\n " if self.use_masked_ops else "") + + """import torch + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime + """ + + ( + "\n from jax.experimental.pallas import triton as pltriton" + if not interpret_is_cpu + else "" + ) + + ( + "\n from torch._inductor.runtime.runtime_utils import next_power_of_2" + if self.use_masked_ops + else "" + ) + ) + code.splice(imports, strip=True) + + aliasable_flags: dict[str, bool] = {} + for param in pure_out_params: + buffer_name = output_buffer_lookup.get(param) + is_contiguous = buffer_name is not None and self._buffer_is_contiguous( + buffer_name + ) + aliasable_flags[param] = (not interpret_is_cpu) and is_contiguous + alias_params = [ + f"{param}_alias" for param in pure_out_params if aliasable_flags[param] + ] + pointer_tail = [ + p for p in kernel_params if p.startswith(("in_out_ptr", "in_ptr")) + ] + kernel_input_params = alias_params + pointer_tail + full_kernel_params = alias_params + kernel_params + non_alias_out_set = OrderedSet( + [name for name, flag in aliasable_flags.items() if not flag] + ) + copy_output_indices = [ + idx for idx, name in enumerate(output_params) if name in non_alias_out_set + ] + self.aliasable_out_ptrs = aliasable_flags + + # For GPU with masked ops, add block_size as keyword-only parameter + kernel_signature = ( + f"def {kernel_name}_kernel({', '.join(full_kernel_params)}" + + (", *, block_size" if self.use_masked_ops else "") + + "):" + ) + code.writeline(kernel_signature) + with code.indent(): + # For masked ops on GPU, generate per-tensor masks at the start + if self.use_masked_ops and self.tensor_masks: + # Create a mapping from buffer name to parameter name + buf_to_param = {} + for outer, inner in self.args.input_buffers.items(): + buf_to_param[outer] = inner if isinstance(inner, str) else outer + for outer, inner in self.args.output_buffers.items(): + buf_to_param[outer] = inner if isinstance(inner, str) else outer + + # Generate a mask for each tensor that was accessed + for buf_name, mask_var in sorted(self.tensor_masks.items()): + param_name = buf_to_param.get(buf_name, buf_name) + # Find the corresponding parameter in kernel_params + matching_param = None + for p in kernel_params: + # Check if this parameter corresponds to the buffer + if param_name == p or buf_name in str(p): + matching_param = p + break + + if matching_param: + # Calculate flattened size for this tensor + code.writeline(f"# Mask for {buf_name}") + code.writeline(f"{mask_var}_size = {matching_param}.size") + code.writeline( + f"{mask_var} = jnp.arange(block_size) < {mask_var}_size" + ) + + # Generate iteration variables as jnp.arange arrays + # These are used by index_expr operations like torch.arange + # Skip on GPU with masked ops - iteration vars would create non-power-of-2 arrays + # which are not supported by Pallas Triton backend + if self.range_tree_nodes and not self.use_masked_ops: + code.writeline("# Define iteration variables as JAX arrays") + # Get the first output buffer's shape for reshaping + first_output_shape = None + first_output_numel = None + if output_params: + first_out_param = output_params[0] + first_out_buf_name = output_buffer_lookup.get(first_out_param) + if first_out_buf_name: + try: + buf = V.graph.get_buffer(first_out_buf_name) + size = buf.get_size() + first_output_shape = tuple( + int(s) if hasattr(s, "__int__") else s for s in size + ) + first_output_numel = 1 + for s in first_output_shape: + first_output_numel *= s + except Exception: + pass + + for var_sym, entry in self.range_tree_nodes.items(): + var_name = str(var_sym) + length = entry.length + length_str = self.kexpr(length) + # If the iteration variable length matches the output numel, + # reshape it to match the output shape for proper broadcasting + try: + length_val = int(length) if hasattr(length, "__int__") else None + except (TypeError, ValueError): + length_val = None + + # Skip symbolic lengths - jnp.arange requires concrete values + # This happens with dynamic shapes + if length_val is None: + continue + + if ( + first_output_shape + and len(first_output_shape) > 1 + and length_val == first_output_numel + ): + shape_str = ", ".join(str(s) for s in first_output_shape) + code.writeline( + f"{var_name} = jnp.arange({length_str}).reshape({shape_str})" + ) + else: + code.writeline(f"{var_name} = jnp.arange({length_str})") + + # Emit compute (CSE) and store lines; they reference *_ptr[index] directly. + for line in self.compute._lines: + code.writeline(str(line)) + # Filter stores to only emit those for outputs that are in kernel params. + # This handles cases where an intermediate value was stored but the buffer + # was later optimized away (not passed to the kernel). + for out_ptr, store_line in self.store_with_output: + if out_ptr in full_kernel_params: + code.writeline(store_line) + + jit_wrapper_name = f"{kernel_name}_jit_wrapper" + donate_indices = [] + for idx, name in enumerate(kernel_input_params): + if (name in alias_params) or name.startswith("in_out_ptr"): + donate_indices.append(idx + 2) + if donate_indices: + donate_literal = "(" + ", ".join(str(x) for x in donate_indices) + ",)" + else: + donate_literal = "()" + code.writeline( + "@functools.partial(" + "jax.jit, static_argnums=(0, 1), donate_argnums=" + f"{donate_literal})" + ) + code.writeline( + f"def {jit_wrapper_name}(out_shapes, out_dtypes, {', '.join(kernel_input_params)}):" + ) + with code.indent(): + code.writeline("out_specs = tuple(") + code.writeline(" jax.ShapeDtypeStruct(shape, dtype)") + code.writeline(" for shape, dtype in zip(out_shapes, out_dtypes)") + code.writeline(")") + + # For masked ops, calculate block_size as next power of 2 of max flattened size + if self.use_masked_ops: + code.writeline( + "# Calculate block_size as next power of 2 for Triton backend" + ) + code.writeline("# Find maximum flattened size across all tensors") + code.writeline("max_size = 0") + # Calculate size for all input tensors + for param in kernel_input_params: + code.writeline(f"max_size = max(max_size, {param}.size)") + # Also consider output shapes + code.writeline("for shape in out_shapes:") + code.writeline( + " tensor_size = shape[0] if len(shape) == 1 else math.prod(shape)" + ) + code.writeline(" max_size = max(max_size, tensor_size)") + code.writeline("block_size = next_power_of_2(max_size)") + + alias_pairs: list[tuple[int, int]] = [] + for out_idx, name in enumerate(output_params): + if name.startswith("out_ptr"): + if aliasable_flags.get(name, False): + alias_name = f"{name}_alias" + input_idx = kernel_input_params.index(alias_name) + alias_pairs.append((input_idx, out_idx)) + else: + input_idx = kernel_input_params.index(name) + alias_pairs.append((input_idx, out_idx)) + alias_map_literal = ", ".join(f"{i}: {o}" for (i, o) in alias_pairs) + + # For masked ops, wrap kernel with functools.partial to pass block_size + kernel_arg = ( + f"functools.partial({kernel_name}_kernel, block_size=block_size)," + if self.use_masked_ops + else f"{kernel_name}_kernel," + ) + code.writeline("return pl.pallas_call(") + code.writeline(" " + kernel_arg) + + code.writeline(" out_shape=out_specs,") + code.writeline(f" interpret={interpret_literal},") + code.writeline(" grid=(1,),") + code.writeline( + f" input_output_aliases={{ {alias_map_literal} }}," + if alias_pairs + else " input_output_aliases={}," + ) + code.writeline(")(") + if kernel_input_params: + code.writeline(f" {', '.join(kernel_input_params)},") + code.writeline(")") + + main_name = f"{kernel_name}_main" + code.writeline( + f"def {main_name}({', '.join(full_kernel_params)}, stream=None):" + ) + with code.indent(): + code.writeline("# Enable JAX x64 mode for float64/int64 support") + code.writeline("jax.config.update('jax_enable_x64', True)") + if alias_params: + code.writeline("# Convert Torch -> JAX for donated outputs") + for alias_name in alias_params: + # TODO: The `jax.device_put` path is a temporary workaround for a Mosaic compiler bug + # that occurs with DLPack. Once TorchTPU provides a direct method for placing a + # `torch.Tensor` on a TPU device, this should be reverted to use the + # `jax.dlpack.from_dlpack` path. + if is_tpu: + code.writeline( + f"{alias_name}_jax = jax.device_put({alias_name}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name}.detach())" + ) + code.writeline("# Convert Torch -> JAX for in-place tensors") + for ptr in pointer_tail: + if ptr.startswith("in_out_ptr"): + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.detach())" + ) + code.writeline("# Convert Torch -> JAX for inputs") + for ptr in pointer_tail: + if ptr.startswith("in_ptr"): + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.detach().contiguous())" + ) + + code.writeline("# Prepare output metadata from PyTorch tensor") + code.writeline( + "out_shapes = (" + + ", ".join([f"tuple({name}.shape)" for name in output_params]) + + ",)" + ) + code.writeline( + "out_dtypes = (" + + ", ".join( + [ + f"torch_dtype_to_jax_runtime({name}.dtype)" + for name in output_params + ] + ) + + ",)" + ) + arg_name_map: dict[str, str] = {} + for alias_name in alias_params: + arg_name_map[alias_name] = f"{alias_name}_jax" + for ptr in pointer_tail: + arg_name_map[ptr] = f"{ptr}_jax" + + if kernel_input_params: + alias_args_str = ", ".join( + arg_name_map[name] for name in kernel_input_params + ) + code.writeline( + f"res = {jit_wrapper_name}(out_shapes, out_dtypes, {alias_args_str})" + ) + else: + code.writeline(f"res = {jit_wrapper_name}(out_shapes, out_dtypes)") + if copy_output_indices: + code.writeline( + "result_values = res if isinstance(res, tuple) else (res,)" + ) + for idx in copy_output_indices: + name = output_params[idx] + if is_tpu: + code.writeline( + f"res_cpu = jax.device_get(result_values[{idx}])" + ) + code.writeline(f"{name}.copy_(torch.from_dlpack(res_cpu))") + else: + code.writeline( + f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))" + ) + + return code.getvalue() + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: # type: ignore[override] + """Generate the Python code that calls this Pallas kernel.""" + wrapper = V.graph.wrapper_code + arg_defs, call_args, _, _ = self.args.python_argdefs() + kernel_param_names = [a.name for a in arg_defs] + pure_out_params = [p for p in kernel_param_names if p.startswith("out_ptr")] + call_arg_strs = list(map(str, call_args)) + aliasable = getattr(self, "aliasable_out_ptrs", {}) + alias_call_args = [ + call_arg_strs[kernel_param_names.index(p)] + for p in pure_out_params + if aliasable.get(p, False) + ] + + # Generate kernel call: kernel_name.run(arg1, arg2, ...) + # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper + # which exposes a run() method + kernel_call = f"{name}.run({', '.join(alias_call_args + call_arg_strs)})" + wrapper.writeline(kernel_call) + + +class PallasScheduling(SIMDScheduling): + kernel_type = PallasKernel # type: ignore[assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + # Pallas/JAX can handle reductions to single elements efficiently + # without requiring split reductions + return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT]) + + def define_kernel( + self, + src_code: str, + node_schedule: Sequence[BaseSchedulerNode], + kernel: PallasKernel, + ) -> str: # type: ignore[override] + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + return wrapper.src_to_kernel[src_code] + + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"pallas_{kernel_hash}" + else: + kernel_name = f"pallas_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + + # Replace placeholder if any + src_code = src_code.replace("", kernel_name) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment) + + return kernel_name diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py new file mode 100644 index 0000000000000000000000000000000000000000..00833e1de702ca9922b41c53defc88c92fa6d350 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/python_wrapper_mtia.py @@ -0,0 +1,34 @@ +from typing import Optional +from typing_extensions import override + +from torch._inductor import ir + +from .wrapper import PythonWrapperCodegen + + +class PythonWrapperMtia(PythonWrapperCodegen): + """ + A thin wrapper of PythonWrapperCodegen with MTIA specific logic + """ + + @override + def write_header(self) -> None: + super().write_header() + + # MITA specific imports + self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library") + + @override + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ) -> PythonWrapperCodegen: + if is_subgraph: + # Delegate to the parent class to handle the case of subgraph + return PythonWrapperCodegen.create( + is_subgraph, subgraph_name, parent_wrapper, partition_signatures + ) + return PythonWrapperMtia() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py new file mode 100644 index 0000000000000000000000000000000000000000..277b6ed3749486074a583d61f6f2909886eb60c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -0,0 +1,627 @@ +# mypy: allow-untyped-defs +import copy +import logging +import random +from typing import Any +from typing_extensions import override + +from torch._inductor.virtualized import V + +from .rocm_template import ArgInfo + + +try: + import ck4inductor # type: ignore[import] +except ImportError: + ck4inductor = None + +if ck4inductor is not None: + from ck4inductor.grouped_conv_fwd.gen_instances import ( # type: ignore[import] + gen_conv_ops_library, + ) + from ck4inductor.grouped_conv_fwd.op import ( # type: ignore[import] # noqa: TCH002 + CKGroupedConvFwdOp, + ) +else: + + def gen_conv_ops_library(): + return [] + + +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def torch_layout_to_ck_layouts(torch_layout): + # logically, torch tensors are always NCHW, + # and channels-last memory layout is visible in the strides + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + # when input or output is NCHW + # NB: torch.conv2d result is always NCHW + return ["NGCHW", "GKCYX", "NGKHW"] + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + # when input or output or weight is channels-last + return ["NHWGC", "GKYXC", "NHWGK"] + else: + return None + + +def torch_layout_to_ck_input_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGCHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGC" + else: + return None + + +def torch_layout_to_ck_weight_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "GKCYX" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "GKYXC" + else: + return None + + +def torch_layout_to_ck_output_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGKHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGK" + else: + return None + + +class CKGroupedConvFwdTemplate(CKTemplate): + conv_template = r""" + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto conv = {{instance_type}} {}; + auto invoker = conv.MakeInvoker(); + + using ck::index_t; + + constexpr index_t NumDTensor = {{n_d_tensors}}; + constexpr index_t NDimSpatial = {{n_dim_spatial}}; + const std::vector FilterSize = { FilterSize_0, FilterSize_1 }; + const std::vector InputSize = { InputSize_0, InputSize_1 }; + const std::vector ConvolutionStrides = { ConvolutionStrides_0, ConvolutionStrides_1 }; + const std::vector Dilations = { Dilations_0, Dilations_1 }; + const std::vector LeftPads = { LeftPads_0, LeftPads_1 }; + const std::vector RightPads = { RightPads_0, RightPads_1 }; + + + auto conv_param = ck::utils::conv::ConvParam { + NDimSpatial, + GroupCount, + NBatch, + NOutChannels, + NInChannels, + FilterSize, + InputSize, + ConvolutionStrides, + Dilations, + LeftPads, + RightPads, + }; + + using InLayout = ck::tensor_layout::convolution::{{input_layout}}; + using WeiLayout = ck::tensor_layout::convolution::{{weight_layout}}; + using OutLayout = ck::tensor_layout::convolution::{{output_layout}}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const void* p_a = input; + const void* p_b = weight; + const std::array p_ds; + void* p_e = output; + std::array a_g_n_c_wis_lengths; + std::array a_g_n_c_wis_strides; + std::array b_g_k_c_xs_lengths; + std::array b_g_k_c_xs_strides; + std::array, NumDTensor> ds_g_n_k_wos_lengths; + std::array, NumDTensor> ds_g_n_k_wos_strides; + std::array e_g_n_k_wos_lengths; + std::array e_g_n_k_wos_strides; + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + const auto a_element_op = PassThrough {}; + const auto b_element_op = PassThrough {}; + const auto cde_element_op = PassThrough {}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto argument = conv.MakeArgument( + p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op + ); + if (!conv.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for conv instance " << conv.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = conv.GetWorkSpaceSize(&argument); + return 0; + } + + if (p_a == nullptr) { + std::cerr << "p_a is nullptr" << std::endl; + return -1; + } + if (p_b == nullptr) { + std::cerr << "p_b is nullptr" << std::endl; + return -1; + } + if (p_e == nullptr) { + std::cerr << "p_e is nullptr" << std::endl; + return -1; + } + + // when debugging, do time kernel to serialize launches + auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + + if (workspace != nullptr) { + conv.SetWorkSpacePointer(&argument, workspace, stream_config); + } + + // run the kernel + float elapsed_time = invoker.Run(argument, stream_config); + return 0; + } // kernel definition + } // extern C + + #ifdef GENERATE_CK_STANDALONE_RUNNER + int main(int argc, char** argv) { + (void) argc; + (void) argv; + return 0; + } + #endif // GENERATE_CK_STANDALONE_RUNNER +""" + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK conv globals + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GKCX = ck::tensor_layout::convolution::GKCX; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + using NGKW = ck::tensor_layout::convolution::NGKW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; + + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using KXGC = ck::tensor_layout::convolution::KXGC; + using KYXGC = ck::tensor_layout::convolution::KYXGC; + using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using NGCW = ck::tensor_layout::convolution::NGCW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; + + using G_K = ck::tensor_layout::convolution::G_K; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; + + using OutElementOp = PassThrough; + + namespace ck { + namespace utils { + namespace conv { + + ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + } // namespace conv + } // namespace utils + } // namespace ck + + const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + void HostTensorDescriptor::CalculateStrides() { + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); + } + """ + ) + return res + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK conv headers + + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + + #include "ck/library/utility/convolution_parameter.hpp" + #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + """ + ) + return res + + @staticmethod + def add_ck_conv_choices( + choices, + layout, + input_nodes, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + template = CKGroupedConvFwdTemplate( + input_nodes, + layout, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=n_spatial_dimensions, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def __init__( + self, + input_nodes, + layout, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + super().__init__( + "ck_conv_template", + input_nodes, + layout, + ) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.n_spatial_dimensions = n_spatial_dimensions + + def filter_op(self, op: "CKGroupedConvFwdOp"): # type: ignore[name-defined] + metas = [ + T.get_layout() + for T in [*self.input_nodes, self.output_node] + if T is not None + ] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.e_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_input_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_weight_layout(W_meta): + return None + if op.e_layout != torch_layout_to_ck_output_layout(Y_meta): + return None + # disable the instance if number of spatial dimensions doesn't match + if op.n_dim_spatial != self.n_spatial_dimensions: + return None + # disable 1x1 and odd-channels conv specializations for now + if "Default" not in op.conv_forward_specialization: + return None + return op + + def gen_ops(self): + unfiltered_instances = gen_conv_ops_library() + + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + def emit_ck_instance(self, op: "CKGroupedConvFwdOp") -> tuple[str, str]: # type: ignore[name-defined] + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore [bad-argument-type] + template_params.append(arg) + else: + if field_value is not None: + # pyrefly: ignore [bad-argument-type] + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGroupedConvFwdOp", # type: ignore[name-defined] + **kwargs, + ) -> str: + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + instance_definition, instance_type = self.emit_ck_instance(op) + + size_arg_strs = [ + "GroupCount", + "NBatch", + "NOutChannels", + "NInChannels", + "FilterSize_0", + "FilterSize_1", + "InputSize_0", + "InputSize_1", + "ConvolutionStrides_0", + "ConvolutionStrides_1", + "Dilations_0", + "Dilations_1", + "LeftPads_0", + "LeftPads_1", + "RightPads_0", + "RightPads_1", + ] + + return self._template_from_string(self.conv_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + instance_type=instance_type, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias] if Bias is not None else [X, W], + outputs=[Y], + names_str="input, weight, bias, output" + if Bias is not None + else "input, weight, output", + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + n_d_tensors=1 if Bias is not None else 0, + n_dim_spatial=self.n_spatial_dimensions, + input_layout=op.a_layout, + weight_layout=op.b_layout, + output_layout=op.e_layout, + ) + + def size_args(self): + x, w = self.input_nodes[0], self.input_nodes[1] + y = self.output_node + + group_count = self.groups + n_batch = x.shape[0] # type: ignore[index] + n_out_channels = y.shape[1] # type: ignore[index] + n_in_channels = x.shape[1] # type: ignore[index] + + filter_size_0, filter_size_1 = w.shape[2:4] # type: ignore[index] + input_size_0, input_size_1 = x.shape[2:4] # type: ignore[index] + convolution_strides_0, convolution_strides_1 = self.stride + dilations_0, dilations_1 = self.dilation + left_pads_0, left_pads_1 = self.padding + right_pads_0, right_pads_1 = self.padding + + return ( + group_count, + n_batch, + n_out_channels, + n_in_channels, + filter_size_0, + filter_size_1, + input_size_0, + input_size_1, + convolution_strides_0, + convolution_strides_1, + dilations_0, + dilations_1, + left_pads_0, + left_pads_1, + right_pads_0, + right_pads_1, + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b1eaf5c228eed80b5b9e40e3bbbd4e2de07b7c45 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py @@ -0,0 +1,110 @@ +from typing import Any +from typing_extensions import override + +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + +from .rocm_template import ArgInfo + + +class CKTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", # gfx94 + torch.float8_e4m3fn: "F8", # gfx95 + torch.float8_e5m2fnuz: "BF8", # gfx94 + torch.float8_e5m2: "BF8", # gfx95 + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + + #ifdef DEBUG_LOG + #define DEBUG_LOG_TMP DEBUG_LOG + #undef DEBUG_LOG + #else + #define DEBUG_LOG_TMP 0 + #endif + #include "ck/ck.hpp" + #undef DEBUG_LOG + #define DEBUG_LOG DEBUG_LOG_TMP + + #include "ck/utility/data_type.hpp" + #include "ck/library/utility/check_err.hpp" + #include "ck/library/utility/device_memory.hpp" + #include "ck/library/utility/fill.hpp" + #include "ck/library/utility/host_tensor.hpp" + #include "ck/library/utility/host_tensor_generator.hpp" + #include "ck/library/utility/literals.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK globals + + template + using S = ck::Sequence; + + template + using Tuple = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + using Scale = ck::tensor_operation::element_wise::Scale; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + + // see "composable_kernel/include/ck/utility/data_type.hpp" + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using F16 = ck::half_t; + using F32 = float; + // using F64 = double; + using BF16 = ck::bhalf_t; + // using I32 = int32_t; + // using I8 = int8_t; + // using I4 = ck::int4_t; + + #if DEBUG_LOG + static constexpr auto kDEBUG_LOG = 1; + #else + static constexpr auto kDEBUG_LOG = 0; + #endif + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py new file mode 100644 index 0000000000000000000000000000000000000000..70d31d635cc36dca295b1d82066376a1185c4da9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py @@ -0,0 +1,58 @@ +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + + +class CKTileTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", # gfx94 + torch.float8_e4m3fn: "F8", # gfx95 + torch.float8_e5m2fnuz: "BF8", # gfx94 + torch.float8_e5m2: "BF8", # gfx95 + } + + ck_dtype_to_size = { + "FP16": 2, + "BF16": 2, + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + #include "ck_tile/core.hpp" + + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using F8 = ck_tile::fp8_t; + using BF8 = ck_tile::bf8_t; + using F16 = ck_tile::half_t; + using F32 = float; + using BF16 = ck_tile::bfloat16_t; + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..94a79297ef5e47e16f98a8968c815262a9d24d75 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -0,0 +1,979 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import functools +import logging +import random +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_tile_template import CKTileTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.codegen.rocm.rocm_template import ArgInfo +from torch._inductor.ir import Buffer, Layout +from torch.utils._ordered_set import OrderedSet + +from ...utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def is_static_int(number): + import sympy + + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +class CKTileGemmTemplate(CKTileTemplate): + """ + This class is used for rendering CK-Tile Universal GEMM kernels + """ + + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + + using {{instance_namespace}}::BaseGemmPipeline; + using {{instance_namespace}}::TilePartitioner; + + constexpr auto TileK = {{instance_namespace}}::TileK; + constexpr auto kPrefetchStages = BaseGemmPipeline::PrefetchStages; + + const auto BiasTerms = std::array (); + const auto BiasStrides = std::array (); + + auto kargs = ck_tile::UniversalGemmKernelArgs<> { + {X}, + {W}, + BiasTerms, + Y, + M, + N, + K, + {LDA}, + {LDB}, + BiasStrides, + LDC, + kBatch + }; + + if (workspace_size) { + *workspace_size = 0; + return 0; + } + + // run the kernel + const auto dispatch = [&](const auto has_hot_loop_, const auto tail_number_) constexpr { + using Kernel = {{instance_namespace}}::Kernel; + + if (!Kernel::IsSupportedArgument(kargs)) { + // we do our best to statically avoid this case in `filter_op` + throw std::runtime_error("invalid argument"); + } + auto stream_config = ck_tile::stream_config{stream}; + auto grid_size = Kernel::GridSize(M, N, kBatch); + constexpr auto block_size = Kernel::BlockSize(); + constexpr auto lds_bytes = 0; + constexpr auto kBlockPerCU = 1; + auto gemm = ck_tile::make_kernel(Kernel{}, grid_size, block_size, lds_bytes, kargs); + float elapsed_time = ck_tile::launch_kernel(stream_config, gemm); + }; + + const ck_tile::index_t k_grain = kBatch * TileK; + const ck_tile::index_t K_split = (K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + {{rendered_dispatch}} + + return 0; + } // kernel definition + } // extern C + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + "ck_tile_gemm_template", + input_nodes=input_nodes, + layout=layout, + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK GEMM header(s) + + #include "ck_tile/ops/gemm.hpp" + #include "ck_tile/ops/epilogue.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + template + void dispatch_memory_pipeline_hot_loop(const ck_tile::TailNumber tail_num, Dispatcher dispatch) + { + if(tail_num == ck_tile::TailNumber::One) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + """ + ) + return res + + def check_dtypes(self, op: "CKTileGemmOperation"): + X_dtype, W_dtype, out_dtype = [ + T.get_layout().dtype for T in [*self.input_nodes, self.output_node] + ] + if op.datatype_a != self._TORCH_DTYPE_TO_CK[X_dtype]: + return False + if op.datatype_b != self._TORCH_DTYPE_TO_CK[W_dtype]: + return False + if op.datatype_c != self._TORCH_DTYPE_TO_CK[out_dtype]: + return False + return True + + def check_layouts(self, op: "CKTileGemmOperation"): + X_layout, W_layout, out_layout = [ + torch_layout_to_ck_layout(T.get_layout()) + for T in [*self.input_nodes, self.output_node] + ] + if op.layout_a != X_layout: + return False + if op.layout_b != W_layout: + return False + if op.layout_c != out_layout: + return False + return True + + def get_gemm_problem_size(self): + X_size, W_size = [T.get_layout().size for T in [*self.input_nodes]] + + M, K = X_size + _, N = W_size + + return M, N, K + + def check_block_tiles(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the block tile size + This helper function enforces it for the inputs and the output. + """ + M, N, K = self.get_gemm_problem_size() + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + if op.layout_a == "Row": + # handle in kBatch check + return True + elif op.layout_a == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_b == "Col": + # handle in kBatch check + return True + else: + raise AssertionError(f"Invalid {op.layout_b=}") + + if op.layout_c == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_c == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_c=}") + + return True + + def check_alignments(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the vector load size. + """ + M, N, K = self.get_gemm_problem_size() + + def max_alignment(contiguous_elements_per_tile, elements_per_thread, ck_dtype): + for vector_load_bytes in (16, 8, 4, 2, 1): + alignment = vector_load_bytes // self.ck_dtype_to_size[ck_dtype] + if ( + alignment > 0 + and contiguous_elements_per_tile % alignment == 0 + and elements_per_thread % alignment == 0 + ): + return alignment + + threads_per_block = ( + op.warp_m * op.warp_n * op.warp_k * self.gfx9_threads_per_warp + ) + a_elements_per_thread = op.tile_m * op.tile_k / threads_per_block + b_elements_per_thread = op.tile_n * op.tile_k / threads_per_block + + if op.layout_a == "Row": + # K is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_k, a_elements_per_thread, op.datatype_a + ) + if is_static_int(K) and K % a_max_vector_size != 0: + return False + elif op.layout_a == "Col": + # M is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_m, a_elements_per_thread, op.datatype_a + ) + if is_static_int(M) and M % a_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + # N is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_n, b_elements_per_thread, op.datatype_b + ) + if is_static_int(N) and N % b_max_vector_size != 0: + return False + elif op.layout_b == "Col": + # K is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_k, b_elements_per_thread, op.datatype_b + ) + if is_static_int(K) and K % b_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_b=}") + + # the `default` epilogue writes C to memory by 1 tensor element + # (divisibility check not necessary) + # the `cshuffle` epilogue writes C to memory by 16 bytes + # (so the contiguous C dimension size must be divisible by the number of tensor elements in 16 bytes) + if op.epilogue == "CShuffle": + if ( + op.layout_c == "Row" + and is_static_int(N) + and N % (16 / self.ck_dtype_to_size[op.datatype_c]) != 0 + ): + return False + + return True + + def check_warp_tiles(self, op: "CKTileGemmOperation"): + if op.tile_m % (op.warp_m * op.warp_tile_m) != 0: + return False + if op.tile_n % (op.warp_n * op.warp_tile_n) != 0: + return False + if op.tile_k % (op.warp_k * op.warp_tile_k) != 0: + return False + return True + + def check_block_tile_size(self, op: "CKTileGemmOperation"): + # assuming LDS size is 64KB + if op.pipeline == "CompV4": + max_block_tile_size = 2**15 + else: + max_block_tile_size = 2**16 + + block_tile_size = ( + self.ck_dtype_to_size[op.datatype_a] * op.tile_m * op.tile_k + + self.ck_dtype_to_size[op.datatype_b] * op.tile_n * op.tile_k + ) + if block_tile_size > max_block_tile_size: + return False + return True + + def filter_op(self, op: "CKTileGemmOperation"): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + if not self.check_dtypes(op): + return None + if not self.check_layouts(op): + return None + if not self.check_block_tiles(op): + return None + if not self.check_alignments(op): + return None + + return op + + def emit_ck_instance(self, op: "CKTileGemmOperation"): + """ + This method is used to generate code which defines the type alias for the generated kernel class + """ + template_definition = r""" + // Gemm operator {{operation_name}} + + namespace {{operation_name}} { + // block tile + constexpr int32_t TileM = {{tile_m}}; + constexpr int32_t TileN = {{tile_n}}; + constexpr int32_t TileK = {{tile_k}}; + // warps per block + constexpr int32_t WarpM = {{warp_m}}; + constexpr int32_t WarpN = {{warp_n}}; + constexpr int32_t WarpK = {{warp_k}}; + // xdl tile + constexpr int32_t WarpTileM = {{warp_tile_m}}; + constexpr int32_t WarpTileN = {{warp_tile_n}}; + constexpr int32_t WarpTileK = {{warp_tile_k}}; + + constexpr bool kPadM = {{m_is_padded}}; + constexpr bool kPadN = {{n_is_padded}}; + constexpr bool kPadK = {{k_is_padded}}; + + using ALayout = {{layout_a}}; + using BLayout = {{layout_b}}; + using CLayout = {{layout_c}}; + + using ADataType = {{datatype_a}}; + using BDataType = {{datatype_b}}; + using CDataType = {{datatype_c}}; + using AccDataType = F32; + + constexpr bool permuteA = false; + constexpr bool permuteB = false; + constexpr bool DoubleSmemBuffer = {{has_double_smem_buffer}}; + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + constexpr ck_tile::index_t TilePartitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + {{rendered_scheduler}} + + template + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + {{rendered_pipeline}} + + {{rendered_epilogue}} + + template + using Kernel = ck_tile::GemmKernel, GemmEpilogue>; + } + +""" + + def render_epilogue(epilogue_type): + if epilogue_type == "Default": + return r""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem; + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; + """ + elif epilogue_type == "CShuffle": + return r""" + constexpr auto kMemoryOperation = ck_tile::memory_operation_enum::set; + using DsDataType = ck_tile::tuple<>; // no bias terms for vanilla GEMM + using DsLayout = ck_tile::tuple<>; + constexpr auto ELayout = CLayout; + using CDEElementWise = ck_tile::element_wise::PassThrough; // no-op + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + """ + else: + raise AssertionError("Epilogue must be set") + + def render_pipeline(pipeline_type): + return rf""" + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCr{pipeline_type}; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCr{pipeline_type}>; + """ + + def render_scheduler(scheduler_type): + return rf""" + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::{scheduler_type}; + """ + + rendered_definition = self._template_from_string(template_definition).render( + operation_name=op.name(), + **asdict(op), + rendered_scheduler=render_scheduler(op.scheduler), + rendered_pipeline=render_pipeline(op.pipeline), + rendered_epilogue=render_epilogue(op.epilogue), + has_double_smem_buffer=("true" if op.pipeline == "CompV4" else "false"), + ) + return rendered_definition + + def render( # type: ignore[override] + self, kernel: ROCmTemplateKernel, op: "CKTileGemmOperation", **kwargs + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes") + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + assert 2 == len(self.input_nodes) + X, W = self.input_nodes + Y = self.output_node + + instance_definition = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + + def render_dispatch(pipeline_type, op_name): + switch_tailnum_template = r""" + switch (tail_num) { + {% for tail_num in valid_tailnums %} + case ck_tile::TailNumber::{{tail_num}}: + dispatch({{has_hot_loop}}, + ck_tile::integral_constant{}); + break; + {% endfor %} + default: + std::ostringstream err; + err << "Unsupported dispatch: " + << "Pipeline: " << "{{pipeline}}" + << "Prefetch stages: " << kPrefetchStages + << "Tail num: " << tail_num; + throw std::runtime_error(err.str()); + } // switch tail_num + """ + dispatch_template = r""" + if (has_hot_loop) { + {{rendered_with_hot_loop}} + } + else { // has_hot_loop == false + {{rendered_without_hot_loop}} + } // if has_hot_loop + """ + if pipeline_type == "CompV3": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "Mem": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop="dispatch_memory_pipeline_hot_loop(tail_num, dispatch);", + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "CompV4": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Two", "Three"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + else: + raise AssertionError(f"Pipeline {pipeline_type} is not supported") + + return self._template_from_string(self.gemm_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, Y", + size_args=[ + f"int32_t {arg}" for arg in ["M", "N", "K", "LDA", "LDB", "LDC"] + ], + ), + instance_namespace=op.name(), + version_comment=version_comment, + rendered_dispatch=render_dispatch(op.pipeline, op.name()), + ) + + def gen_ops(self): + """ + Creates a list of `CKTileGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + instances = ops() + if not instances: + raise AssertionError( + "No Composable Kernel Universal GEMM instances found. " + "Please check if the library is installed." + ) + filtered_instances = list(filter(self.filter_op, instances)) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_tile_max_profiling_configs), + ) + if config.rocm.ck_tile_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after sample: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKTileGemmTemplate( + input_nodes, + layout, + ) + ops = template.gen_ops() + for op in ops: + for k_batch in template.k_batch_choices(op): + template.maybe_append_choice( + choices, + op=op, + kBatch=k_batch, + ) + + def k_batch_choices(self, op: "CKTileGemmOperation") -> tuple[int, ...]: + """ + Returns a list of k_batch choices for the template. + """ + default_choices = (1, 2, 4, 8, 16, 32) + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + _, _, K, _, _, _ = self.size_args() + if op.layout_a == "Row" or op.layout_b == "Col": + choices = tuple( + filter( + lambda k_batch: check(K, op.tile_k * k_batch, op.k_is_padded), + default_choices, + ) + ) + else: + choices = default_choices + + if op.epilogue == "Default": + choices = (1,) + + return choices + + def size_args(self): + """ + Sizes and strides to be used for the kernel call + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + Y = self.output_node + + M = X.get_size()[0] + K = X.get_size()[1] + N = W.get_size()[1] + LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] + LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] + LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] + + return M, N, K, LDA, LDB, LDC + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + # maybe_append_choice kwarg for k_batch must match the name of the argument + arg_names = OrderedSet([arg.name for arg in self.get_runtime_arg_info()]) + if not arg_names.issubset(kwargs): + raise ValueError( + "Missing runtime arguments: " + ", ".join(arg_names - kwargs.keys()) + ) + return [kwargs[k] for k in arg_names] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f8ff54f9f45bc46fb3d4be5b74d36990fc69cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -0,0 +1,1019 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import copy +import logging +import math +import random +from collections import namedtuple +from typing import Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.compile_command import rocm_compile_command +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.ir import Buffer, Layout +from torch._inductor.runtime.runtime_utils import next_power_of_2 + +from ...utils import IndentedBuffer, is_dynamic, try_import_ck_lib + + +_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() + + +log = logging.getLogger(__name__) + +# lightweight collection of information about a single op +InductorROCmOp = namedtuple("InductorROCmOp", ["op", "kBatch"]) + +padding_lookup = { + "M": { + "GemmSpecialization::MPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "N": { + "GemmSpecialization::NPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "K": { + "GemmSpecialization::KPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, +} + + +def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +class CKGemmTemplate(CKTemplate): + # the JINJA template for rendering CK Universal GEMMs + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto gemm = {{instance_type}} {}; + auto invoker = gemm.MakeInvoker(); + {% if is_batched %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + B, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + M * K, // batch_stride_A + N * K, // batch_stride_B + std::array{ {{ds_batch_strides}} }, + M * N, // batch_stride_C + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% else %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + kBatch, // kBatch + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% endif %} + if (!gemm.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = gemm.GetWorkSpaceSize(&argument); + return 0; + } + // run the kernel + #ifdef GENERATE_CK_STANDALONE_RUNNER + const auto stream_config = StreamConfig{ + stream, + /* time kernel */ 1, + /* log level */ 1, + /* n_cold_iter */ 100, + /* n_hot_iter */ 100, + /* flush_l2_cache */ 1, + /* rotate_count */ 5}; + #else + const auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + #endif + + const float elapsed_time = invoker.Run(argument, stream_config); + + #ifdef GENERATE_CK_STANDALONE_RUNNER + std::cout << "elapsed time: " << elapsed_time << " ms" << std::endl; + #else + (void)elapsed_time; + #endif + return 0; + } // kernel definition + } // extern C + """ + + standalone_runner_template = r""" + #ifdef GENERATE_CK_STANDALONE_RUNNER + // standalone runner for the generated CK GEMM kernel + + {{inline_utils}} + + extern "C" { + int run_main(int argc, char** argv) { + {% if is_batched %} + const int32_t B = {{B}}; + {% endif %} + const int32_t M = {{M}}; + const int32_t N = {{N}}; + const int32_t K = {{K}}; + const int32_t LDA = {{LDA}}; + const int32_t LDB = {{LDB}}; + const int32_t LDC = {{LDC}}; + const int32_t LDD = {{LDD}}; + const int32_t kBatch = {{kBatch}}; + + using AElementType = {{a_ck_dtype}}; + using BElementType = {{b_ck_dtype}}; + using CElementType = {{c_ck_dtype}}; + {% if has_bias %} + using BiasElementType = {{bias_ck_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAElementType = {{scale_a_ck_dtype}}; + using ScaleBElementType = {{scale_b_ck_dtype}}; + {% endif %} + + using AArgType = {{a_torch_dtype}}; + using BArgType = {{b_torch_dtype}}; + using CArgType = {{c_torch_dtype}}; + {% if has_bias %} + using BiasArgType = {{bias_torch_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAArgType = {{scale_a_torch_dtype}}; + using ScaleBArgType = {{scale_b_torch_dtype}}; + {% endif %} + + using ALayout = {{a_layout}}; + using BLayout = {{b_layout}}; + using CLayout = {{c_layout}}; + {% if has_bias %} + using BiasLayout = {{bias_layout}}; + {% endif %} + + {% if is_batched %} + using strides_t = std::array; + auto get_strides = [](int32_t batch_stride, int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {batch_stride, leading_dimension, 1}; + } + return {batch_stride, 1, leading_dimension}; + }; + auto a_size = strides_t{B, M, K}; + auto a_stride = get_strides(M * K, LDA, ALayout{}); + auto b_size = strides_t{B, N, K}; + auto b_stride = get_strides(N * K, LDB, BLayout{}); + auto c_size = strides_t{B, M, N}; + auto c_stride = get_strides(M * N, LDC, CLayout{}); + {% else %} + using strides_t = std::array; + auto get_strides = [](int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {leading_dimension, 1}; + } + return {1, leading_dimension}; + }; + auto a_size = strides_t{M, K}; + auto a_stride = get_strides(LDA, ALayout{}); + auto b_size = strides_t{N, K}; + auto b_stride = get_strides(LDB, BLayout{}); + auto c_size = strides_t{M, N}; + auto c_stride = get_strides(LDC, CLayout{}); + {% endif %} + + Tensor a_m_k ( HostTensorDescriptor ( a_size, a_stride ) ); + Tensor b_k_n ( HostTensorDescriptor ( b_size, b_stride ) ); + {% if has_bias %} + Tensor d_m_n ( HostTensorDescriptor ( c_size, get_strides(LDD, BiasLayout{}) ) ); + {% endif %} + {% if has_scale %} + // NB: these are hardcoded + Tensor s_a_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Row{}) )); + Tensor s_b_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Col{}) )); + {% endif %} + + Tensor c_m_n_host ( HostTensorDescriptor ( c_size, c_stride ) ); + Tensor c_m_n_device ( HostTensorDescriptor ( c_size, c_stride ) ); + + a_m_k.GenerateTensorValue(GeneratorTensor_2()); + b_k_n.GenerateTensorValue(GeneratorTensor_2()); + {% if has_bias %} + d_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + {% if has_scale %} + s_a_m_n.GenerateTensorValue(GeneratorTensor_2()); + s_b_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + DeviceMem a_m_k_device_buf(sizeof(AElementType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BElementType) * b_k_n.mDesc.GetElementSpaceSize()); + {% if has_bias %} + DeviceMem d_m_n_device_buf(sizeof(BiasElementType) * d_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + {% if has_scale %} + DeviceMem s_a_m_n_device_buf(sizeof(ScaleAElementType) * s_a_m_n.mDesc.GetElementSpaceSize()); + DeviceMem s_b_m_n_device_buf(sizeof(ScaleBElementType) * s_b_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + DeviceMem c_m_n_device_buf(sizeof(CElementType) * c_m_n_device.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + {% if has_bias %} + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + {% endif %} + {% if has_scale %} + s_a_m_n_device_buf.ToDevice(s_a_m_n.mData.data()); + s_b_m_n_device_buf.ToDevice(s_b_m_n.mData.data()); + {% endif %} + + {{kernel_name}}( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {% if has_scale %} + static_cast(s_a_m_n_device_buf.GetDeviceBuffer()), + static_cast(s_b_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + {% if has_bias %} + static_cast(d_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + {% if is_batched %} + B, + {% endif %} + M, + N, + K, + LDA, + LDB, + LDC, + LDD, + nullptr, // workspace_size + nullptr, // workspace + nullptr); // stream + + hip_check_error(hipDeviceSynchronize()); + + return 0; + } // run_main + } // extern C + + int main(int argc, char** argv) { + return run_main(argc, argv); + } + // compile with: {{compile_cmd}} + #endif // GENERATE_CK_STANDALONE_RUNNER + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ) -> None: + is_batched = len(layout.size) == 3 + name = "ck_batched_gemm_template" if is_batched else "ck_gemm_template" + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + input_reorder=input_reorder, + ) + self.alpha = alpha + self.beta = beta + self.is_batched = is_batched + + def header(self) -> IndentedBuffer: + res = super().header() + if self.is_batched: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + else: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + struct MultiplyMultiplyAdd { + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const { + e = ck::type_convert( + ck::type_convert(c) + * ck::type_convert(d0) + * ck::type_convert(d1) + + ck::type_convert(d2) + ); + } + }; + """ + ) + return res + + def inline_utils(self): + res = IndentedBuffer() + res.splice( + """ + #include "host_tensor.cpp" + #include "device_memory.cpp" + """ + ) + return res + + def _has_padding(self, dimension, gemm_specialization): + # Get the relevant padding map for the given dimension + dimension_padding = padding_lookup.get(dimension, {}) + + # Check if the specialization is in the dimension's padding map + return dimension_padding.get(gemm_specialization, False) + + def filter_op(self, op_info: InductorROCmOp): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + op, kBatch = op_info.op, op_info.kBatch + metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_layout(W_meta): + return None + if op.c_layout != torch_layout_to_ck_layout(Y_meta): + return None + # try to avoid launching the instance with invalid problem size + # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity + + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + + if is_static_int(M): + if not self._has_padding("M", op.gemm_specialization): + if M % op.m_per_block != 0: + return None + if is_static_int(N): + if not self._has_padding("N", op.gemm_specialization): + if N % op.n_per_block != 0: + return None + if is_static_int(K): + if not self._has_padding("K", op.gemm_specialization): + if K % op.k_per_block != 0: + return None + K_t = kBatch * op.k_per_block + if K % K_t != 0: + return None + else: + # need another kBatch check here + lcm = abs(op.a_k1 * op.b_k1) // math.gcd(op.a_k1, op.b_k1) + K_t = kBatch * lcm + k_read_pad_splited = math.ceil(K / K_t) * lcm + if (k_read_pad_splited * (kBatch - 1)) >= K: + return None + + a_contig_size = ( + K if op.a_layout == "Row" else M if op.a_layout == "Col" else None + ) + if ( + is_static_int(a_contig_size) + and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 + ): + return None + b_contig_size = ( + N if op.b_layout == "Row" else K if op.b_layout == "Col" else None + ) + if ( + is_static_int(b_contig_size) + and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 + ): + return None + c_contig_size = ( + N if op.c_layout == "Row" else M if op.c_layout == "Col" else None + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block[0] + if isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ) + else op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + if ( + is_static_int(c_contig_size) + and c_contig_size % c_shuffle_block_transfer_scalar_per_vector_n_per_block + != 0 + ): + return None + if not self._check_num_k_loops(op, kBatch): + return None + # TBD disable instances with invalid number of pipeline prefetch stages + # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check + + return op + + def _check_num_k_loops(self, op, kBatch): + # Additional splitK scenario check + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + K = X_meta.size[-1] + if kBatch > 1: + if op.block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1": + try: + prefetch_stages = self._prefetch_stages( + op, + torch.empty((), dtype=X_meta.dtype).element_size(), + torch.empty((), dtype=W_meta.dtype).element_size(), + torch.cuda.get_device_properties(X_meta.device).warp_size, + ) + except Exception as e: + log.debug( # noqa: G200 + "Failed to prefetch_stages for %s with exception %s", op.name, e + ) + # be conservative here and disable the op + return False + + K_t = op.k_per_block * kBatch + ak0 = (K + K_t - 1) // K_t * (op.k_per_block // op.a_k1) + num_k_loop = ak0 // (op.k_per_block // op.a_k1) + if num_k_loop <= prefetch_stages: + log.debug( + "Op %s is not compatible due to invalid number of pipeline prefetch stages. " + "Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s", + op.name(), + kBatch, + op.block_gemm_pipeline_version, + prefetch_stages, + num_k_loop, + ) + return False + + return True + + # small helper to figure out the prefetch stages on AMD + def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64): + version_str = op.block_gemm_pipeline_version.split("::")[-1] + try: + version = int(version_str[1:]) # Assuming the format is always 'vX' + except ValueError as e: + raise ValueError(f"Invalid version string: {version_str}") from e + if version not in [1, 2, 3, 4, 5]: + raise ValueError( + f"unknown prefetch stages for {op.block_gemm_pipeline_version}" + ) + # Define the mapping of versions to stages + version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3} + # Get the stages for the given version + stages = version_to_stages.get(version) + if stages is None: + # This means we're at stage 2, and this requires computation + # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 + wgp_per_cu = max(4 * warp_size // op.block_size, 1) + full_mem_band_prefetch_stages = math.ceil( + 32768 + / wgp_per_cu + / ( + (op.m_per_block * a_dtype_size + op.n_per_block * b_dtype_size) + * op.k_per_block + ) + ) + stages = min(max(full_mem_band_prefetch_stages, 2), 8) + + return stages + + def emit_ck_instance(self, op: "CKGemmOperation"): + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + struct_name = ( + "DeviceBatchedGemmMultiD_Xdl_CShuffle_V3" + if self.is_batched + else "DeviceGemmMultiD_Xdl_CShuffle_V3" + ) + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::{{struct_name}}< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore [bad-argument-type] + template_params.append(arg) + else: + if field_value is not None: + # pyrefly: ignore [bad-argument-type] + template_params.append(f"/* {field_name} */ {field_value}") + operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") + return self._template_from_string(template_definition).render( + operation_name=operation_name, + template_params=(",\n" + 12 * " ").join(template_params), + struct_name=struct_name, + ), self._template_from_string(template_type).render( + operation_name=operation_name + ) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGemmOperation", + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes") + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + # input nodes: + # * X, W for matmul + # * X, W, Bias for addmm + # * X, W, inv_scale_x, inv_scale_w for scaled_mm + # * X, W, inv_scale_x, inv_scale_w, Bias for scaled_mm with bias + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = ( + self.input_nodes[2] + if 3 == len(self.input_nodes) + else self.input_nodes[4] + if 5 == len(self.input_nodes) + else None + ) + has_bias = Bias is not None + has_scale = len(self.input_nodes) in (4, 5) + op = copy.deepcopy(op) + + # This parameter is converted into tuple because of change + # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. + # The first tuple element corresponds to matmul result... + if not isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ): + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, + ) + + if has_scale: + scale_x = self.input_nodes[2] + scale_w = self.input_nodes[3] + if 1 == scale_x.get_numel() and 1 == scale_w.get_numel(): + # tensorwise scale for both X, W + if has_bias: + op.c_elementwise_op = "ScaleAdd" + else: + op.c_elementwise_op = "Scale" + else: + # rowwise scale for both X, W + if has_bias: + op.c_elementwise_op = "MultiplyMultiplyAdd" + else: + op.c_elementwise_op = "MultiplyMultiply" + op.c_shuffle_dtype = "F32" + op.ds_layouts = ( + torch_layout_to_ck_layout(scale_x.get_layout()), + torch_layout_to_ck_layout(scale_w.get_layout()), + ) + op.ds_element_dtypes = ( + self._TORCH_DTYPE_TO_CK[scale_x.get_layout().dtype], + self._TORCH_DTYPE_TO_CK[scale_w.get_layout().dtype], + ) + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += (1, 1) + else: + scale_x = None + scale_w = None + + bias_dtype = "" + if Bias is not None: + bias_layout = torch_layout_to_ck_layout(Bias.get_layout()) + bias_dtype = self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype] + op.ds_layouts += (bias_layout,) + op.ds_element_dtypes += (bias_dtype,) + if not has_scale: + op.c_elementwise_op = "Bilinear" + # c_shuffle_dtype is also used for adding bias to matmul result + # before converting down to the result dtype + op.c_shuffle_dtype = op.acc_dtype + # this parameter needs to be set accordingly to bias stride for correct accumulation + if bias_layout == "Row": + # bias has (N, ) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + elif bias_layout == "Col": + # bias has (M, 1) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) + else: + raise AssertionError( + "Bias layout is neither row-major nor column-major" + ) + # ...and the second tuple element corresponds to the bias + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( + bias_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + + instance_definition, instance_type = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + epilogue = None + + if op.c_elementwise_op == "Bilinear" and scale_w is None: + epilogue = f"Bilinear {{ {self.alpha}, {self.beta} }}" + + elif op.c_elementwise_op == "Scale": + epilogue = "Scale { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "ScaleAdd": + epilogue = "ScaleAdd { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "MultiplyMultiply": + epilogue = "MultiplyMultiply {}" + + elif op.c_elementwise_op == "MultiplyMultiplyAdd": + epilogue = "MultiplyMultiplyAdd {}" + + elif op.c_elementwise_op == "PassThrough": + epilogue = "PassThrough {}" + + assert epilogue is not None, "CK GEMM epilogue is not set" + + size_arg_strs = ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] + if self.is_batched: + size_arg_strs.insert(0, "B") + + res = self._template_from_string(self.gemm_template).render( + inline_utils=self.inline_utils(), + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W, scale_x, scale_w, Bias], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, inv_scale_x, inv_scale_w, Bias, Y", + input_reorder=self.input_reorder, + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + instance_type=instance_type, + a_element_dtype=op.a_element_dtype, + b_element_dtype=op.b_element_dtype, + c_element_dtype=op.c_element_dtype, + bias_element_dtype=bias_dtype, + alpha=self.alpha, + beta=self.beta, + a_elementwise_op="PassThrough {}", + b_elementwise_op="PassThrough {}", + epilogue=epilogue, + has_bias=has_bias, + ds_size=1 + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else 2 + if op.c_elementwise_op == "MultiplyMultiply" + else 3 + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else 0, + ds_names=", ".join( + ["Bias"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["inv_scale_x", "inv_scale_w"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["inv_scale_x", "inv_scale_w", "Bias"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + ds_strides=", ".join( + ["LDD"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["0", "0"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["0", "0", "LDD"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + version_comment=version_comment, + is_batched=self.is_batched, + ds_batch_strides=", ".join([]), # FIXME when supporting baddbmm + ) + + if config.rocm.generate_test_runner: + is_static_problem = all(is_static_int(arg) for arg in self.size_args()) + # NOTE: size_arg_strs is defined above + size_arg_vals = ( + self.size_args() + if is_static_problem + else ( + f"std::stoi(argv[{k}])" for k, _ in enumerate(self.size_args(), 1) + ) + ) + size_args = dict(zip(size_arg_strs, size_arg_vals, strict=True)) + runtime_args = dict( + zip( + [a.name for a in self.get_runtime_arg_info()], + self.get_runtime_arg_values(), + ) + ) + runner_code = self._template_from_string( + self.standalone_runner_template + ).render( + inline_utils=self.inline_utils().getvalue(), + kernel_name=kernel.kernel_name, + has_bias=has_bias, + has_scale=has_scale, + is_batched=self.is_batched, + a_ck_dtype=op.a_element_dtype, + b_ck_dtype=op.b_element_dtype, + c_ck_dtype=op.c_element_dtype, + bias_ck_dtype=op.ds_element_dtypes[0] if has_bias else "", + scale_a_ck_dtype=op.ds_element_dtypes[0] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + scale_b_ck_dtype=op.ds_element_dtypes[1] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + a_torch_dtype=DTYPE_TO_CPP[X.get_layout().dtype], + b_torch_dtype=DTYPE_TO_CPP[W.get_layout().dtype], + c_torch_dtype=DTYPE_TO_CPP[Y.get_layout().dtype], + bias_torch_dtype=DTYPE_TO_CPP[Bias.get_layout().dtype] + if Bias is not None + else "", + scale_a_torch_dtype=DTYPE_TO_CPP[scale_x.get_layout().dtype] + if scale_x is not None + else "", + scale_b_torch_dtype=DTYPE_TO_CPP[scale_w.get_layout().dtype] + if scale_w is not None + else "", + a_layout=torch_layout_to_ck_layout(X.get_layout()), + b_layout=torch_layout_to_ck_layout(W.get_layout()), + c_layout=torch_layout_to_ck_layout(Y.get_layout()), + bias_layout=torch_layout_to_ck_layout(Bias.get_layout()) + if Bias is not None + else "", + compile_cmd=rocm_compile_command( + [""], "", "exe" + ), + **size_args, + **runtime_args, + ) + res += runner_code + + return res + + def _is_rcr_f16(self): + X_meta, W_meta, Y_meta = ( + T.get_layout() for T in [*self.input_nodes, self.output_node] + ) + X_dtype, W_dtype, Y_dtype = ( + self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) + ) + X_layout, W_layout, Y_layout = ( + torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) + ) + + return ( + X_dtype == "F16" + and W_dtype == "F16" + and Y_dtype == "F16" + and X_layout == "Row" + and W_layout == "Col" + and Y_layout == "Row" + ) + + # helper to calculate a potentially optimal kBatch(es) for a problem + def _get_kBatch(self, op): + # we only set a higher kBatch if K > 16 * the larger of M and N + # this is a hand-tuned heuristic to start + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + if is_dynamic(*self.input_nodes): + return [1] + if K // max(M, N) < config.rocm.split_k_threshold: + return [1] + # if the user is telling us which kBatches to sweep, just use those + if config.rocm.kBatch_sweep is not None: + return config.rocm.kBatch_sweep + # Calculate the number of blocks needed for each dimension + total_k_blocks = math.ceil(K / op.k_per_block) + # we want to calculate how many blocks we need to fit per CU + cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count + # again, manual heuristics as much larger kBatch are significantly worse in + # initial testing + kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128) + return [kBatch] + + def gen_ops(self) -> list[InductorROCmOp]: + """ + Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + try: + from ck4inductor.batched_universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_batched_gemm_ops_library, + ) + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_gemm_ops_library, + gen_ops_preselected as gen_gemm_ops_preselected, + ) + except ImportError: + return [] + + generator = None + if self.is_batched: + generator = gen_batched_gemm_ops_library + else: + generator = gen_gemm_ops_library + if config.rocm.use_preselected_instances and self._is_rcr_f16(): + generator = gen_gemm_ops_preselected + + assert generator is not None + + rops = generator() + ops = [] + for o in rops: + kBatches = self._get_kBatch(o) + for kBatch in kBatches: + # pyrefly: ignore [bad-argument-type] + ops.append(InductorROCmOp(op=o, kBatch=kBatch)) + + filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) + + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_ck_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op.op, + kBatch=op.kBatch, + ) + + def size_args(self): + X = self.input_nodes[0] + W = self.input_nodes[1] + Bias = ( + self.input_nodes[2] + if len(self.input_nodes) == 3 + else self.input_nodes[4] + if len(self.input_nodes) == 5 + else None + ) + Y = self.output_node + + M = X.get_size()[-2] + K = X.get_size()[-1] + N = W.get_size()[-1] + LDA = X.get_stride()[-2 if X.get_stride()[-1] == 1 else -1] + LDB = W.get_stride()[-2 if W.get_stride()[-1] == 1 else -1] + LDC = Y.get_stride()[-2 if Y.get_stride()[-1] == 1 else -1] + LDD = ( + 0 + if (Bias is None or len(Bias.get_size()) == 1) + else Bias.get_stride()[-2 if Bias.get_stride()[-1] == 1 else -1] + ) + if self.is_batched: + B = X.get_size()[0] + return B, M, N, K, LDA, LDB, LDC, LDD + else: + return M, N, K, LDA, LDB, LDC, LDD diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py new file mode 100644 index 0000000000000000000000000000000000000000..aa935b14af23c2efd667871df5e05798a4434fa8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py @@ -0,0 +1,153 @@ +# mypy: allow-untyped-defs +import logging +import os +from typing import Optional + +from torch._inductor import config +from torch._inductor.utils import is_linux, try_import_ck_lib + + +log = logging.getLogger(__name__) + + +def _rocm_include_paths(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_include = ( + os.path.join(config.rocm.rocm_home, "include") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("include") + ) + + if config.is_fbcode(): + from libfb.py import parutil + + ck_path = parutil.get_dir_path("composable-kernel-headers") + else: + if not config.rocm.ck_dir: + ck_dir, _, _, _ = try_import_ck_lib() + if not ck_dir: + log.warning("Unspecified Composable Kernel directory") + config.rocm.ck_dir = ck_dir + ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home( + "composable_kernel" + ) + + log.debug("Using ck path %s", ck_path) + + ck_include = os.path.join(ck_path, "include") + ck_library_include = os.path.join(ck_path, "library", "include") + + # CK has to take priority over ROCm include paths + # Since CK is potentially more up-to-date + paths = [ + os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include) + ] + if dst_file_ext == "exe": + ck_utility_include = os.path.join(ck_path, "library", "src", "utility") + paths.append(os.path.realpath(ck_utility_include)) + return paths + + +def _rocm_lib_options(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_lib_dir = ( + os.path.join(config.rocm.rocm_home, "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("lib") + ) + hip_lib_dir = ( + os.path.join(config.rocm.rocm_home, "hip", "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("hip", "lib") + ) + + opts = [ + "-include __clang_hip_runtime_wrapper.h", + f"-L{os.path.realpath(rocm_lib_dir)}", + f"-L{os.path.realpath(hip_lib_dir)}", + "-lamdhip64", + ] + if dst_file_ext == "exe": + opts += ["-lpthread", "-lstdc++"] + return opts + + +def _rocm_compiler_options() -> list[str]: + arch_list = config.rocm.arch or ["native"] + gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] + opts = [ + config.rocm.compile_opt_level, + "-x", + "hip", + "-std=c++17", + *gpu_arch_flags, + "-fno-gpu-rdc", + "-fPIC", + "-fvisibility=hidden", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + "-mllvm", + "-enable-post-misched=0", + ] + if config.rocm.is_debug: + opts += ["-DDEBUG_LOG=1", "-g"] + if config.rocm.save_temps: + opts += ["--save-temps=obj"] + if config.rocm.print_kernel_resource_usage: + opts += ["-Rpass-analysis=kernel-resource-usage"] + if config.rocm.flush_denormals: + opts += ["-fgpu-flush-denormals-to-zero"] + if config.rocm.use_fast_math: + opts += ["-ffast-math"] + return opts + + +def rocm_compiler() -> Optional[str]: + if is_linux(): + if config.rocm.rocm_home: + return os.path.realpath( + os.path.join(config.rocm.rocm_home, "llvm", "bin", "clang") + ) + try: + from torch.utils import cpp_extension + + return os.path.realpath( + cpp_extension._join_rocm_home("llvm", "bin", "clang") + ) + except OSError: + # neither config.rocm.rocm_home nor env variable ROCM_HOME are set + return "clang" + return None + + +def rocm_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[list[str]] = None, +) -> str: + include_paths = _rocm_include_paths(dst_file_ext) + lib_options = _rocm_lib_options(dst_file_ext) + compiler_options = _rocm_compiler_options() + compiler = rocm_compiler() + options = ( + compiler_options + + (extra_args or []) + + [f"-I{path}" for path in include_paths] + + lib_options + ) + src_file = " ".join(src_files) + # supported extensions: .o, .so, .exe + if dst_file_ext == "o": + options.append("-c") + elif dst_file_ext == "so": + options.append("-shared") + elif dst_file_ext == "exe": + options.append("-DGENERATE_CK_STANDALONE_RUNNER") + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + return f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a87bef820dfc76037b5294b00a5f25f26be223 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from ctypes import byref, c_int, c_size_t, c_void_p +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.autotune_process import ( + BenchmarkRequest, + GPUDeviceBenchmarkMixin, + TensorMeta, +) +from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + +log = logging.getLogger(__name__) + + +class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate code cache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + ROCmCodeCache.compile(self.source_code, "so") + if config.rocm.generate_test_runner: + ROCmCodeCache.compile(self.source_code, "exe") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + size_args = [c_int(arg) for arg in self.extra_args] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *size_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + dict.fromkeys(meta.name for meta in self.input_tensor_meta) + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + size_args = [c_int(arg) for arg in self.extra_args] + run_method( + *args, # input ptrs and output ptrs + *size_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..ec58e458df6b110fab0c452ec261861d0c2d7cef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import cast + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for ROCm C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and ROCm C++ specific template code generation. + """ + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ROCmTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["rocm", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.rocm(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a ROCm template, possibly with fused epilogues + """ + assert self.is_rocm_cpp_template(template_node), ( + "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..c5981e129cb192d1f6e0ce1f445401e8c7e51b5e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch._inductor.config as config +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.utils import do_bench_using_profiling + +from ...ir import ( + Buffer, + ChoiceCaller, + IRNode, + Layout, + PrimitiveInfoType, + ShapeAsConstantBuffer, + TensorBox, +) +from ...virtualized import V +from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode +from ..cpp_utils import CppPrinter +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +if TYPE_CHECKING: + from torch._inductor.codegen.rocm.rocm_template import ArgInfo, ROCmTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class ROCmKernel(Kernel): + """ + Baseclass for ROCm based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class ROCmTemplateKernel(ROCmKernel): + """ + Template kernels defined by ROCm in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, hipStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the ROCmTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def get_signature(self): + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + size_args: list[str], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder == [2, 0, 1]: + input_reorder = [4, 0, 1, 2, 3] + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + + runtime_arg_defs = [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args + runtime_arg_defs)},{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "ROCmTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The ROCmTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # Kinda hacky because we always originally initialize name with "KERNEL_NAME" + # So, we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace("KERNEL_NAME", name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + kernel_args = [] + for arg in call_args: + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + if V.graph.is_unspec_arg(arg): + arg = arg + ".item()" + else: + if not V.graph.cpp_wrapper: + arg = f"c_void_p({arg}.data_ptr())" + kernel_args.append(arg) + + # add size args + size_args = [ + f"{V.graph.sizevars.simplify(sarg)}" for sarg in node.template.size_args() + ] + + if V.graph.cpp_wrapper: + kernel_args.extend(size_args) + else: + kernel_args.extend(f"c_int({sarg})" for sarg in size_args) + + if V.graph.cpp_wrapper: + arg_types.extend(["int"] * len(node.template.size_args())) + + # the runtime args come right after the size args + kernel_args.extend(self.runtime_arg_values) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" + kernel_args.append( + data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" + ) + else: + ws = None + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + wrapper.generate_kernel_call( + name, + kernel_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + +class ROCmTemplateCaller(ChoiceCaller): + """ + ROCmTemplateCaller + + This class represents a caller for ROCm template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (ROCmBenchmarkRequest): The benchmark request for the caller. + template_buffer (ROCmTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [ROCmTemplateBuffer, Optional[Sequence[IRNode]]], str + ], + bmreq: ROCmBenchmarkRequest, + template: "ROCmTemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})" + + def call_name(self) -> str: + return f"rocm_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "ROCm", + "name": self.name, + **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] + } + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + self.bmreq.update_workspace_size() + return TensorBox.create( + ROCmTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bfeb03eabc72d7cf9bce701f535e612644a806c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import patch + +from ...autotune_process import TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +log = logging.getLogger(__name__) + + +# FIXME: unify with the CUDA version +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class ROCmTemplate(KernelTemplate): + index_counter = itertools.count() + gfx9_threads_per_warp = 64 + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for ROCm C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the ROCmTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> ROCmTemplateCaller: + """ + Generates the ROCm template caller object for the given GEMM template and operation. This ROCmTemplateCaller + may be used to call and benchmark the generated ROCm kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A ROCmTemplateCaller object representing the generated ROCm template caller. + """ + kernel_name = f"rocm_{self.name}" + kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + ROCmTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + + size_args = ( + self.size_args() if hasattr(self, "size_args") else () + ) # subclass should define def size_args() + size_args_ints = [ + V.graph.sizevars.size_hint(arg) for arg in size_args + ] # resolve to ints for benchmarking + # The runtime args come right after the size args + runtime_args = self.get_runtime_arg_values(**kwargs) + extra_args = size_args_ints + runtime_args + bmreq = ROCmBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ROCmTemplateBuffer, + epilogue_nodes: Optional[Sequence[IRNode]] = None, + ): + kernel = ROCmTemplateKernel( + kernel_name="KERNEL_NAME", + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return ROCmTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c90d71c19980279f13cad86d7385825ba7212c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py @@ -0,0 +1,27 @@ +from collections.abc import Callable, Sequence +from typing import TypeVar +from typing_extensions import ParamSpec + +from ...ir import Buffer, Layout, TemplateBuffer + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class ROCmTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[Buffer], + make_kernel_render: Callable[_P, _T], + workspace_size: int, + template: "ROCmTemplate", # type: ignore[name-defined] # noqa: F821 + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self) -> int: + return self.workspace_size if self.workspace_size is not None else 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36871ac5c7f8fcf0a8b91a143168ab1b90530b0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs + + +import torch + +from ..cpp_utils import DTYPE_TO_CPP + + +DTYPE_TO_ROCM_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "uint16_t", + torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e5m2fnuz: "uint8_t", + torch.float8_e4m3fn: "uint8_t", + torch.float8_e5m2: "uint8_t", + torch.bfloat16: "uint16_t", +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e5c86d18109d36ee8b9595d0bce48685845f54 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/segmented_tree.py @@ -0,0 +1,242 @@ +from collections.abc import Callable +from typing import Generic, Optional, TypeVar + + +T = TypeVar("T") + + +def _value_or(opt: Optional[T], default: T) -> T: + return opt if opt is not None else default + + +class SegmentedTree(Generic[T]): + def __init__( + self, + values: list[T], + update_op: Callable[[T, T], T], + summary_op: Callable[[T, T], T], + identity_element: T, + ): + """ + Initialize a segment tree with the given values and operations. + + Args: + values: list of initial values + update_op: Function to apply when updating a value (e.g., addition) + summary_op: Function to summarize two values (e.g., min, max, sum) + identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min) + + Raises: + ValueError: If the input values list is empty + """ + if not values: + raise ValueError("Cannot create a segment tree with empty values list") + + self.n = len(values) + self.update_op = update_op + self.summary_op = summary_op + self.identity = identity_element + + # Size of segment tree array (next power of 2 * 2) + # The tree follows a standard heap layout where + # node `n`'s children are at `2*n` and `2*n+1`. + # Index 0 is unused. + self.size = 1 + while self.size < self.n: + self.size *= 2 + self.size *= 2 + + # Initialize tree and lazy arrays + self.tree = [identity_element] * self.size + # The lazy array contains updates to the given node + # Upon update, we only push updates to the top-most + # nodes that fully receive the update. We then + # propagate the update down as required (i.e., when + # we receive an interval query that neither fully + # contains the node nor fully doesn't contain the + # node + self.lazy: list[Optional[T]] = [None] * self.size + + # Build the tree + self._build(values, 1, 0, self.n - 1) + + def _build(self, values: list[T], node: int, start: int, end: int) -> None: + """ + Build the segment tree recursively. + + Args: + values: Original array of values + node: Current node index in the segment tree + start: Start index of the segment + end: End index of the segment + """ + if start == end: + # Leaf node + if start < len(values): + self.tree[node] = values[start] + return + + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + # Recursively build left and right subtrees + self._build(values, left_child, start, mid) + self._build(values, right_child, mid + 1, end) + + # Update current node with summary of children + self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child]) + + def _children(self, node: int) -> list[int]: + return [2 * node, 2 * node + 1] + + def _push_lazy(self, node: int, start: int, end: int) -> None: + """ + Push lazy updates down to children. + + Args: + node: Current node index + start: Start index of the segment + end: End index of the segment + """ + lazy_node = self.lazy[node] + if lazy_node is None: + return + + # Apply lazy update to current node + self.tree[node] = self.update_op(self.tree[node], lazy_node) + + if start != end: # Not a leaf node + # Propagate to children + for child in self._children(node): + self.lazy[child] = self.update_op( + _value_or(self.lazy[child], self.identity), lazy_node + ) + + # Clear the lazy value + self.lazy[node] = None + + def _update_range_helper( + self, node: int, start: int, end: int, left: int, right: int, value: T + ) -> None: + """ + Helper method to update a range of values in the segment tree. + + Args: + node: Current node index + start: Start index of the current segment + end: End index of the current segment + left: Start index of the range to update + right: End index of the range to update + value: Value to apply to the range + """ + # Push lazy updates before processing this node + self._push_lazy(node, start, end) + + # No overlap + if start > right or end < left: + return + + # Complete overlap + if start >= left and end <= right: + # Apply update to current node + self.lazy[node] = value + self._push_lazy(node, start, end) + return + + # Partial overlap, recurse to children + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + self._update_range_helper(left_child, start, mid, left, right, value) + self._update_range_helper(right_child, mid + 1, end, left, right, value) + + # Update current node based on children + self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child]) + + def _query_range_helper( + self, node: int, start: int, end: int, left: int, right: int + ) -> T: + """ + Helper method to query a range of values in the segment tree. + + Args: + node: Current node index + start: Start index of the current segment + end: End index of the current segment + left: Start index of the range to query + right: End index of the range to query + + Returns: + Summary value for the range + """ + # No overlap + if start > right or end < left: + return self.identity + + # Push lazy updates before processing this node + self._push_lazy(node, start, end) + + # Complete overlap + if start >= left and end <= right: + return self.tree[node] + + # Partial overlap, recurse to children + mid = (start + end) // 2 + left_child = 2 * node + right_child = 2 * node + 1 + + left_result = self._query_range_helper(left_child, start, mid, left, right) + right_result = self._query_range_helper(right_child, mid + 1, end, left, right) + + # Combine results from children + return self.summary_op(left_result, right_result) + + def update_range(self, start: int, end: int, value: T) -> None: + """ + Update a range of values in the segment tree. + + Args: + start: Start index of the range to update (inclusive) + end: End index of the range to update (inclusive) + value: Value to apply to the range + + Raises: + ValueError: If start > end or indices are out of bounds + """ + if start > end: + raise ValueError("Start index must be less than or equal to end index") + + if start < 0 or start >= self.n: + raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]") + + if end < 0 or end >= self.n: + raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]") + + self._update_range_helper(1, 0, self.n - 1, start, end, value) + + def summarize_range(self, start: int, end: int) -> T: + """ + Query a range of values in the segment tree. + + Args: + start: Start index of the range to query (inclusive) + end: End index of the range to query (inclusive) + + Returns: + Summary value for the range according to the summary operation + + Raises: + ValueError: If start > end or indices are out of bounds + """ + if start > end: + raise ValueError("Start index must be less than or equal to end index") + + if start < 0 or start >= self.n: + raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]") + + if end < 0 or end >= self.n: + raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]") + + return self._query_range_helper(1, 0, self.n - 1, start, end) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py new file mode 100644 index 0000000000000000000000000000000000000000..aff8966e5af7167eb0e2b6f7133d3397149c7436 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py @@ -0,0 +1,3127 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import textwrap +from collections import Counter +from typing import Any, Generic, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeVar + +import sympy + +import torch +import torch._logging +from torch._inductor import metrics +from torch._inductor.ir import MultiTemplateBuffer +from torch._inductor.tiling_utils import analyze_memory_coalescing +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.immutable_collections import immutable_dict +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch.utils._sympy.symbol import ( + free_symbol_is_type, + prefix_str, + symbol_is_type, + SymT, +) + +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask +from ..codecache import code_hash, PyCodeCache +from ..dependencies import MemoryDep, StarDep, WeakDep + + +if TYPE_CHECKING: + from collections.abc import Callable + + from ..ir import IRNode + +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.coordinate_descent_tuner import CoordescTuner +from ..runtime.hints import DeviceProperties +from ..runtime.runtime_utils import green_text, last_power_of_2, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + cache_property_on_self, + expr_fits_within_32bit, + get_dtype_size, + IndentedBuffer, + Placeholder, + prefix_is_reduction, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsWrapper, V +from .block_analysis import BlockPatternMatcher +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel, SizeHintMultiKernel +from .simd_kernel_features import ( + DisableReduction, + EnableReduction, + NodeScheduleEntry, + NodeScheduleMarker, + SIMDKernelFeatures, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + +all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"]) + + +def get_max_tiles(default: int = 2) -> int: + max_tiles = torch._inductor.config.triton.max_tiles + return max_tiles if max_tiles is not None else default + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: list[sympy.Symbol], + var_ranges: dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.S.One, + length=sympy.S.One, + root: IterationRangesRoot, + ) -> None: + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + @property + @cache_property_on_self + def is_reduction(self) -> bool: + return prefix_is_reduction(self.prefix) + + def symbol(self) -> sympy.Symbol: + return sympy_index_symbol(self.name) + + @property + @cache_property_on_self + def symt(self) -> SymT: + prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} + return prefix_to_symt[self.prefix] + + +class IterationRangesRoot(IterationRanges): + """ + Root of a iteration range tree that represents a single + tiled dimension in the output kernel. It contains multiple + sets of iteration represented with IterationRangesEntry. + """ + + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache: Optional[dict[str, str]] = None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + # pyrefly: ignore [missing-argument] + assert not is_loop or (self.is_reduction and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self) -> str: + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self) -> None: + for node in self.nodes.values(): + node.cache_clear() + + def index_sym(self) -> sympy.Symbol: + return sympy_index_symbol(f"{self.prefix}index") + + def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry: + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(self.index_sym(), divisor) + else: + expr = ModularIndexing(self.index_sym(), divisor, length) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries( + self, lengths: list[sympy.Expr] + ) -> list[IterationRangesEntry]: + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return [*reversed(itervars)] + + def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]: + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes( + self, index: sympy.Expr + ) -> tuple[list[sympy.Symbol], list[sympy.Expr]]: + """Figure out vars from this tree used in index""" + + def get_sort_key(x: IterationRangesEntry) -> tuple[int, bool]: + """ + Gets the key for sorting nodes. When two nodes have the + same divisor, the node with length as 1 should be handled + first so the current divisor is not changed after multiplied + node.length. Returns `not length_is_one_hint` for ascending + sort. + """ + divisor_hint = V.graph.sizevars.size_hint( + x.divisor, fallback=config.unbacked_symint_fallback + ) + length_is_one_hint = ( + V.graph.sizevars.size_hint( + x.length, fallback=config.unbacked_symint_fallback + ) + == 1 + ) + return (divisor_hint, not length_is_one_hint) + + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: get_sort_key(x)) + divisor = sympy.S.One + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return [*reversed(index_vars)], [*reversed(sizes)] + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ) -> None: + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self) -> str: + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name: str) -> None: + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self) -> None: + self.codegen.cache_clear() + + def _codegen(self) -> str: + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self) -> list[sympy.Expr]: + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: list[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, IterationRangesEntry) + return self.name == other.name + + +def constant_repr(value: Union[int, float]) -> str: + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + + +@dataclasses.dataclass +class PartialAccumulate: + buffer_name: str + reduction_type: str + value: Any + + +class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr: Callable[[sympy.Expr], str] = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr: bool = False + # pyrefly: ignore [bad-override] + kernel_name: str + + def __init__( + self, + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + pid_cache: Optional[dict[str, str]] = None, + override_persistent_reduction: Optional[bool] = None, + override_cooperative_reduction: Optional[bool] = None, + tiling_scores: Optional[dict[str, sympy.Expr]] = None, + mix_order_reduction: bool = False, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__() + self.features = features + self.mutations = features.get_mutations() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = { + prefix: V.graph.sizevars.simplify(val) for prefix, val in tiling.items() + } + self.range_trees: list[IterationRangesRoot] = [] + self.range_tree_nodes: dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = features.is_reduction() + self.cooperative_reduction: bool = ( + override_cooperative_reduction + if override_cooperative_reduction is not None + else self.should_use_cooperative_reduction() + ) + self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores + self.tiling: dict[str, sympy.Expr] = tiling + self.persistent_reduction: bool = ( + override_persistent_reduction + if override_persistent_reduction is not None + else self.should_use_persistent_reduction() + ) + self.mix_order_reduction: bool = mix_order_reduction + self.no_x_dim = self.want_no_x_dim() + self.code_hash: Optional[str] = None + # Info to enable multiple store_output calls for epilogue subtiling + self.store_output_ctr = itertools.count() + self.is_native_matmul = False + if config.triton.native_matmul: + for node in self.features.node_schedule: + if ( + isinstance(node, scheduler.SchedulerNode) + and isinstance(node.node, ir.ComputedBuffer) + and node.node.get_reduction_type() == "dot" + ): + self.is_native_matmul = True + break + + # define this in a closure to make cache local to object + @functools.cache + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + + return self.combine_modular_indexing_pairs(index) + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + self.rsplit_size = 0 + self.saved_partial_accumulate: list[PartialAccumulate] = [] + + def _get_store_output_subgraph_name(self, i: int) -> str: + return f"" + + def get_store_output_count(self): + total = next(self.store_output_ctr) + self.store_output_ctr = itertools.count(start=total - 1, step=1) + return total + + @property + @cache_property_on_self + def num_reduction_dims(self) -> int: + return sum(prefix_is_reduction(prefix) for prefix in self.numels) + + def dtype_to_str(self, dtype: torch.dtype) -> str: + raise NotImplementedError + + def get_index_dtype_as_torch_dtype(self) -> torch.dtype: + return self.features.select_index_dtype() + + @property + def index_dtype(self) -> str: + return self.dtype_to_str(self.get_index_dtype_as_torch_dtype()) + + def want_no_x_dim(self) -> bool: + return False + + def construct_range_trees( + self, + pid_cache: Optional[dict[str, str]], + inside_reduction: bool, + is_reduction: bool, + numels: dict[str, sympy.Expr], + no_x_dim: bool, + ) -> list[IterationRangesRoot]: + active_prefixes = OrderedSet( + prefix for prefix in all_prefixes if prefix in numels + ) + no_r_dim = not inside_reduction or not is_reduction + + def filtered_index_map(seq, mask) -> dict[Any, int]: + return { + val: idx for idx, val in enumerate(val for val in seq if val in mask) + } + + grid_dims = ["x", "y", "z"] + pointwise_tensor_dims = list(reversed(grid_dims)) + reduction_dims = ["r0_", "r1_"] + if no_x_dim: + tensor_dims = reduction_dims + elif no_r_dim: + tensor_dims = pointwise_tensor_dims + else: + tensor_dims = pointwise_tensor_dims + reduction_dims + + # Filter out unused tensor dims. + # Convert to dicts for O(1) index lookup. + tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes) + grid_dim_map = filtered_index_map(grid_dims, all_prefixes) + + range_trees = [] + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dim_map.get(prefix) + grid_dim = grid_dim_map.get(prefix) + index = i if grid_dim is None else grid_dim + range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numels[prefix], + prefix, + index, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in numels, + ) + ) + return range_trees + + def initialize_range_tree(self, pid_cache: dict[str, str]) -> None: + range_trees = self.construct_range_trees( + pid_cache, + self.inside_reduction, + self.features.is_reduction(), + self.numels, + self.no_x_dim, + ) + self.range_trees.extend(range_trees) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None: + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + """ + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_cooperative_reduction(self) -> bool: + return False # defined in subclass + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self) -> int: + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i: int) -> str: + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> list[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + # pyrefly: ignore [missing-argument] + if not tree.is_reduction or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def create_constant_mask(self, entry) -> str: + x = entry.prefix + if entry.tensor_dim is None: + sizestr = self.dense_size_str() + return f"{x}mask = tl.full({sizestr}, True, tl.int1)" + sizes = ["None"] * self.triton_tensor_ndim() + sizes[entry.tensor_dim] = ":" + suffix = ", ".join(sizes) + out = f"{x}mask = tl.full([{x.upper()}BLOCK], True, tl.int1)[{suffix}]" + return out + + def dense_size_str(self) -> str: + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index) + # the index now contains xindex/etc, which is nonstandard, fix it up + return sympy_subs( + new_index, + { + tree_node.root.index_sym(): tree_node.root.lookup( + sympy.S.One, tree_node.root.numel + ).symbol() + }, + ) + + def combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def disable_reduction(self) -> contextlib.AbstractContextManager[None]: + should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction + + @contextlib.contextmanager + def ctx(): + if not self.features.is_reduction(): + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]: + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ) -> tuple[ + list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]] + ]: + # Special case: if a node's sizes are ([], []), there's nothing to split. + if all(len(length) == 0 for length in lengths): + return [[] for group in groups], [] + + sv = V.graph.sizevars + new_ranges: list[list[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i: int, expr: sympy.Expr) -> int: + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined( + sizes: list[sympy.Expr], idxs: list[int] + ) -> Callable[[list[sympy.Expr]], sympy.Expr]: + """ + Builds the nested expression: + ((...((s1*v[i1] + v[i2]) * s2 + v[i3]) ... ) * sk + v[i(k+1)]) + """ + assert len(idxs) == len(sizes) + 1 + + def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: + expr = flat_vars[idxs[0]] + for s, idx in zip(sizes, idxs[1:]): + expr = s * expr + flat_vars[idx] + return expr + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.S.Zero) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], + 1, # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + # During native matmul on bmm, we enforce tiling order (z, y, x, r). + # When fusing a bmm node with loop (z, y, x, r) with a pw node + # of shape (z*y*x, 1), we need to split the pw iteration range + # into three dimensions. + # The group becomes [z, y, x, 1], with lengths ([z*y*x], []). + # In this case, we decompose the combined size z*y*x into three + # consecutive groups. Previously, _split_iteration_ranges supported + # splitting into at most two dimensions, but we now extend it to do + # three splits when the total size is divisible by all three. + + # is group having (z,y,x,r=1) form? + is_bmm_then_pw = len(remaining) == 4 and remaining[-1] == 1 + if ( + current_group + 2 < len(remaining) + and sv.statically_known_gt( + size, remaining[current_group] * remaining[current_group + 1] + ) + and is_bmm_then_pw + ): + # need to break size in three + if not sv.statically_known_multiple_of( + size, remaining[current_group] * remaining[current_group + 1] + ): + raise CantSplit + + size1 = remaining[current_group] + size2 = remaining[current_group + 1] + size3 = FloorDiv(size, size1 * size2) + return_getters.append( + make_combined( + [size2, size3], + [ + add_range(current_group, size1), + add_range(current_group + 1, size2), + add_range(current_group + 2, size3), + ], + ) + ) + + # Two-dimensional tiling: split size across current_group and next group. + elif current_group + 1 < len(remaining) and ( + sv.statically_known_gt(size, remaining[current_group]) + or + # statically_known_gt(size, remaining) may return False for symbolic + # expressions like 64*u0 vs u0, because both could be 0. Similarly for + # backed expressions like s25*(((s70 - 5)//4)) - s25 and + # (s25*(((s70 - 5)//4)) - s25)*64. + # We want to assume tensor sizes are not 0 and pass the gt + # using the following logic. + # + # if A//B = C and C >= 1 + # then A = B * C + R + # and assuming A!=0 + # A must be > B . + # + sv.statically_known_gt(FloorDiv(size, remaining[current_group]), 1) + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + [size2], + [ + add_range(current_group, size1), + add_range(current_group + 1, size2), + ], + ) + ) + else: + if current_group < len(remaining): + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( + f"failed to set ranges {remaining} {lengths}" + ) + return new_ranges, return_getters_groups + + @classmethod + def prepare_split_iteration_lengths( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> Sequence[Sequence[sympy.Expr]]: + "Fill in the reduction numel of lengths if missing" + sizevars = V.graph.sizevars + if len(lengths[1]) == 0 and ( + not sizevars.statically_known_equals(reduction_numel, sympy.S.One) + and sizevars.statically_known_equals( + sympy_product(groups), + sympy_product(lengths[0]) * reduction_numel, + ) + ): + return (lengths[0], [reduction_numel]) + + return lengths + + @classmethod + def is_compatible( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> bool: + lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel) + + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges( + self, lengths: Sequence[Sequence[sympy.Expr]] + ) -> list[list[sympy.Expr]]: + """ + Split and set iteration ranges for the kernel based on the provided lengths. + + This method maps the kernel's tiling structure to the node's iteration space, + handling both pointwise and reduction dimensions appropriately. + + Args: + lengths: A sequence of sequences of symbolic expressions representing + the sizes of different dimensions for each node. + + Returns: + A list of lists of symbolic expressions representing the mapped + iteration variables for each dimension. + """ + # Create a dictionary mapping each range tree prefix to its total number of elements + tiling = {rt.prefix: rt.numel for rt in self.range_trees} + + # If we're not inside a reduction loop, set all reduction dimensions to 1 + # This effectively disables reduction dimensions when not needed + if not self.inside_reduction: + for prefix in tiling: + if prefix_is_reduction(prefix): + tiling[prefix] = sympy.S.One + + # Extract the values from the tiling dictionary to create groups + groups = [*tiling.values()] + + # Map the kernel's group structure to the node's sizes and set the ranges + # using the set_ranges method, returning the resulting iteration variables + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + @classmethod + def map_kernel_groups_to_node_sizes( + cls, + groups: Sequence[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + set_ranges, + ) -> list[list[sympy.Expr]]: + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + if len(lengths) == len(groups) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return set_ranges(*lengths) + + new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths) + itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))] + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr) -> bool: + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr) -> bool: + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels.values()) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in output code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg] + + def prepare_indexing( + self, + index: sympy.Expr, + ) -> sympy.Expr: + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = self.simplify_indexing(index) + + # Now that we are done simplifying we can unwrap Identity so that downstream handling + # for its contained expression will work. previously, tl.full wrapping of sympy.Integer + # would not occur + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + return self.codegen_indexing(simp_index) + + def active_range_trees(self) -> list[IterationRangesRoot]: + return [ + t + for t in self.range_trees + # pyrefly: ignore [missing-argument] + if not t.is_reduction or self.inside_reduction + ] + + def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, + replacements, # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + def codegen_nan_check(self) -> None: + raise NotImplementedError("NYI: codegen_nan_check") + + def deallocate_workspaces(self): + wrapper = V.graph.wrapper_code + for ws in reversed(self.args.workspace_args): + wrapper.generate_workspace_deallocation(ws) + + def call_kernel( + self, name: str, node: Optional[IRNode] = None, deallocate_ws: bool = True + ) -> None: + raise NotImplementedError("NYI: call_kernel") + + @contextlib.contextmanager + def mask_loads( + self, mask: Union[str, OpsWrapper], value: Union[int, float] + ) -> Iterator[str]: + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + prior_val = self._load_other + if prior: + mask = ops.logical_and(mask, prior) + + mask = OpsWrapper._unwrap(mask) + self._load_mask = mask + self._load_other = value + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + self._load_other = prior_val + + def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]: + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_flops(self) -> Optional[int]: + flops = [ + node.estimate_flops() + for node in NodeScheduleMarker.only_nodes(self.features.node_schedule) + ] + return sum(filter(None, flops)) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _, _ = self.args.python_argdefs() + buf_accesses = self.features.buf_accesses() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint( + sympy_product(self.numels.values()), + fallback=config.unbacked_symint_fallback, + ) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint( + arg_numel, fallback=config.unbacked_symint_fallback + ) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices = OrderedSet[Any]() + no_index_dep_count = 0 + for dep in buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + # pyrefly: ignore [bad-argument-type] + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, _signature, _ = self.args.python_argdefs() + uniform_stride_order = None + # pyrefly: ignore [bad-assignment] + for arg_name in call_args: + buf = V.graph.try_get_buffer(arg_name) + if not buf: + continue + layout = buf.get_layout() + if len(layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order( + V.graph.get_buffer(name).get_layout().stride + ) + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).get_layout().size + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + argdef_names = [x.name for x in argdefs] + msg = yellow_text( + f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def welford_reduce_fallback(self, dtype, value): + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.features.reduction_numel, dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + return OpsWrapper._unwrap((mean, m2, rnumel)) + + def prepare_softmax_twopass_fallback(self, dtype, value): + vmax = ops.reduction(dtype, dtype, "max", value) + sub = ops.sub(value, vmax) + exp = ops.exp(sub) + vsum = ops.reduction(dtype, dtype, "sum", exp) + return OpsWrapper._unwrap((vmax, vsum)) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + pass + + +class SIMDScheduling(BaseScheduling): + """ + Single Instruction Multiple Data parent class used for fusion across + multiple different backends. + """ + + kernel_type: type[Any] = SIMDKernel # override in subclass + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + from torch._inductor.scheduler import MixOrderReduction + + reduction_can_fuse = MixOrderReduction.can_fuse(node1, node2) + + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + + if reduction_can_fuse and ( + node1.is_native_matmul() or node2.is_native_matmul() + ): + # Ensure node1 is always the native matmul side + if not node1.is_native_matmul(): + node1, node2 = node2, node1 + + # 1. A native matmul node keeps its original loop order. + # For example: C[z,y,x] = torch.bmm(A[z,y,r], B[z,r,x]) keeps (z,y,x) order. + # (see simplify_and_reorder in ir.py) + # + # 2. Triton kernels with native matmul always tile loops as (z,y,x) + # (see get_tiling_and_scores in this file) + # + # 3. If a candidate node (node2) uses a different loop order (e.g., (z,x,y,r)), + # its tiling is incompatible with native matmul tiling (z,y,x,r). + # This means _split_iteration_ranges will fail, so these nodes should not be fused. + tiling = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + if not all( + SIMDKernel.is_compatible( + tiling.values(), n2.get_ranges(), reduction_numel=rnumel1 + ) + for n2 in node2.get_nodes() + ): + why("invalid loop order and tiling for native matmul") + return False + + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + if not node2.is_template(): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + else: + # prologue fusion input sizes differ from output group + # fuse so long as this node matches the group of existing prologue nodes + for node in node2.get_nodes(): + # dont need to check epilogue nodes for prologue fusion, break after template + if node.is_template(): + break + # we would have already restricted prologue from fusing if it had multiple + # uses, so it must be fusing into this node + if not node.used_buffer_names() & node1.get_buffer_names(): + continue + _, (pro_numel, pro_rnumel) = node.group + if not (numel1 == pro_numel and rnumel1 == pro_rnumel): + why( + "numel/rnumel mismatch prologue mismatch (%s, %s), (%s, %s)", + numel1, + pro_numel, + rnumel1, + pro_rnumel, + ) + return False + + for n in (node1, node2): + if n.is_template(): + return True + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = tuple( + self.select_tiling(node1.get_nodes(), numel1).values() + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: list[Any] = [] + done = OrderedSet[scheduler.BaseSchedulerNode]() + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() + maybe_split_index: Optional[int] = None + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def expect_improved_memory_usage(n): + for read in n.read_writes.reads: + if read.name in current_loop_buffer_usage: + return True + return False + + def schedule_node_in_loop(n): + done.add(n) + node_schedule.append(n) + current_loop_buffer_usage.update([x.name for x in n.read_writes.reads]) + + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + not_ready_yet_nodes.add(n.get_name()) + else: # this node is available within the loop + current_loop_buffer_usage.update([x.name for x in n.read_writes.writes]) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal maybe_split_index + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + if maybe_split_index: + node_schedule.insert(maybe_split_index, DisableReduction) + node_schedule.insert(maybe_split_index + 1, EnableReduction) + maybe_split_index = None + yield + node_schedule.append(EnableReduction) + not_ready_yet_nodes.clear() + current_loop_buffer_usage.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) + + for node in nodes: + if node in done: + continue + done.add(node) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + if current_loop_buffer_usage and not expect_improved_memory_usage(node): + # If we don't improve memory usage, then it is better to split into two loops + maybe_split_index = maybe_split_index or len(node_schedule) + else: + # Memory usage got improved, cancel the loop split + maybe_split_index = None + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_mix_order_reduction(self, node): + node1, node2 = node.node1, node.node2 + + # Make sure there are no producer/consumer relationship + assert not (node1.ancestors & node2.get_operation_names()) and not ( + node2.ancestors & node1.get_operation_names() + ) + + self._codegen_mix_order_reduction(node1, node2) + + def _split_mix_order_reduction_epilogue(self, node): + # TODO: do more validation here + nodes = node.get_nodes() + reductions = [] + epilogues = [] + for node in nodes: + if node.is_reduction(): + reductions.append(node) + else: + epilogues.append(node) + return reductions, epilogues + + def _generate_kernel_code_for_mix_order_reduction( + self, kernel_features, split_size, for_benchmark + ): + """ + for_benchmark: + True if the generated code is for benchmarking. We need make + sure benchmark harness code is generated. + """ + numel, rnumel = kernel_features.numel, kernel_features.reduction_numel + node_schedule = kernel_features.node_schedule + + kernel = self.create_kernel_choices( + kernel_features, + [{"x": numel, "r0_": rnumel}], + { + "features": kernel_features, + "tiling_scores": None, + "mix_order_reduction": True, + "override_persistent_reduction": True, + }, + )[0] + assert kernel.persistent_reduction + assert kernel.mix_order_reduction + kernel.rsplit_size = split_size + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + # allocate workspace for this kernel + _, ws_name, ws_off = kernel.args.workspace( + len(kernel.saved_partial_accumulate) + * kernel.numels["r0_"] + * ((kernel.numels["x"] + kernel.rsplit_size - 1) // kernel.rsplit_size), + False, + dtype=torch.float, + ) + assert ws_off == 0, f"{ws_off=}" + with kernel: + kernel.codegen_body() + + stack = contextlib.ExitStack() + with V.set_kernel_handler(kernel), stack: + if for_benchmark: + stack.enter_context(config.patch(benchmark_kernel=True)) + src_code = kernel.codegen_kernel() + + if for_benchmark: + # only do this if we are doing benchmarking. + # When we are generating final code, the kernel name + # should be decided differently with node type, fx node name + # etc. + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return kernel, ws_name, src_code + + def benchmark_codegened_module( + self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None + ) -> tuple[float, str]: + raise NotImplementedError + + def _codegen_mix_order_reduction(self, node1, node2): + numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1) + + if not V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)): + return self._codegen_mix_order_reduction(node2, node1) + + def _pick_split_size(): + # the overridden has highest priority + if config.triton.mix_order_reduction_split_size is not None: + return config.triton.mix_order_reduction_split_size + + # heuristics based on number of SMs + device_prop = DeviceProperties.create(node1.get_device()) + num_sm = device_prop.multi_processor_count + estimated_num_splits = num_sm * 8 + + # split_size is decided based on hint + numel_hint = V.graph.sizevars.size_hint(numel) + split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16) + split_size = min(split_size, 128) + return split_size + + split_size = _pick_split_size() + + # pyrefly: ignore [bad-assignment] + metrics.codegen_mix_order_reduction += 1 + + assert V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)) + + # split epilogue out of node2 + node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue( + node2 + ) + + converted_nodes = [] + for subnode in node2_reductions: + subnode.cancel_reduction_split() + converted = subnode.extract_pw_from_reduction() + converted.swap_pw_red_dimension() + converted_nodes.append(converted) + node_schedule = self.generate_node_schedule( + node1.get_nodes() + converted_nodes, numel, rnumel + ) + kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel) + + # The autotuning is skipped in deterministic mode + if ( + not torch._inductor.config.deterministic + and config.triton.mix_order_reduction_split_size is None + and ( + config.triton.mix_order_reduction_autotune_split_size + or config.max_autotune + or config.coordinate_descent_tuning + ) + ): + + def _bench(candidate_split_size): + _, _, src_code = self._generate_kernel_code_for_mix_order_reduction( + kernel_features, + split_size=candidate_split_size, + for_benchmark=True, + ) + mod = PyCodeCache.load(src_code) + ms, _ = self.benchmark_codegened_module(mod) + return ms + + split_size = CoordescTuner.autotune_single_field( + _bench, + split_size, + 8, + ) + + kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction( + kernel_features, + split_size=split_size, + for_benchmark=False, + ) + + # rename intermediate reduction output to final reduction + # output + is_split_reduction = bool(node2_reductions[0].node._split_size) + rename = {} + if is_split_reduction: + for subnode in node2_reductions: + bufname = subnode.get_outputs()[0].node.get_name() + username = ( + subnode.get_outputs()[0] + .users[0] + .node.get_outputs()[0] + .node.get_name() + ) + rename[bufname] = username + assert self.scheduler + self.scheduler.removed_ops.add( + subnode.get_outputs()[0].users[0].node.get_name() + ) + V.graph.removed_buffers.add(bufname) + + for partial_accum in kernel.saved_partial_accumulate: + partial_accum.buffer_name = rename.get( + partial_accum.buffer_name, partial_accum.buffer_name + ) + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + with V.set_kernel_handler(kernel): + for node in kernel_features.scheduler_nodes(): + # No need to allocate buffer for split reduction + # since we are gonna to allocate workspace to store the + # intermediate reduction reduction + if node.get_outputs()[0].node.get_name() not in rename: + node.mark_run() + + V.graph.wrapper_code.make_comment("# Call mix order reduction kernel") + self.codegen_comment(node_schedule, None) + # workspace args is still needed after the call + kernel.call_kernel(kernel.kernel_name, deallocate_ws=False) + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + + # a extra round of reduction + assert len(converted_nodes) == len(kernel.saved_partial_accumulate) + nsplit = V.graph.wrapper_code.codegen_python_sizevar( + (numel + split_size - 1) // split_size + ) + for idx, partial_accum in enumerate(kernel.saved_partial_accumulate): + buffer_name = partial_accum.buffer_name + + stride_str = f"{nsplit} * {rnumel}" + start = f"{idx} * {stride_str}" + end = f"({idx} + 1) * {stride_str}" + reduction_type2op = { + "min": "amin", + "max": "amax", + } + opname = reduction_type2op.get( + partial_accum.reduction_type, partial_accum.reduction_type + ) + + V.graph.wrapper_code.writeline( + f"{buffer_name} = {ws_name}[{start} : {end}].view({nsplit}, {rnumel}).{opname}(dim=0)", + ) + # mark the buffer as allocated, so we don't try to allocate + # it again when it's later used + V.graph.wrapper_code.allocated.add(buffer_name) + + kernel.deallocate_workspaces() + + if node2_epilogue: + self._codegen_nodes(node2_epilogue) + + self.free_buffers_in_scheduler() + + def _codegen_nodes( + self, + nodes: Sequence[scheduler.SchedulerNode], + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ): + assert self.scheduler + nodes = [ + node for node in nodes if node.get_name() not in self.scheduler.removed_ops + ] + if not nodes: + return + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis) + ) + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + assert self.scheduler + nodes = [ + node + for node in node.get_nodes() + if node.get_name() not in self.scheduler.removed_ops + ] + if len(nodes) == 0: + return + + if torch._inductor.config.triton.coalesce_tiling_analysis: + if len(nodes) != len(node.get_nodes()): + assert self.scheduler + node = scheduler.FusedSchedulerNode(self.scheduler, nodes) + coalesce_analysis = analyze_memory_coalescing(node) + else: + coalesce_analysis = None + + return self._codegen_nodes(nodes, coalesce_analysis) # type: ignore[arg-type] + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, + buffers: Iterable[ + Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject, ir.IRNode] + ], + ) -> bool: + int_max = torch.iinfo(torch.int32).max + + if not expr_fits_within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if buf.has_tensor_output() + ] + + for buf in buffers: + if not buf.has_tensor_output() and isinstance(buf, ir.MutationOutput): + mutated_bufs = buf.get_mutation_buffers() + buf_sizes += [ + buf.get_layout().storage_size() + for buf in mutated_bufs + if buf.has_tensor_output() + ] + + if not all(expr_fits_within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.check_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.check_leq(size, int_max) # type: ignore[arg-type] + return True + + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + """ + Generate code for nodes in kernel_features + """ + node_schedule = kernel_features.node_schedule + + tiling, tiling_score = self.get_tiling_and_scores( + node_schedule, + kernel_features.numel, + kernel_features.reduction_numel, + kernel_features.coalesce_analysis, + ) + kernels = self.create_kernel_choices( + kernel_features, + [tiling], + {"features": kernel_features, "tiling_scores": tiling_score}, + ) + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + # filter out NodeScheduleMarker + base_scheduler_nodes = [ + node for node in node_schedule if isinstance(node, BaseSchedulerNode) + ] + self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + final_kernel.kernel_name, + base_scheduler_nodes, # type: ignore[arg-type] + ) + final_kernel.call_kernel(final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks # type: ignore[has-type] + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + assert node.node is not None + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.free_buffers_in_scheduler() + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> list[SIMDKernel]: + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + with kernel: + stack = contextlib.ExitStack() + all_indexing = {} + + # First pass to collect indexing and decide inplace updates + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + node.decide_inplace_update() + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + all_indexing.update( + dict.fromkeys( + node._body.indexing_from_args(index_vars).values() + ) + ) + + kernel.finalize_indexing(all_indexing.keys()) + + # Second pass to do codegen + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def _codegen_single_template( + self, + kernel, + render, + template_node, + epilogue_nodes, + prologue_nodes, + *, + only_gen_src_code=False, + ): + """ + Helper method to codegen a single template kernel variant + """ + buf_name_to_prologue_group = {} + template_reads = template_node.used_buffer_names() + prologue_group = [] + for prologue in prologue_nodes: + names = prologue.get_buffer_names() + prologue_group.append(prologue) + # this must be the end of a prologue group + if names & template_reads: + assert len(names) == 1 + buf_name_to_prologue_group[next(iter(names))] = prologue_group + kernel.prologue_fused_inputs.add(next(iter(names))) + prologue_group = [] + + # all prologue groups should have finalized with use in template + assert len(prologue_group) == 0 + + with kernel: + if not only_gen_src_code: + # prologue nodes can only be fused if their only use is in the template, + # so they are necessarily not allocated + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + partial_code = render() + + num_store_subgraphs = kernel.get_store_output_count() + for i in range(num_store_subgraphs): + subgraph_name = kernel._get_store_output_subgraph_name(i) + with kernel.set_subgraph_body(subgraph_name): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + kernel.cse.invalidate(OrderedSet()) + + for input_name, buffer in kernel.named_input_nodes.items(): + subgraph_name = f"" + if prologue_group := buf_name_to_prologue_group.get( + buffer.get_name(), [] + ): + can_codegen_without_upcast = all( + p_n.can_codegen_without_upcasts() for p_n in prologue_group + ) + + # TODO - this doesn't work with libdevice calls, potentially other bugs + # upcasting to fp32 and downcasting gives large slowdown + with config.patch( + "triton.codegen_upcast_to_fp32", not can_codegen_without_upcast + ): + with kernel.set_subgraph_body(subgraph_name): + for prologue_node in prologue_group: + if ( + len(prologue_node.get_buffer_names()) == 1 + and len(prologue_group) == 1 + ): + if prologue_preserves_zero_mask(prologue_node): + kernel.prologue_fused_inputs_preserve_zero |= ( + prologue_node.get_buffer_names() + ) + + prologue_node.codegen( + kernel.split_and_set_ranges( + prologue_node.get_ranges() + ) + ) + kernel.cse.invalidate(OrderedSet()) + + # Template hooks must be finalised after kernel.remove_kernel_local_buffers + # is called (this is called when the kernel context is exited above), and when + # the kernel handler is set (as below). This is because the hooks may add + # DeferredLine type lines, which preclude lines involving buffers that have + # been removed + + # finalize must be called after adding epilogue above + with V.set_kernel_handler(kernel): + if not isinstance(partial_code, str): + # This is used to calculate flops in TritonTemplateKernels + with ir.IRNode.current_origins(template_node.node.origins): + partial_code.finalize_hook("") + partial_code.finalize_hook("", strict=False) + + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + + for input_name in kernel.named_input_nodes: + subgraph_name = f"" + # pyrefly: ignore [missing-attribute] + partial_code.finalize_hook(subgraph_name, strict=False) + + num_store_subgraphs = kernel.get_store_output_count() + for i in range(num_store_subgraphs): + subgraph_name = kernel._get_store_output_subgraph_name(i) + # pyrefly: ignore [missing-attribute] + partial_code.finalize_hook(subgraph_name) + + if isinstance(partial_code, str): + src_code = partial_code + else: + # Ensure all hooks are finalized before the kernel is defined. + # Note: some of these hooks may have been registered by a kernel subclass + src_code = partial_code.finalize_remaining() + + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + return kernel + + def _get_multikernel_shapes( + self, node: MultiTemplateBuffer + ) -> tuple[tuple[int, ...], ...]: + from ..ir import IRNode + + def get_size(arg): + if not isinstance(arg, IRNode): + return None + if isinstance(arg, ir.BaseView): # triton templates want the base tensor. + arg = arg.unwrap_view() + if (size := arg.maybe_get_size()) is None: + return None + return tuple(s for s in size) + + out = [] + for arg in list(node.inputs) + [node]: + if isinstance(arg, (list, tuple)): + out.append(tuple(get_size(_arg) for _arg in arg)) + else: + out.append(get_size(arg)) + return tuple(out) + + def _kernel_has_dynamic_shapes(self, node: MultiTemplateBuffer) -> bool: + shapes = self._get_multikernel_shapes(node) + return any( + any( + isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer) + for s in shape + ) + for shape in shapes + ) + + def _make_shape_cache_key( + self, node: MultiTemplateBuffer, hint: int + ) -> tuple[tuple[int, ...], ...]: + """ + Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in. + """ + shapes = self._get_multikernel_shapes(node) + return tuple( + tuple( + hint + if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer) + else s + for s in shape + ) + for shape in shapes + ) + + def codegen_template( + self, + template_node, + epilogue_nodes, + prologue_nodes, + *, + only_gen_src_code=False, + hint_override: Optional[int] = None, + ) -> Optional[str]: + """ + Codegen a triton template with multi-kernel dispatch support + + If `only_gen_src_code=True` the src code will be returned instead of being + codegenned into the wrapper + """ + + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + + if ( + isinstance(template_node.node, MultiTemplateBuffer) + and template_node.node._make_kernel_renders + and len(template_node.node._make_kernel_renders) > 1 + and self._kernel_has_dynamic_shapes(template_node.node) + ): + kernels = {} + src_codes = [] + + for ( + size_hint, + make_kernel_render, + ) in template_node.node._make_kernel_renders.items(): + kernel, render = make_kernel_render( + template_node.node, hint_override=hint_override + ) + + if only_gen_src_code: + src_code = self._codegen_single_template( + kernel, + render, + template_node, + epilogue_nodes, + prologue_nodes, + only_gen_src_code=True, + ) + assert isinstance(src_code, str) + # pyrefly: ignore [bad-argument-type] + src_codes.append(src_code) + else: + if size_hint is None: + continue # skip kernel generation based on real runtime value; only use hints + kernel = self._codegen_single_template( + kernel, + render, + template_node, + epilogue_nodes, + prologue_nodes, + only_gen_src_code=False, + ) + shape_cache_key = ( + None + if size_hint is None + else self._make_shape_cache_key(template_node.node, size_hint) + ) + kernels[shape_cache_key] = kernel + + if only_gen_src_code: + return "\n\n".join(src_codes) + + MultiKernel.merge_workspaces_inplace(list(kernels.values())) + multi_kernel = SizeHintMultiKernel(kernels) + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] + self.codegen_comment(node_schedule, multi_kernel.kernel_name) + multi_kernel.call_kernel(multi_kernel.kernel_name) + V.graph.removed_buffers |= multi_kernel.removed_buffers + V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove + self.free_buffers_in_scheduler() + return None + else: + kernel, render = template_node.node.make_kernel_render( + template_node.node, hint_override=hint_override + ) + + if only_gen_src_code: + return self._codegen_single_template( + kernel, + render, + template_node, + epilogue_nodes, + prologue_nodes, + only_gen_src_code=True, + ) + else: + kernel = self._codegen_single_template( + kernel, + render, + template_node, + epilogue_nodes, + prologue_nodes, + only_gen_src_code=False, + ) + + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] + self.codegen_comment(node_schedule, kernel.kernel_name) + kernel.call_kernel(kernel.kernel_name, template_node.node) + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.free_buffers_in_scheduler() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def generate_combo_kernel_code( + self, + subkernel_nodes: list[BaseSchedulerNode], + custom_part_algorithm: bool, + enable_autotune: bool, + mixed_sizes: bool, + only_gen_src_code: bool = False, + ) -> list[tuple[str, Any, Any]]: + from .triton_combo_kernel import ComboKernel + + fused_node_lists = [node.get_nodes() for node in subkernel_nodes] + subkernel_map, node_schedule_map = {}, {} + for pn, nodes in zip(subkernel_nodes, fused_node_lists): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + node_schedule_map[pn] = node_schedule, tiling, numel, rnumel + subkernel_map[pn] = ComboKernel.create_triton_kernel( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + optimize_mask=not mixed_sizes, + ) + + partitions = ComboKernel.horizontal_partition( + nodes=subkernel_nodes, + triton_scheduling=self, + custom_algorithm=custom_part_algorithm, + kernel_map=subkernel_map, + node_info_map=node_schedule_map, + ) + log.debug( + "ComboKernels: %d nodes partitioned into %s groups", + len(subkernel_nodes), + [len(p) for p in partitions], + ) + kernel_code_list = [] + for node_group in partitions: + if len(node_group) == 0: + continue + kernel = ComboKernel( + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + ) + + for pn in node_group: + self.codegen_node_schedule_with_kernel( + node_schedule_map[pn][0], + kernel.create_sub_kernel(subkernel_map[pn]), + ) + subkernel = subkernel_map[pn] + node_schedule = node_schedule_map[pn][0] + if not only_gen_src_code: + with V.set_kernel_handler(subkernel): # type: ignore[call-arg] + for node in NodeScheduleMarker.only_nodes(node_schedule): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_code_list.append((src_code, kernel, node_group)) + return kernel_code_list + + def codegen_combo_kernel(self, combo_kernel_node): + subkernel_nodes = combo_kernel_node.get_subkernel_nodes() + custom_part_algorithm = combo_kernel_node.use_custom_partition_algo + enable_autotune = combo_kernel_node.enable_autotune + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm + ) + + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes + ) + + for src_code, kernel, _ in kernel_code_list: + kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) + self.codegen_comment(combo_kernel_node.snodes, kernel_name) + log.debug("ComboKernels: generated kernel %s.", kernel_name) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.free_buffers_in_scheduler() + + @classmethod + @functools.lru_cache(32) + def candidate_tilings(cls, node, numel, reduction_numel) -> list[CandidateTiling]: + is_pointwise = reduction_numel == 1 + + def tile_ranges(is_pointwise: bool, ranges, rw) -> list[CandidateTiling]: + """ + Compute tiling candidates by dividing up the iteration ranges. + """ + assert len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}" + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers + and isinstance(dep, MemoryDep) + ] + write_names = OrderedSet([dep.name for dep in rw.writes]) + + def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: + return V.graph.sizevars.simplify(sympy_product(ranges)) + + # Default to no tiling. + tilings = [ + CandidateTiling( + tiling=cls.create_partial_tiling( + [collapse_ranges(ranges)], is_pointwise + ), + name="none", + score=0, + ) + ] + + # Find non-trivial tiling candidates. + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + + tiled_groups = ( + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ) + + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append( + CandidateTiling( + tiling=cls.create_partial_tiling( + [ + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ], + reduction_numel, + ), + score=score, + name=dep.name, + ) + ) + + return tilings + + pointwise_ranges, reduction_ranges = node.get_ranges() + if ( + len(pointwise_ranges) <= 1 + and len(reduction_ranges) <= 1 + or free_unbacked_symbols(pointwise_ranges + reduction_ranges) + ): + return [] + + # Tile either pointwise or reduction dims. + pointwise_ranges, reduction_ranges = node.get_ranges() + partial_tilings = tile_ranges( + is_pointwise, + pointwise_ranges if is_pointwise else reduction_ranges, + node.pointwise_or_reduction_read_writes(is_pointwise), + ) + + # Fill in the missing ranges. + full_tilings = [ + CandidateTiling( + tiling=cls.complete_partial_tiling( + tiling.tiling, numel, reduction_numel + ), + score=tiling.score, + name=tiling.name, + ) + for tiling in partial_tilings + ] + + return full_tilings + + @classmethod + def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] + ) -> immutable_dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :] + reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)] + return immutable_dict( + [*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)] + ) + + @classmethod + def create_partial_tiling( + cls, + tiling: Sequence[sympy.Expr], + is_pointwise: bool, + ) -> immutable_dict[str, sympy.Expr]: + return cls.create_tiling( + tiling if is_pointwise else [], + tiling if not is_pointwise else [], + ) + + @classmethod + def complete_partial_tiling( + cls, + tiling: dict[str, sympy.Expr], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ) -> immutable_dict[str, sympy.Expr]: + """ + Given a tiling for only pointwise or reduction dimensions, adds the missing one. + """ + splits = list(tiling.values()) + is_pointwise = "x" in tiling + + total_numel = numel * reduction_numel + missing_tiling = [total_numel / sympy_product(splits)] + + tiling_args = ( + (splits, missing_tiling) if is_pointwise else (missing_tiling, splits) + ) + return cls.create_tiling(*tiling_args) + + @classmethod + def get_nd_tilings( + cls, + node_schedule, + pointwise_numel, + reduction_numel, + ) -> list[immutable_dict[str, sympy.Expr]]: + """ + Creates N-dimensional tiling candidates, attempting to simplify loads/stores + by tiling the kernel into higher dimensions. + + Returns a list of tilings ranked by dimensionality. + """ + is_pointwise = reduction_numel == 1 + tilings = OrderedSet[immutable_dict[str, sympy.Expr]]() + for node in EnableReduction.filter(node_schedule): + if not isinstance(node, scheduler.SchedulerNode): + continue + + # If this is a reduction schedule, skip nodes which are missing their + # reduction ranges. + node_ranges = node.get_ranges() + if not is_pointwise and len(node_ranges[1]) == 0: + continue + + # Use the node ranges as the default tiling candidate. + ranges_to_tile = node_ranges[0 if is_pointwise else 1] + node_tilings = [ranges_to_tile] + + # Search the indexing expressions for more candidates. + # If we see modular indexing, try to subdivide ranges into their implied + # block shape. + memory_deps = [ + dep + for dep in node.read_writes.reads_and_writes() + if isinstance(dep, MemoryDep) and len(dep.ranges) > 0 + ] + for dep in memory_deps: + # Attempt to partition variable ranges into pointwise and reduction groups. + # To achieve this, merge the leading ranges until we reach the pointwise numel. + all_var_ranges = [*dep.ranges.items()] + pointwise_vars_numel = sympy.S.One + sizevars = V.graph.sizevars + pointwise_end_idx = 0 + for idx, (_var, numel) in enumerate(all_var_ranges): + pointwise_vars_numel *= numel + pointwise_end_idx = idx + if sizevars.statically_known_geq( + pointwise_vars_numel, pointwise_numel + ): + break + + # Reject the split if it does not match the total pointwise numel. + if not sizevars.statically_known_equals( + pointwise_vars_numel, pointwise_numel + ): + continue + + # Partition var ranges into pointwise and reduction splits. + reduction_start_idx = pointwise_end_idx + 1 + var_ranges = ( + all_var_ranges[:reduction_start_idx] + if is_pointwise + else all_var_ranges[reduction_start_idx:] + ) + + # Pattern match the subexpression pertaining to each index variable. + index_tiling = [] + for var, numel in var_ranges: + index = BlockPatternMatcher.get_subexpr_involving_symbol( + dep.index, var + ) + + # Heuristic to bound the maximum dimensionality of the block. + num_dims = max( + 2, + index.count(FloorDiv) + index.count(ModularIndexing), + len(ranges_to_tile), + ) + + # Attempt to pattern match the index expr. + # Failed matches default to the full range. + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, var, numel, num_dims + ) + dims = match_result[0] if match_result is not None else [numel] + index_tiling.extend(dims) + + # Prune dimensions of size 1. + index_tiling = [ + dim + for dim in index_tiling + if not V.graph.sizevars.statically_known_equals(dim, sympy.S.One) + ] + + if len(index_tiling) > 0: + node_tilings.append(index_tiling) + + # Flatten leading dimensions, assigning labels to each dim. + for node_tiling in node_tilings: + num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2)) + first_trailing_dim = num_leading_dims + 1 + collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim]) + collapsed_splits = (collapsed_leading_dim,) + tuple( + node_tiling[first_trailing_dim:] + ) + tilings.add( + cls.complete_partial_tiling( + cls.create_partial_tiling(collapsed_splits, is_pointwise), + pointwise_numel, + reduction_numel, + ) + ) + + # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D. + # Since this is a stable sort, ties are broken by schedule order. + ranked_tilings = sorted( + tilings, + key=len, + reverse=True, + ) + + return ranked_tilings + + @classmethod + def compute_tiling_strategy( + cls, + node_schedule: list[NodeScheduleEntry], + pointwise_numel: sympy.Expr, + reduction_numel: sympy.Expr, + coalesce_analysis: CoalesceVarAnalysis, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses. + """ + tiling_var: Optional[sympy.Expr] = ( + None + if not coalesce_analysis.suggested_split + else coalesce_analysis.suggested_split.var + ) + + all_iter_vars = coalesce_analysis.norm_read_writes.index_vars + all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars + ranges = coalesce_analysis.norm_read_writes.var_ranges + + pw_ranges = [ranges[v] for v in all_iter_vars] + red_ranges = [ranges[v] for v in all_red_vars] + + torch._check( + sympy_product(pw_ranges) == pointwise_numel, + lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}", + ) + + torch._check( + sympy_product(red_ranges) == reduction_numel, + lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}", + ) + + # score of a pointwise or reduction split + scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {} + + score_split: list[ + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] + ] = [] + + def process_node_vars( + vars_to_use: tuple[sympy.Expr, ...] = (), + use_split_var: bool = False, + is_pointwise: bool = False, + ) -> tuple[list[int], list[int]]: + """ + Generate a tiling, and a tiling score, given vars to use as splits. + """ + + ranges = pw_ranges if is_pointwise else red_ranges + target_numel = pointwise_numel if is_pointwise else reduction_numel + # Some kernels have no reduction ranges, and a reduction numel of 1 + if not ranges: + if target_numel: + return ([target_numel], []) + else: + return ([], []) + + key = (repr(vars_to_use), use_split_var, is_pointwise) + if out := scored_sub_split.get(key): + return out + + splitting_vars = all_iter_vars if is_pointwise else all_red_vars + + splits = [] + split_scores = [] + prod = 1 + prev_var_coalesced_score = 0 + + # iterate from non-dense to dense + for v, v_range in zip(splitting_vars, ranges): + if v not in vars_to_use: + prod *= v_range + prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get( + v, 0 + ) + continue + + if use_split_var and v == tiling_var: + var_tiling = coalesce_analysis.suggested_split + assert var_tiling is not None + + tile = var_tiling.tiling_factor + remainder = FloorDiv(v_range, var_tiling.tiling_factor) + + splits.append(prod * remainder) + split_scores.append(var_tiling.score) + + splits.append(tile) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + + prod = 1 + prev_var_coalesced_score = 0 + + continue + + prod *= v_range + splits.append(prod) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + prod = 1 + + if prod != 1 or (is_pointwise and len(splits) == 0): + splits.append(prod) + split_scores.append(prev_var_coalesced_score) + + # penalize splits that leave small blocks + # where we can't fully utilize full memory transaction + # TODO: incorporate exact bitwidth, and read/write + # coalesced write is 2x more important + for i in range(len(splits)): + s = V.graph.sizevars.size_hint(splits[i], fallback=32) + s = min(s, 8) + split_scores[i] = int(split_scores[i] * s / 8) + + scored_sub_split[key] = (splits, split_scores) + return (splits, split_scores) + + # add the default tiling + score_split.append( + ( + process_node_vars(is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if tiling_var: + score_split.append( + ( + process_node_vars( + (tiling_var,), use_split_var=True, is_pointwise=True + ), + process_node_vars(is_pointwise=False), + ) + ) + + # TODO, add tests, reduction splits if config.triton.tile_reductions + # TODO: we should ignore tiny increases in score for extra splits + overlapping_iter_vars = ( + all_iter_vars & coalesce_analysis.coalesced_by_var.keys() + ) + for v in overlapping_iter_vars: + score_split.append( + ( + process_node_vars((v,), is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if get_max_tiles(default=3) == 3 and reduction_numel == 1: + for vars_to_use in itertools.combinations(overlapping_iter_vars, 2): + score_split.append( + ( + process_node_vars(vars_to_use, is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + tilings: list[tuple[CandidateTiling, immutable_dict[str, sympy.Expr]]] = [] + for (pw_split, pw_score), (red_split, red_score) in score_split: + candidate = CandidateTiling( + cls.create_tiling(pw_split, red_split), + score=sum(pw_score) + sum(red_score), + ) + tiling_score = cls.create_tiling(pw_score, red_score) + tilings.append((candidate, tiling_score)) + + default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel]) + + # add a slight penalty for longer tilings that dont increase score much, + # and are poor sizes + bad_size_additional_tiling_penalty = 1.025 + good_size_tiling_penalty = 1.005 + + total_uncoalesced = sum(coalesce_analysis.uncoalesced_addrs.values()) + + def score_mod(t): + score_factor = 1.0 + for tile_size in t[0].tiling.values(): + if not CandidateTiling.is_good_size(tile_size): + score_factor = score_factor / bad_size_additional_tiling_penalty + else: + score_factor = score_factor / good_size_tiling_penalty + + # Add uncoalesced memory score to prevent small coalesced benefits + # from dominating large amounts of uncoalesced memory + uncoalesced_penalty = total_uncoalesced * 0.05 + + return -(t[0].score + uncoalesced_penalty) * score_factor + + # apply penalty for longer tilings that dont increase score much + for cand, tiling_score in sorted(tilings, key=score_mod): + if ( + cls.tiling_is_compatible( + node_schedule, pointwise_numel, reduction_numel, cand.tiling + ) + or cand.tiling == default_tiling + ): + # we always include default reduction numel == 1, dont include + tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0) + if tiling_len > get_max_tiles(default=3): + perf_hint_log.info( + "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles " + "set to %s. Consider increasing", + tiling_len, + torch._inductor.config.triton.max_tiles, + ) + continue + + return cand.tiling, tiling_score + + # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible` + # TODO - look into, occurs with dynamic shapes often + if cand.tiling == default_tiling: + return cand.tiling, tiling_score + + return default_tiling, None + + @classmethod + def tiling_is_compatible( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + tiling: dict[str, sympy.Expr], + ): + assert isinstance(tiling, dict) + return all( + SIMDKernel.is_compatible( + tiling.values(), node.get_ranges(), reduction_numel=reduction_numel + ) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ) + + @classmethod + def get_first_compatible_tiling( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ranked_tilings: list[dict[str, sympy.Expr]], + ): + for tiling in ranked_tilings: + if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling): + return tiling + + return None + + @classmethod + def select_tiling( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> dict[str, sympy.Expr]: + return cls.get_tiling_and_scores( + node_schedule, numel, reduction_numel, coalesce_analysis + )[0] + + @classmethod + def get_tiling_and_scores( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + # If this is a reduction, only tile reduction dims. + is_pointwise = reduction_numel == 1 + + # Tiled reductions are gated by a config flag. + default_tiling = cls.create_tiling([numel], [reduction_numel]) + + # Force tiling compatible with matmul dimensions + # when natively generating matmul without template calls. + for node in EnableReduction.filter(node_schedule): + if isinstance(node.node, ir.ComputedBuffer): + if ( + node.node.get_reduction_type() == "dot" + and config.triton.native_matmul + ): + # A[M,K] @ B[K,N] + # force tiling to be {'y':M, 'x':N, 'r0_':K} + node_ranges = node.get_ranges() + range_y_x = node_ranges[0] # (M,N) + range_r = node_ranges[1] # (K) + tiling = cls.create_tiling(range_y_x, range_r) + return tiling, None + + # # TODO: enable by default + if ( + torch._inductor.config.triton.coalesce_tiling_analysis + and coalesce_analysis + and not config.triton.prefer_nd_tiling + ): + return cls.compute_tiling_strategy( + node_schedule, numel, reduction_numel, coalesce_analysis + ) + + if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles( + default=2 + ) <= 1: + # Emit a perf hint in case we miss an opportunity to tile a reduction. + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if ( + not config.triton.tile_reductions + and len(cls.candidate_tilings(node, numel, reduction_numel)) > 0 + ): + perf_hint_log.info( + textwrap.dedent( + """ + Reduction over non-contiguous dims. + Consider setting config.triton.tile_reductions to True. + """ + ) + ) + break + + return default_tiling, None + + seen_names: OrderedSet[str] = OrderedSet() + candidate_tiles: Counter[CandidateTiling] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel): + if candidate_tiling.name in seen_names: + continue + elif candidate_tiling.name is not None: + seen_names.add(candidate_tiling.name) + candidate_tiles[candidate_tiling] += candidate_tiling.score + + ranked_tilings: list[dict[str, sympy.Expr]] = [ + candidate_tiling.tiling + for candidate_tiling, score in candidate_tiles.most_common() + ] + + if get_max_tiles(default=2) >= 3 and is_pointwise: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + def convert_tiling_to_3d( + tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr] + ) -> Optional[dict[str, sympy.Expr]]: + a0, a1 = tiling0["x"], tiling0.get("y", 1) + b0, b1 = tiling1["x"], tiling1.get("y", 1) + + if ( + free_unbacked_symbols([a1, b1]) + or V.graph.sizevars.size_hint(a1 - b1) == 0 + ): + return None + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + (a0, a1), (b0, b1) = (b0, b1), (a0, a1) + + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if not V.graph.sizevars.statically_known_multiple_of(a1, b1): + return None + + new_tiling = { + "z": a0, + "y": FloorDiv(a1, b1), + "x": b1, + "r0_": tiling0["r0_"], + } + + return new_tiling + + for i in range(1, len(ranked_tilings)): + new_3d_tiling = convert_tiling_to_3d( + ranked_tilings[0], ranked_tilings[i] + ) + if new_3d_tiling is not None: + ranked_tilings = [new_3d_tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + # Optionally, prefer tiling into as many dimensions as possible. + # pyrefly: ignore [unbound-name] + if config.triton.prefer_nd_tiling: + ranked_tilings = ( + cls.get_nd_tilings(node_schedule, numel, reduction_numel) + + ranked_tilings + ) + + if tiling := cls.get_first_compatible_tiling( + node_schedule, numel, reduction_numel, ranked_tilings + ): + return tiling, None + + return default_tiling, None + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes( + self, nodes, benchmark_kernel=False, hint_override: Optional[int] = None + ): + if not any(n.is_template() for n in nodes): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + kernel = self.kernel_type( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with ( + config.patch("benchmark_kernel", benchmark_kernel), + V.set_kernel_handler(kernel), + ): + src_code = kernel.codegen_kernel() + else: + prologue, template, epilogue = nodes[0].get_prologue_template_epilogue( + nodes + ) + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template, + epilogue, + prologue, + only_gen_src_code=True, + hint_override=hint_override, + ) + + # pyrefly: ignore [missing-attribute] + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class CandidateTiling: + tiling: dict[str, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class CantSplit(Exception): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb38dda5a3660e090adc7013da94577507e8a89 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py @@ -0,0 +1,620 @@ +from __future__ import annotations + +import collections +import dataclasses +import functools +import itertools +import typing +from typing import Any, Optional, Union + +import sympy + +import torch + +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import make_symbol, SymT +from ..dependencies import Dep, extract_loop_body_with_args, MemoryDep +from ..runtime.hints import ReductionHint +from ..scheduler import SchedulerNode +from ..utils import cache_on_self +from ..virtualized import V + + +if typing.TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +class NodeScheduleMarker: + @staticmethod + def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + for item in it: + if not (item is DisableReduction or item is EnableReduction): + yield item # type: ignore[misc] + + @staticmethod + def is_reduction() -> bool: + return False + + +NodeScheduleEntry = Union[SchedulerNode, type[NodeScheduleMarker]] + + +class DisableReduction(NodeScheduleMarker): + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction(NodeScheduleMarker): + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule: list[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node # type: ignore[misc] + + +class SIMDKernelFeatures: + """ + An ordered schedule of nodes that will become a single kernel. + """ + + def __init__( + self, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ): + self.node_schedule = node_schedule + # numel excludes reduction_numel + self.numel: sympy.Expr = V.graph.sizevars.simplify(numel) + self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel) + self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {} + self.coalesce_analysis = coalesce_analysis + + @cache_on_self + def is_reduction(self) -> bool: + return self.reduction_numel != 1 + + @cache_on_self + def scheduler_nodes(self) -> Iterable[SchedulerNode]: + return tuple(NodeScheduleMarker.only_nodes(self.node_schedule)) + + def reduction_nodes(self) -> list[SchedulerNode]: + return [n for n in self.scheduler_nodes() if n.is_reduction()] + + @cache_on_self + def buf_accesses(self) -> dict[str, list[Dep]]: + """only needed for config.benchmark_kernel""" + buf_accesses = collections.defaultdict(list) + for node in self.scheduler_nodes(): + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + return buf_accesses + + @cache_on_self + def op_counts(self) -> collections.Counter[str]: + counts: collections.Counter[str] = collections.Counter() + for node in self.scheduler_nodes(): + counts.update(node._body.op_counts) + return counts + + def contains_op(self, op_name: str) -> bool: + """True if V.ops.{op_name} is used in node_schedule""" + return bool(self.op_counts().get(op_name)) + + def get_mutations(self) -> OrderedSet[str]: + mutations: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + for buf in node.get_outputs(): + mutations.update(buf.get_mutations()) + return mutations + + @cache_on_self + def select_index_dtype(self) -> torch.dtype: + # Gather all used buffer names + buffer_names: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + buffer_names.update(node.get_buffer_names()) + buffer_names.update(node.used_buffer_names()) + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = self.numel * self.reduction_numel + + from .simd import SIMDScheduling + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return torch.int32 + return torch.int64 + + @cache_on_self + def get_reduction_hint(self) -> ReductionHint: + reductions = self.reduction_nodes() + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel() + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + return reduction_hint_val + + @cache_on_self + def buffer_read_counts(self) -> dict[str, int]: + """Counts how many times each buffer is read within the kernel""" + read_counts: dict[str, int] = collections.defaultdict(int) + + for node in self.scheduler_nodes(): + # node.read_writes.reads contains MemoryDep objects for each read + for read_dep in node.read_writes.reads: + read_counts[read_dep.name] += 1 + + return dict(read_counts) # Convert defaultdict to regular dict + + def has_non_contiguous_pw_in_reduction_kernel(self) -> bool: + pointwise_nodes = [ + n + for n in self.scheduler_nodes() + if not n.is_reduction() + and n.group[1][0] == self.numel * self.reduction_numel + ] + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + @staticmethod + def reduction_hint(node: Any) -> ReductionHint: + assert node.is_reduction() + if node.node.data.reduction_hint != ReductionHint.INNER and all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + def memory_stats( + self, groups_dict: Optional[dict[str, sympy.Expr]] = None + ) -> MemoryStats: + """Analysis to generate features that can be used in heuristics""" + if groups_dict is None: + groups = (self.numel, self.reduction_numel) + elif groups_dict.keys() == OrderedSet(["x", "r0_"]): + groups = (groups_dict["x"], groups_dict["r0_"]) + else: + raise NotImplementedError(f"groups_dict={groups_dict!r}") + result = self._stats_cache.get(groups) + if result is None: + self._stats_cache[groups] = result = MemoryStats.compute( + MemoryEstimator(self, groups) + ) + return result + + +class MemoryEstimator: + """ + Estimate various properties of the kernel for use in heuristics. + We simulate the memory effects of CSE/buffer elimination in codegen. + """ + + kernel_sizes: tuple[sympy.Expr, ...] + outside_loop: MemoryEstimate + loops: list[MemoryEstimate] + persistent: MemoryEstimate + symbols: list[sympy.Symbol] + + def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]): + self.features = features + self.inside_reduction = features.is_reduction() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.num_reductions_dims = 1 + self.groups = groups + self.symbols = [make_symbol(SymT.INDEX, i) for i in range(len(groups))] + # We are doing two estimates simultaneously: + # 1) the first is a for a non-persistent (aka looped) reduction, using self.outside_loop/self.loops + # we add an item to loops each corresponding to each reduction loop in the kernel + # outside_loop is only used for broadcasting or point-wise ops that don't use the reduction dimension + # 2) the second is for a persistent kernel, using self.persistent + # persistent kernels don't have loops, so we only have one MemoryEstimate() + # for point-wise ops the two estimates will be the same, they matter for reductions only + self.outside_loop = MemoryEstimate() + self.loops = [MemoryEstimate()] + self.persistent = MemoryEstimate() + self.simulate_codegen() + self.remove_kernel_local() + + def simulate_codegen(self) -> None: + from .simd import SIMDKernel + + kernel_size_outside_loop = (*self.groups[:-1], sympy.S.One) + kernel_size_inside_loop = tuple(self.groups) + self.kernel_sizes = kernel_size_inside_loop + + for node in self.features.node_schedule: + if node is DisableReduction: + self.inside_reduction = False + self.kernel_sizes = kernel_size_outside_loop + continue + elif node is EnableReduction: + self.inside_reduction = True + self.kernel_sizes = kernel_size_inside_loop + self.loops.append(MemoryEstimate()) + continue + assert isinstance(node, SchedulerNode) + rw = extract_loop_body_with_args( + node._body, + SIMDKernel.map_kernel_groups_to_node_sizes( + self.kernel_sizes, node.get_ranges(), self.set_ranges + ), + dict(zip(self.symbols, self.kernel_sizes)), + ) + + for dep in rw._reads: + if not isinstance(dep, MemoryDep): + continue + dep = dep.simplify_with_ranges() + if not self.persistent.writes.get(dep.name): # cache miss? + self.persistent.reads[dep.name].add(dep) + # the cache behavior of looped kernels is more complex than the persistent case above + # some operations are lifted outside the loop (if they don't use the reduction dimension) + # other operations are inside the loop, and can only be reused within the same loop + if not ( + self.outside_loop.writes.get(dep.name) + or self.loops[-1].writes.get(dep.name) + ): + self.scope(dep).reads[dep.name].add(dep) + if dep.name in self.store_buffer_names and self.loops[-1].reads.get( + dep.name + ): + self.must_keep_buffers.add(dep.name) + + for dep in rw._writes: + if not isinstance(dep, MemoryDep): + continue + dep = dep.simplify_with_ranges() + self.store_buffer_names.add(dep.name) + self.persistent.writes[dep.name].add(dep) + self.scope(dep).writes[dep.name].add(dep) + + def remove_kernel_local(self) -> None: + # Remove any kernel-local buffers + fused_node_names = OrderedSet( + [n.get_name() for n in self.features.scheduler_nodes()] + ) + for name in self.store_buffer_names: + if not self.persistent.reads.get( + name + ) and V.graph.scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ): + self.persistent.remove(name) + if name not in self.must_keep_buffers: + # we can also remove this from the looped kernel + self.outside_loop.remove(name) + for loop in self.loops: + loop.remove(name) + + if not self.loops[-1]: + self.loops.pop() # for pointwise ops + + def scope(self, dep: MemoryDep) -> MemoryEstimate: + """Determine how a read/write should be categorized""" + if self.inside_reduction and ( + self.has_reduction_var(dep.index) or dep.is_indirect() + ): + return self.loops[-1] + return self.outside_loop + + def has_reduction_var(self, index: sympy.Expr) -> bool: + for sym in self.symbols[-self.num_reductions_dims :]: + if isinstance(sym, sympy.Symbol) and sym in index.free_symbols: + return True + return False + + def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]: + assert len(self.kernel_sizes) == len(lengths) + return [ + self.make_flat_range(sym, numel, length) + for sym, numel, length in zip(self.symbols, self.kernel_sizes, lengths) + ] + + @staticmethod + def make_flat_range( + sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr] + ) -> list[sympy.Expr]: + if len(lengths) == 1 and numel == lengths[0]: + return [sym] + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + if V.graph.sizevars.statically_known_equals(divisor * length, numel): + expr = FloorDiv(sym, divisor) + else: + expr = ModularIndexing(sym, divisor, length) + itervars.append(expr) + divisor = divisor * length + return [*reversed(itervars)] + + +@dataclasses.dataclass +class MemoryEstimate: + """Tracks the memory usage of a single loop in the generated kernel""" + + reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + + def remove(self, name: str) -> None: + self.reads.pop(name, None) + self.writes.pop(name, None) + + def __bool__(self) -> bool: + return bool(self.reads or self.writes) + + def __repr__(self) -> str: + return f"""MemoryEstimate( + reads={[*itertools.chain.from_iterable(self.reads.values())]!r}, + writes={[*itertools.chain.from_iterable(self.writes.values())]!r} + )""" + + +@dataclasses.dataclass +class StatsForDim: + """Memory usage stats for a block dimension in the generated kernel (different from user dimensions)""" + + # the number of load/store ops + count_per_thread_contiguous: int = 0 + count_per_thread_broadcast: int = 0 + count_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes in each load/store op for a single element + bytes_per_thread_contiguous: int = 0 + bytes_per_thread_broadcast: int = 0 + bytes_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes read by entire kernel + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForDim: + return StatsForDim( + count_per_thread_contiguous=self.count_per_thread_contiguous + + other.count_per_thread_contiguous, + count_per_thread_broadcast=self.count_per_thread_broadcast + + other.count_per_thread_broadcast, + count_per_thread_non_contiguous=self.count_per_thread_non_contiguous + + other.count_per_thread_non_contiguous, + bytes_per_thread_contiguous=self.bytes_per_thread_contiguous + + other.bytes_per_thread_contiguous, + bytes_per_thread_broadcast=self.bytes_per_thread_broadcast + + other.bytes_per_thread_broadcast, + bytes_per_thread_non_contiguous=self.bytes_per_thread_non_contiguous + + other.bytes_per_thread_non_contiguous, + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + other.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return ( + self.count_per_thread_contiguous + + self.count_per_thread_broadcast + + self.count_per_thread_non_contiguous + ) + + @property + def bytes_per_thread(self) -> int: + return ( + self.bytes_per_thread_contiguous + + self.bytes_per_thread_broadcast + + self.bytes_per_thread_non_contiguous + ) + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @property + def contiguous_score(self) -> float: + return 1.0 - self.count_per_thread_non_contiguous / max( + self.count_per_thread, 1 + ) + + +@dataclasses.dataclass +class StatsForLoop: + """Memory usage stats for single loop in the generated kernel""" + + # load/store ops + count_per_thread: int = 0 + bytes_per_thread: int = 0 + + def __add__(self, other: typing.Self) -> StatsForLoop: + return StatsForLoop( + count_per_thread=self.count_per_thread + other.count_per_thread, + bytes_per_thread=self.bytes_per_thread + other.bytes_per_thread, + ) + + +@dataclasses.dataclass +class StatsForReadsOrWrites: + """Memory usage stats that are collected for reads/writes/both""" + + dim: list[StatsForDim] + loop: list[StatsForLoop] + # total bytes contiguous in any dimension + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForReadsOrWrites: + assert len(self.dim) == len(other.dim) + assert len(self.loop) == len(other.loop) + return StatsForReadsOrWrites( + dim=[a + b for a, b in zip(self.dim, other.dim)], + loop=[a + b for a, b in zip(self.loop, other.loop)], + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + self.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return self.dim[0].count_per_thread + + @property + def bytes_per_thread(self) -> int: + return self.dim[0].bytes_per_thread + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @classmethod + def compute( + cls, + loop_deps: list[dict[str, OrderedSet[MemoryDep]]], + index_symbols: list[sympy.Symbol], + ) -> typing.Self: + ndim = len(index_symbols) + result = cls(dim := [StatsForDim() for _ in range(ndim)], []) + for dep_group in loop_deps: + result.loop.append(loop_stats := StatsForLoop()) + for name, deps in dep_group.items(): + assert deps + contiguous_or_broadcast = [True] * ndim + numel = sympy.S.Zero + itemsize = V.graph.get_dtype(name).itemsize + loop_stats.count_per_thread += len(deps) + loop_stats.bytes_per_thread += itemsize * len(deps) + for dep in deps: + strides: list[sympy.Expr] = V.graph.sizevars.stride_vars( + dep.index, index_symbols + ) + for i in range(ndim): + if V.graph.sizevars.statically_known_equals(strides[i], 1): + dim[i].count_per_thread_contiguous += 1 + dim[i].bytes_per_thread_contiguous += itemsize + elif ( + V.graph.sizevars.statically_known_equals(strides[i], 0) + and not dep.is_indirect() + ): + dim[i].count_per_thread_broadcast += 1 + dim[i].bytes_per_thread_broadcast += itemsize + else: + dim[i].count_per_thread_non_contiguous += 1 + dim[i].bytes_per_thread_non_contiguous += itemsize + contiguous_or_broadcast[i] = False + numel += dep.get_numel() + if len(deps) > 1: + # can't read more elements than exist in the buffer + numel = sympy.Min(numel, V.graph.get_numel(name)) + nbytes = numel * itemsize + for i in range(ndim): + if contiguous_or_broadcast[i]: + dim[i].bytes_contiguous_or_broadcast += nbytes + else: + dim[i].bytes_non_contiguous += nbytes + if any(contiguous_or_broadcast): + result.bytes_contiguous_or_broadcast += nbytes + else: + result.bytes_non_contiguous += nbytes + if len(result.loop) > 1: + # the first loop represent the "outside of the loop" compute which could be long lived + result.loop = [result.loop[0] + x for x in result.loop[1:]] + return result + + +@dataclasses.dataclass +class StatsForKernelType: + """Memory usage stats that are collected for both persistent and looped kernels""" + + reads: StatsForReadsOrWrites + writes: StatsForReadsOrWrites + memory: StatsForReadsOrWrites + + @classmethod + def compute( + cls, loops: list[MemoryEstimate], estimator: MemoryEstimator + ) -> typing.Self: + reads = StatsForReadsOrWrites.compute( + [loop.reads for loop in loops], estimator.symbols + ) + writes = StatsForReadsOrWrites.compute( + [loop.writes for loop in loops], estimator.symbols + ) + return cls( + reads=reads, + writes=writes, + memory=reads + writes, + ) + + +@dataclasses.dataclass +class MemoryStats: + """Memory usage stats collected for each generated kernel""" + + persistent: StatsForKernelType + looped: StatsForKernelType + + def get(self, persistent: bool) -> StatsForKernelType: + return self.persistent if persistent else self.looped + + @classmethod + def compute(cls, estimator: MemoryEstimator) -> typing.Self: + persistent = StatsForKernelType.compute([estimator.persistent], estimator) + if len(estimator.loops) == 1 and not ( + estimator.outside_loop and estimator.loops[0] + ): + looped = persistent # loops/persistent is the same in this common case + else: + looped = StatsForKernelType.compute( + [estimator.outside_loop, *estimator.loops], estimator + ) + return cls( + persistent=persistent, + looped=looped, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..7b931fb3bf47e74596e9053f58177a6faa180edd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py @@ -0,0 +1,433 @@ +import itertools +import logging +from collections.abc import Callable +from typing import Any, Union + +import torch +import torch._inductor.config as config +from torch._inductor import ir +from torch._inductor.codegen.common import KernelTemplate +from torch._inductor.ir import ( + Buffer, + FixedLayout, + get_free_symbols, + get_symbolic_inputs, + gm_original_output_strides, + ir_node_to_tensor, + Layout, +) +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.utils import do_bench_using_profiling +from torch._inductor.virtualized import V + + +log = logging.getLogger(__name__) + + +def inline_subgraph_to_ir_nodes( + gm: torch.fx.GraphModule, inputs: list[Any], name: str +) -> Any: + """Inline a subgraph by converting its FX operations to individual IR nodes. + + This converts a subgraph to multiple ComputedBuffer nodes (fusable), + enabling epilogue fusion with subsequent operations. + + Returns: + TensorBox containing the final operation result as individual IR nodes + """ + from torch._inductor.lowering import process_subgraph_nodes + + return process_subgraph_nodes(gm, inputs) + + +class SubgraphChoiceCaller(ir.ChoiceCaller): + """ + Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary + GraphModule. Compiles the Subgraph down to a module for benchmarking. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + make_fx_graph: Callable[..., Any], + ) -> None: + super().__init__(name, input_nodes, layout, description) + + self.example_inputs = [] + with V.fake_mode: + for inp in self.input_nodes: + # Here there will be no unbacked symbols, as SubgraphBuffer does not support them + assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0 + assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0 + + inp.data.freeze_layout() # type: ignore[attr-defined] + self.example_inputs.append(ir_node_to_tensor(inp)) + + self.gm = make_fx_graph(*self.example_inputs) + gm_original_output_strides(self.gm) + + self.sym_inputs = get_symbolic_inputs(self.input_nodes) + + # Cache compiled module to avoid recompiling on every benchmark call + self._compiled_module: Any = None + self._compiled_sym_inputs: list[Any] | None = None + + def __str__(self) -> str: + return f"SubgraphCaller({self.name})" + + def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: + """ + Compile the subgraph for benchmarking and return (module, sym_inputs). + + TODO: Add precompile() method to enable parallel compilation of all choices + before benchmarking. + """ + import torch._inductor.config as inductor_config + from torch._inductor.graph import GraphLowering + + safe_name = self.name.replace("::", "_").replace(".", "_") + + bm_graph_lowering = GraphLowering( + gm=self.gm, + example_inputs=self.example_inputs, + shape_env=V.graph._shape_env, + cpp_wrapper=V.graph.cpp_wrapper, + aot_mode=V.graph.aot_mode, + extern_node_serializer=V.graph.extern_node_serializer, + is_inference=V.graph.is_inference, + is_backward=V.graph.is_backward, + name=f"benchmark_{safe_name}", + ) + + for sym_inp in self.sym_inputs: + bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp + bm_graph_lowering.graph_input_names.append(sym_inp.name) + + sym_inputs = [ + # pyrefly: ignore [no-matching-overload] + int(V.graph.sizevars.shape_env.size_hint(sym_var)) + for sym_var in self.sym_inputs + ] + + if len(sym_inputs) == 0: + # Sanity check that args are same layout as example inputs + # Only do it if there are no symbolic inputs, otherwise + # the dynamic dim will be realized to the same size as args + for ar, example_inp in zip(args, self.example_inputs): + # Sanity check that args are same layout as example inputs + if isinstance(ar, torch.Tensor): + assert isinstance(example_inp, torch.Tensor) + assert ar.shape == example_inp.shape + assert ar.stride() == example_inp.stride() + + with V.set_graph_handler(bm_graph_lowering): + # Don't bother autotuning on Triton here + with inductor_config.patch( + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + bm_graph_lowering.run(*self.example_inputs) + mod = bm_graph_lowering.compile_to_module() + + return mod, sym_inputs + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + """ + Regular benchmarking: compile and use benchmarker with warmup/rep. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) + return benchmarker.benchmark( + # Shallow clone args since bm_func may clear args + lambda: bm_func([*sym_inputs, *args]), + device=benchmarker.infer_device(*sym_inputs, *args), + ) + + def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None: + """ + Only run once with cached compiled module. + Called by benchmark_collective_choice which handles warmup + and timing with barrier synchronization across all ranks. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call + bm_func([*sym_inputs, *args]) + + def hash_key(self) -> str: + return "-".join( + [ + self.name.rsplit("_", 1)[0], + *[str(inp.get_size()) for inp in self.input_nodes], + *[str(inp.get_stride()) for inp in self.input_nodes], + str(self.gm.graph), + ] + ) + + def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: + return ir.TensorBox.create( + ir.SubgraphBuffer( + layout=self.layout, + input_nodes=self.input_nodes, + gm=self.gm, + example_inputs=self.example_inputs, + subgraph_name=self.name, + ) + ) + + def info_dict(self) -> dict[str, Any]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "subgraph", + "kernel_name": self.name, + } + + def autoheuristic_id(self) -> str: + return f"subgraph_{self.name}" + + +class SubgraphTemplate(KernelTemplate): + """ + A template for subgraph evaluation to be used in autotuning. + + This class allows creating customized subgraphs that can be appended + as choices during the autotuning process, enabling the selection of + optimal implementations for complex operations. + """ + + index_counter = itertools.count() + + def __init__( + self, + name: str, + ): + """ + Initialize a subgraph template. + + Args: + name: The name of this template + graph: The FX graph + """ + super().__init__(name=name) + + def generate( # type: ignore[override] + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + make_fx_graph: Callable[..., Any], + description: str = "", + **kwargs: Any, + ) -> SubgraphChoiceCaller: + """ + Generate a SubgraphChoiceCaller instance for autotuning. + + Args: + name: The name for this subgraph choice + input_nodes: List of input nodes to the subgraph + layout: Memory layout information for the output + make_fx_graph: Callable that creates the FX graph for this subgraph + description: Optional description of this choice + **kwargs: Additional keyword arguments + + Returns: + SubgraphChoiceCaller: A callable object that can be used for autotuning + """ + + return SubgraphChoiceCaller( + name=f"{name}_{next(SubgraphTemplate.index_counter)}", + input_nodes=input_nodes, + layout=layout, + description=description, + make_fx_graph=make_fx_graph, + ) + + def generate_custom_op_choices( + self, + name: str, + decompositions: list[Callable[..., Any]], + input_nodes: list[Buffer], + non_tensor_args: list[dict[str, Any]], + default_impl: Callable[..., Any] | None = None, + ) -> list[SubgraphChoiceCaller]: + """ + Generate multiple SubgraphChoiceCaller instances for custom op autotuning. + + This method extends SubgraphTemplate to support custom op decompositions, + allowing multiple implementations to compete in autotuning. + + Args: + name: Base name for the choices + decompositions: List of decomposition functions to compete in autotuning + input_nodes: List of tensor inputs. All tensor arguments must be passed here. + non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition. + default_impl: Default implementation for layout inference + + Returns: + List of SubgraphChoiceCaller instances for autotuning + """ + if not decompositions: + return [] + + assert len(decompositions) == len(non_tensor_args), ( + f"decompositions and non_tensor_args must have same length, " + f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs" + ) + + # Infer layouts and ensure layout consistency for fair autotuning comparison + layouts = [ + self._infer_custom_op_layout(input_nodes, decomp, kwargs, default_impl) + for decomp, kwargs in zip(decompositions, non_tensor_args) + ] + + # Validate all decompositions produce equivalent layouts for fair comparison + self._validate_layout_equivalence(name, decompositions, layouts) + layout = layouts[0] # All layouts are now validated to be equivalent + + choices: list[SubgraphChoiceCaller] = [] + for decomp, decomp_kwargs in zip(decompositions, non_tensor_args): + # Create make_fx_graph function for this decomposition + import functools + + def make_fx_graph( + *args: Any, + decomp: Callable[..., Any] = decomp, + decomp_kwargs: dict[str, Any] = decomp_kwargs, + ) -> Any: + # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs + from torch.fx.experimental.proxy_tensor import make_fx + + from ..decomposition import select_decomp_table + + decomposition_table = select_decomp_table() + + return make_fx( + functools.partial(decomp, **decomp_kwargs), + decomposition_table=decomposition_table, + )(*args) + + # Generate descriptive name for this variant + variant_name = self._generate_variant_name(decomp, decomp_kwargs) + + choice = self.generate( + name=f"{name}_{variant_name}", + input_nodes=input_nodes, + layout=layout, + make_fx_graph=make_fx_graph, + description=f"CustomOp {decomp.__name__}", + ) + choices.append(choice) + + return choices + + def _generate_variant_name( + self, decomp: Callable[..., Any], kwargs: dict[str, Any] + ) -> str: + """Generate a descriptive name for a decomposition variant with its parameters.""" + base_name = decomp.__name__ + if not kwargs: + return base_name + param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items())) + return f"{base_name}_{param_suffix}" + + def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None: + """Validate that kwargs contains only non-tensor arguments.""" + for key, value in kwargs.items(): + assert not isinstance(value, (torch.Tensor, Buffer)), ( + f"kwargs['{key}'] contains tensor {type(value)}. " + f"Tensor arguments should be in input_nodes, not kwargs. " + f"Only scalar/non-tensor parameters should be in kwargs." + ) + + def _validate_layout_equivalence( + self, + op_name: str, + decompositions: list[Callable[..., Any]], + layouts: list[Layout], + ) -> None: + """Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning.""" + if not layouts: + return + + reference = layouts[0] + for i, layout in enumerate(layouts[1:], start=1): + if (layout.device, layout.dtype, layout.size, layout.stride) != ( + reference.device, + reference.dtype, + reference.size, + reference.stride, + ): + raise AssertionError( + f"Layout mismatch in custom op '{op_name}': " + f"decomposition '{decompositions[i].__name__}' produces " + f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) " + f"but '{decompositions[0].__name__}' produces " + f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})" + ) + + def _infer_custom_op_layout( + self, + input_nodes: list[Buffer], + function_decomposition: Callable[..., Any], + kwargs: dict[str, Any], + default_impl: Callable[..., Any] | None = None, + ) -> Layout: + """Infer output layout for custom ops using the default implementation when available. + Note that the Subgraph assumes custom ops return exactly one tensor output. + TODO: Add support for multiple output custom ops. + """ + import functools + + from torch._inductor.virtualized import V + + # Assert kwargs contain only non-tensor arguments + self._validate_non_tensor_kwargs(kwargs) + + with V.fake_mode: + example_inputs = [] + for inp in input_nodes: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + example_inputs.append(fake_tensor) + + fn = functools.partial(function_decomposition, **kwargs) + output = fn(*example_inputs) + + # Assert single output + assert isinstance(output, torch.Tensor), ( + f"Expected single tensor output, got {type(output)}. " + f"Multi-output custom ops not yet supported in autotuning." + ) + + return FixedLayout( + device=output.device, + dtype=output.dtype, + size=output.shape, + stride=output.stride(), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6b671f15db31bfd8e2ddea7e556e9867b93227 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py @@ -0,0 +1,6263 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +from abc import abstractmethod +from collections.abc import Callable, Iterable, Sequence +from functools import lru_cache +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +import torch._logging +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity, preserve_rng_state +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._triton import ( + get_triton_version, + has_triton_package, + has_triton_stable_tma_api, +) + +from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir, metrics, utils +from ..async_compile import AsyncCompile +from ..codecache import code_hash, get_path, PyCodeCache, write_atomic +from ..debug import set_kernel_post_grad_provenance_tracing +from ..ops_handler import DefaultHandler +from ..runtime import triton_heuristics +from ..runtime.benchmarking import benchmarker +from ..runtime.hints import ( + AutotuneHint, + DeviceProperties, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 +from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode +from ..shape_propagation import get_broadcasted_shape +from ..utils import ( + cache_on_self, + DelayMaybeLine, + DelayReplaceLine, + get_bounds_index_expr, + get_fused_kernel_name, + get_kernel_metadata, + is_welford_reduction, + Placeholder, + prefix_is_reduction, + sympy_dot, + sympy_product, + sympy_subs, + triton_type, + triton_version_uses_attrs_dict, + upcast_compute_type, +) +from ..virtualized import _ops as ops, ReductionType, StoreMode, V +from ..wrapper_benchmark import get_kernel_category_by_source_code +from .block_analysis import BlockPatternMatcher +from .common import ( + ArgName, + BackendFeature, + ConstexprArg, + CSE, + CSEVariable, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + is_buffer_removed, + OpOverrides, + PythonPrinter, + RemovedArg, + SizeArg, + TensorArg, + WorkspaceArg, + WorkspaceZeroMode, +) +from .simd import ( + constant_repr, + IterationRanges, + IterationRangesEntry, + IterationRangesRoot, + PartialAccumulate, + SIMDKernel, + SIMDScheduling, +) +from .triton_utils import ( + config_of, + equal_1_arg_indices, + non_constexpr_signature, + should_unwrap_unspec_arg, + signature_to_meta, +) +from .wrapper import SymbolicCallArg + + +if TYPE_CHECKING: + from types import ModuleType + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from ..ir import IRNode + from .common import BlockShapeType + from .simd_kernel_features import SIMDKernelFeatures + + _T = TypeVar("_T") + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +async_compile = AsyncCompile() + + +def get_triton_reduction_function(reduction_type): + use_helper = reduction_type in ("any", "max", "min", "prod") + module = "triton_helpers" if use_helper else "tl" + if reduction_type in ("max", "min"): + return f"{module}.{reduction_type}2" + else: + return f"{module}.{reduction_type}" + + +def is_sympy_integer_like(expr: object): + """ " + Is this expression a Sympy Integer or is it an integer sympy Expr + containing no free symbols. The latter case can happen with Identity expr. + """ + if not isinstance(expr, sympy.Expr): + return False + return isinstance(expr, sympy.Integer) or ( + expr.is_integer and len(expr.free_symbols) == 0 + ) + + +class OpDtypeSupport: + """ + Some Triton ops such as libdevice and tl.math only support float32 and float64. + This class records which dtypes are supported by specific IR ops. + """ + + supported_dtypes: dict[str, OrderedSet[torch.dtype]] = {} + convert_outputs: dict[str, bool] = {} + + @classmethod + def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None: + op_name = func.__name__ + cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64]) + cls.convert_outputs[op_name] = convert_output + + +@lru_cache(None) +def gen_attr_descriptor_import() -> str: + """ + import AttrsDescriptor if the triton version is new enough to have this + class defined. + """ + if not has_triton_package(): + return "" + + import triton.compiler.compiler + + # Note: this works because triton.compiler.compiler imports AttrsDescriptor from triton.backends.compiler + # When support for the legacy AttrsDescriptor is removed then this import path should be changed. + if hasattr(triton.compiler.compiler, "AttrsDescriptor"): + return "from triton.compiler.compiler import AttrsDescriptor" + else: + return "" + + +@lru_cache(None) +def gen_common_triton_imports() -> str: + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + """ + ) + return imports.getvalue() + + +class TritonSymbols: + """ + Stores sympy.Symbol instances and constants associated with triton codegen. + """ + + reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX]) + block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types]) + + block_offsets = { + symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True) + for symt in block_types + } + + block_sizes = { + symt: sympy.Symbol( + f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True + ) + for symt in block_types + } + + @classmethod + def get_block_shape(cls, expr: sympy.Expr) -> BlockShapeType: + # return block shape of sympy Expression + # e.g., + # tmp13 = y1 + # tmp14 = x0 - tmp13 + # + # get_block_shape(y1) = (YBLOCK,1,1) + # get_block_shape(x0-tmp13) = (YBLOCK,XBLOCK,1) + + expr_shape: BlockShapeType = () + expr_vars = expr.free_symbols + for var in expr_vars: + if symbol_is_type(var, SymT.TMP): + cse_var = V.kernel.cse.varname_map[var.name] + var_shape = cse_var.shape + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + var_shape = () + else: + symbol_matches = [ + symt for symt in cls.block_types if symbol_is_type(var, symt) + ] + assert len(symbol_matches) == 1, f"Ambiguous type: {var.name}" + + sym = symbol_matches[0] + ndim = V.kernel.triton_tensor_ndim() + shape = ["1"] * ndim + + tree_match = [ + tree + for tree in V.kernel.active_range_trees() + if prefix_str[sym] == tree.prefix + ] + assert len(tree_match) == 1, "# of Match expected to 1" + + shape[tree_match[0].tensor_dim] = str(cls.get_block_size(tree_match[0])) + var_shape = tuple(shape) + + # Union current variable shape + expr_shape = get_broadcasted_shape(expr_shape, var_shape) + + assert expr_shape is not None + + return expr_shape + + @classmethod + def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_sizes[tree.symt] + + @classmethod + def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_offsets[tree.symt] + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: OrderedSet[str] + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + expand_shape: Optional[Sequence[Union[int, str]]] + + def has_mask(self) -> bool: + return bool(self.mask_vars) + + def has_indirect(self) -> bool: + return free_symbol_is_type(self.index, SymT.TMP) + + def has_rindex(self) -> bool: + return self._has_rindex + + def has_tmpmask(self) -> bool: + return any(str(mask).startswith("tmp") for mask in self.mask_vars) + + def has_rmask(self) -> bool: + return any(str(mask).startswith("r") for mask in self.mask_vars) + + @property + def mask_str(self) -> str: + # The sorted call is added to make sure the order is still + # deterministic if self.mask_vars contains mix of string + # and TritonCSEVariable + return ( + " & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None" + ) + + +@dataclasses.dataclass +class BlockDescriptorOptions: + """ + This is a base class that describes a block descriptor used in Triton kernels. + It can be used to create either a tensor descriptor (with TensorDescriptorOptions) + or a block pointer (with BlockPtrOptions). + """ + + params: BlockParameters + constant_offset: sympy.Expr + order: list[int] + mask_vars: OrderedSet[str] + broadcast_shape: Sequence[sympy.Expr] + broadcasting_dims: list[bool] + final_shape: Sequence[sympy.Expr] + # If the BlockParameters have been sorted using a particular stride order + # transpose load / store blocks at runtime using the information in + # stride_sorter. + stride_sorter: BlockParameters.StrideSorter + _boundary_check: Optional[list[int]] = None + # Can we safely lift the constructor + # to the top of the kernel? + can_lift: bool = False + + @property + def shape(self) -> list[sympy.Expr]: + return self.params.shape + + @property + def block_shape(self) -> list[sympy.Expr]: + return self.params.block_shape + + @property + def strides(self) -> list[sympy.Expr]: + return self.params.strides + + @property + def offsets(self) -> list[sympy.Expr]: + return self.params.offsets + + @classmethod + def create( + cls, + *, + params: BlockParameters, + constant_offset: sympy.Expr, + range_trees: list[IterationRangesRoot], + mask_vars: OrderedSet[str], + get_max_block: Callable[[str], int], + stride_sorter_cls: type[BlockParameters.StrideSorter], + can_lift: bool = False, + ) -> BlockDescriptorOptions: + """Helper to create a BlockDescriptorOptions instance""" + + sizevars = V.graph.sizevars + + def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: + return [sizevars.lookup_precomputed_size(expr) for expr in exprs] + + # Look up precomputed sizes + params.shape = lookup_size(params.shape) + params.strides = lookup_size(params.strides) + + # Strip out dimensions of size 1. + # Size 1 dimensions are redundant since the triton kernel shape + # will be e.g. [YBLOCK, XBLOCK], so tl.reshape would just remove these + # dimensions anyway + singleton_dims = [ + sizevars.statically_known_equals(dim, 1) for dim in params.block_shape + ] + if all(singleton_dims): + # Handle a pure singletons, e.g. [1, 1] + singleton_dims[-1] = False + + # Drop singleton dimensions from the block descriptor. + params = params.remove_dims(singleton_dims) + + # Maybe reorder dimensions based on strides + # with tl.trans applied at load / store time + params, stride_sorter = params.maybe_sort_with_stride_order( + stride_sorter_cls=stride_sorter_cls, shape_env=V.graph._shape_env + ) + + # Strip out dimensions of stride 0. + # These will be restored with tl.broadcast_to. + broadcasting_dims = [ + sizevars.statically_known_equals(stride, 0) for stride in params.strides + ] + + # Record the post-broadcast shape before broadcasting dims are removed. + # The pre-broadcast shape is identical to this, except broadcasting dims are + # replaced with 1. + broadcast_shape = params.block_shape + + # Drop broadcasting dims from the block descriptor. + params = params.remove_dims(broadcasting_dims) + + # Compute the final shape, adjusting for special kernel types. + final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees] + if V.kernel.no_x_dim: + assert range_trees[0].prefix == "x" + final_shape.pop(0) + + reduction_ndim = V.kernel.num_reduction_dims + if ( + not V.kernel.inside_reduction + and len(params.strides) == len(V.kernel.numels) - reduction_ndim + and V.kernel.features.is_reduction() + ): + # Need to expand rank to match the rank used inside the reduction loop + final_shape += [sympy.S.One] * reduction_ndim + + try: + # Get permutation to sort strides in ascending order. + # This is used as the order argument in tl.make_block_ptr + order = utils.argsort_sym(V.graph._shape_env, params.strides) + except AssertionError: + # Symbolic shapes, failed to evaluate comparison expression + order = list(reversed(range(len(params.strides)))) + + result = cls( + params=params, + constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), + order=order, + mask_vars=mask_vars, + final_shape=final_shape, + broadcast_shape=broadcast_shape, + broadcasting_dims=broadcasting_dims, + stride_sorter=stride_sorter, + can_lift=can_lift, + ) + result.compute_boundary_check(get_max_block, range_trees) + return result + + def replace_offset( + self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT + ) -> sympy.Expr: + """ + Replaces instances of {symt}_offset with the new expression. + """ + roffset = TritonSymbols.block_offsets[symt] + return sympy_subs(expr, {roffset: replacement}) + + def remove_roffsets(self, expr: sympy.Expr) -> sympy.Expr: + for symt in TritonSymbols.reduction_types: + expr = self.replace_offset(expr, sympy.Integer(0), symt) + return expr + + def compute_boundary_check( + self, + get_max_block: Callable[[str], int], + range_trees: list[IterationRangesRoot], + ) -> None: + """List of indices to pass to tl.load(boundary_check=...)""" + sizevars = V.graph.sizevars + + # Substitute maximum block sizes in shape expressions. + # This works in multiple_of checks because block sizes are powers of 2. + block_to_max: dict[sympy.Expr, Any] = { + TritonSymbols.block_sizes[t.symt]: get_max_block(prefix_str[t.symt]) + for t in range_trees + } + + # Also see Note: Constant mask optimisation + # if ynumel / YBLOCK > max_ygrid, then the z dimension is used to handle + # the remaining programs that cannot fit into the y dimension. This means + # it's possible that more than the required number of programs are launched, + # possibly leading to out-of-bounds accesses. So even if ynumel divides YBLOCK, + # boundary checking is required in the dimensions that are based on YBLOCK + # e.g. for [YBLOCK // 16, YBLOCK, XBLOCK] dimensions 0 and 1 need boundary + # checks when max_ygrid is exceeded. + needs_overflow_grid = any(map(V.kernel.needs_yz_grid_overflow, range_trees)) + self._boundary_check = [ + idx + for idx in range(len(self.shape)) + if ( + not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero) + and ( + ( + needs_overflow_grid + and TritonSymbols.block_sizes[SymT.YBLOCK] + in self.block_shape[idx].free_symbols + ) + or ( + not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], + sympy_subs(self.block_shape[idx], block_to_max), + ) + ) + ) + and not ( + V.kernel.no_x_dim + and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK] + ) + ) + ] + + def boundary_check(self) -> list[int]: + assert self._boundary_check is not None + return self._boundary_check + + def has_indirect(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_rindex(self) -> bool: + return any( + free_symbol_is_type(expr, TritonSymbols.reduction_types) + for expr in self.block_shape + ) + + def has_rmask(self) -> bool: + return self.has_rindex() + + def has_tmpmask(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_mask(self) -> bool: + return bool(self.boundary_check()) + + def codegen_broadcast_and_reshape( + self, + value: str, + initial_shape: Sequence[sympy.Expr], + final_shape: Sequence[sympy.Expr], + allow_implicit: bool, + for_store: bool, + ) -> str: + """ + Generate a broadcast and a reshape for the block descriptor. + This restores stride-0 dimensions which were removed from the block descriptor. + + Transposes are also applied to the input using self.stride_sorter: + if for_store is True: + - First Broadcast the value. Since self.broadcast_shape is stored in + descending stride order, it must be reverted to the original order + since the input value does not have dims with descending strides + - After, transpose the broadcasted value so that dimensions are in + descending stride order + - Finally reshape to the block shape + else (for load): + - First broadcast the value to self.broadcast_shape (strides are descending) + - Then transpose the value so that dimensions no longer have descending strides + - Finally reshape the block to the final kernel tile shape + """ + broadcast_shape = self.broadcast_shape + broadcasting_dims = self.broadcasting_dims + + # If the block parameters have been sorted by descending strides, + # permute the broadcasting parameters so that they are compatible + # with the value being stored. This is because the dimensions + # of the value being stored are not sorted in descending stride order, + # but the broadcasting parameters are based on the dims in sorted order + if for_store: + broadcast_shape = self.stride_sorter.revert(self.broadcast_shape) + broadcasting_dims = self.stride_sorter.revert(self.broadcasting_dims) + + # Reshape to add singletons. + pre_broadcast_shape = [ + sympy.S.One if is_broadcasting else dim + for dim, is_broadcasting in zip(broadcast_shape, broadcasting_dims) + ] + value = triton_reshape(value, initial_shape, pre_broadcast_shape) + + if ( + not self.stride_sorter.is_identity + and not for_store + and len(pre_broadcast_shape) == len(final_shape) + ): + # If all we need to do is transpose to match the final shape + # with implicit broadcasting then we don't need an explicit broadcast + # unless the caller requests it. So just test implicit broadcast support + # with the transposed pre broadcast shape + pre_broadcast_shape = self.stride_sorter.revert(pre_broadcast_shape) + + # Broadcast singletons. + # For loads, we can often implicitly broadcast singleton dimensions. + # We need an explicit broadcast for stores, or if the final reshape does more + # than add singletons. + sizevars = V.graph.sizevars + supports_implicit_broadcast = allow_implicit and ( + len(pre_broadcast_shape) == len(final_shape) + and all( + sizevars.statically_known_equals(pre_dim, 1) + or sizevars.statically_known_equals(pre_dim, post_dim) + for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape) + ) + ) + + if any(self.broadcasting_dims) and not supports_implicit_broadcast: + value = ( + f"tl.broadcast_to({value}, {V.kernel.index_to_str(broadcast_shape)})" + ) + + old_shape = self.broadcast_shape + if not self.stride_sorter.is_identity: + # if for_store the transform is + # (non-descending strides) broadcasted kernel tile shape + # -> (descending strides) block descriptor shape + # o/w if loading the transform is + # (descending strides) ((maybe implicitly) broadcasted block shape + # -> (non-descending) (maybe implicitly) broadcasted kernel tile shape + permute_dims = ( + self.stride_sorter.sort_idx + if for_store + else self.stride_sorter.revert_sort_idx + ) + value = f"tl.trans({value}, {permute_dims})" + old_shape = ( + self.broadcast_shape + if for_store + else self.stride_sorter.revert(self.broadcast_shape) + ) + + # Reshape to the final shape. + value = triton_reshape(value, old_shape, final_shape) + + return value + + +@dataclasses.dataclass +class TensorDescriptorOptions(BlockDescriptorOptions): + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_tensor_descriptor() + + Args: + name: variable name for pointer + roffset: unused, but kept for compatibility with BlockPtrOptions.format() + + Returns: + "tl.make_tensor_descriptor(...)" + """ + + f = V.kernel.index_to_str + args = [ + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + ] + + return f"tl.make_tensor_descriptor({', '.join(args)})" + + +@dataclasses.dataclass +class BlockPtrOptions(BlockDescriptorOptions): + def replace_offset( + self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT + ) -> sympy.Expr: + """ + Replaces instances of {symt}_offset with the new expression. + """ + roffset = TritonSymbols.block_offsets[symt] + return sympy_subs(expr, {roffset: replacement}) + + def remove_roffsets(self, expr: sympy.Expr) -> sympy.Expr: + for symt in TritonSymbols.reduction_types: + expr = self.replace_offset(expr, sympy.Integer(0), symt) + return expr + + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should rn_offset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets = [self.remove_roffsets(offset) for offset in offsets] + args = [ + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" + + def advance_roffset(self, symt: SymT) -> sympy.Expr: + """ + Codegen string to pass to tl.advance(name, ...). + + Advance is the difference between offsets in each loop iteration. + To compute it, we replace rN_offset with multiples of RN_BLOCK. + Since we expect rN_offset to vary in range(0, rN_numel, RN_BLOCK), the first + iteration has rN_offset=0, while the second has rN_offset=RN_BLOCK. + """ + rblock = TritonSymbols.block_sizes[symt] + advance = [ + ( + self.replace_offset(offset, rblock, symt) + - self.replace_offset(offset, sympy.S.Zero, symt) + ) + for offset in self.offsets + ] + return advance + + +def triton_reshape( + value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr] +) -> str: + """Workaround https://github.com/triton-lang/triton/issues/2836""" + assert isinstance(old_shape, list) and isinstance(new_shape, list) + + old_shape_str = [V.kernel.index_to_str(shape) for shape in old_shape] + new_shape_str = [V.kernel.index_to_str(shape) for shape in new_shape] + + if old_shape_str == new_shape_str: + return value + if [s for s in new_shape_str if s != "1"] != old_shape_str: + return f"tl.reshape({value}, [{', '.join(new_shape_str)}])" + # rewrite to [:, None] syntax, which is less buggy + idx = 0 + expand = [] + for size in new_shape_str: + if idx < len(old_shape_str) and size == old_shape_str[idx]: + expand.append(":") + idx += 1 + else: + assert size == "1" + expand.append("None") + assert idx == len(old_shape_str) + return f"{value}[{', '.join(expand)}]" + + +def enable_pdl_codegen(): + if not torch._inductor.config.triton.enable_pdl: + return False + major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) + return major >= 9 + + +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem +class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_Float(self, expr: sympy.Expr) -> str: + if expr.is_integer: + # sympy considers 0.0 to be integer, but triton doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + ret = str(int(expr)) + elif config.is_fbcode() and torch.version.hip: + ret = f"{expr}" + else: + ret = f"tl.full([], {expr}, tl.float64)" + return ret + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5) + return f"{s}.to(tl.float64)" + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.remainder_integer({quot_s}, {div_s})" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + assert expr.is_integer + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + # work around for https://github.com/pytorch/pytorch/issues/165738 + if torch.xpu.is_available(): + return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))" + return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))" + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + if expr.args[0].is_Integer: + return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})" + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_Where(self, expr: sympy.Expr) -> str: + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"tl.where({c}, {p}, {q})" + + def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: + """ + Helper for max/min code generation. + cmp: > or < + """ + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + cls = type(expr) + a = self._print(cls(*expr.args[:mid])) + b = self._print(cls(*expr.args[mid:])) + + # Use a macro so we can propagate constexprs. + # https://github.com/triton-lang/triton/issues/3815 + a, b = tuple(f"({x})" for x in (a, b)) + assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'" + return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" + + def _print_Min(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, "<") + + def _print_Max(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, ">") + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"tl_math.abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}" + + +texpr = TritonPrinter().doprint + + +def triton_compute_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type and upcast [b]float16 to float32""" + return triton_type(upcast_compute_type(dtype)) + + +def triton_store_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with fix for storing tl.bool""" + if dtype == torch.bool: + dtype = torch.int8 + return triton_type(dtype) + + +def upcast_acc_dtype(dtype: torch.dtype) -> torch.dtype: + """Implicit upcasts used for Triton reduction types""" + if is_integer_dtype(dtype) and dtype.is_signed and dtype.itemsize <= 4: + return torch.int32 + return upcast_compute_type(dtype) + + +def triton_acc_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with reduction upcasts""" + return triton_compute_type(upcast_acc_dtype(dtype)) + + +def low_precision_fp(dtype: torch.dtype) -> bool: + return dtype.itemsize <= 2 and dtype.is_floating_point + + +def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: + if not isinstance(var, CSEVariable): + return False + + dtype = var.dtype + return low_precision_fp(dtype) if isinstance(dtype, torch.dtype) else False + + +class TritonCSEVariable(CSEVariable): + def __init__( + self, + name: str, + bounds: ValueRanges[Any], + dtype: torch.dtype, + shape: BlockShapeType = None, + ) -> None: + super().__init__(name, bounds, dtype, shape=shape) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: OrderedSet[str] = OrderedSet() + assert dtype is not None, "TritonCSEVariable must have dtype" + assert shape is not None, "TritonCSEVariable must have shape" + + def update_on_args(self, name, args, kwargs): + for arg in args: + if isinstance(arg, TritonCSEVariable): + self.mask_vars.update(arg.mask_vars) + elif isinstance(arg, sympy.Symbol): + # most of the time index vars don't need masks associated with them + # however, when index vars are used to compute indices for indirect reads + # those reads should subsequently be masked, + for symt in TritonSymbols.block_types: + if symbol_is_type(arg, symt): + self.mask_vars.update([f"{prefix_str[symt]}mask"]) + break + + +def get_dtype_handler() -> DtypePropagationOpsHandler: + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + return DtypePropagationOpsHandler() + + +def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]: + """ + Codegen helper to upcast arguments to float32, depending on the config and dtype. + This decorates tl.math/libdevice codegen functions. + """ + + def needs_upcast(var) -> bool: + return ( + not config.triton.codegen_upcast_to_fp32 + and isinstance(var, CSEVariable) + and var.dtype in (torch.float16, torch.bfloat16) + ) + + def maybe_upcast_arg(var) -> str: + upcast_string = ".to(tl.float32)" if needs_upcast(var) else "" + return f"{var}{upcast_string}" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + # Record that this function only supports float32 and float64. + OpDtypeSupport.register_upcast(func, convert_output) + + def wrapped(*args, **kwargs) -> str: + # Optionally upcast args to float32. + upcast_args = [maybe_upcast_arg(arg) for arg in args] + upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()} + + # Call the decorated function, optionally downcasting the result. + result = func(*upcast_args, **upcast_kwargs) + any_needs_upcast = convert_output and any( + needs_upcast(var) for var in itertools.chain(args, kwargs.values()) + ) + result_dtype = ( + None + if not any_needs_upcast + else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs) + ) + needs_downcast = result_dtype not in (torch.float32, None) + downcast_string = ( + f".to({triton_type(result_dtype)})" + if needs_downcast and result_dtype is not None + else "" + ) + return f"{result}{downcast_string}" + + return wrapped + + return decorator # type: ignore[return-value] + + +class TritonOverrides(OpOverrides): + """Map element-wise ops to Triton e.g., ops.to_dtype(x,...) -> x.to(...)""" + + _LOG_2_E = math.log2(math.e) + + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + def _get_min_elements_per_thread( + src_dtype: torch.dtype, dst_dtype: torch.dtype + ) -> int: + if src_dtype == dst_dtype: + # No data type conversion is needed. No requirements on min_elem_per_thread. + return 0 + + # fp8 data type conversions has min_elem_per_thread requirements. + # Refer to Triton implementations here: + # https://github.com/triton-lang/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. + assert not ( + src_dtype in fp8_dtypes + and dst_dtype in fp8_dtypes + and src_dtype != dst_dtype + ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" + if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: + return 4 + if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: + return 2 + # No requirements on min_elem_per_thread. + return 0 + + if src_dtype is not None: + # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). + # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions + # in the same kernel. + V.kernel.min_elem_per_thread = max( + _get_min_elements_per_thread(src_dtype, dtype), + V.kernel.min_elem_per_thread, + ) + + if dtype == torch.bool: + return f"({x} != 0)" + elif dtype == torch.uint8 and ( + src_dtype is not None and src_dtype.is_floating_point or src_dtype is None + ): + # to work around llvm uint conversion semantics that produces 0's for negative + # values when converting from floating types. + # optimization - if source type is known and it's not a floating type, then + # do not apply conversion to the intermediate type. + return f"{x}.to(tl.int16).to(tl.uint8)" + + if use_compute_types: + out_dtype = triton_compute_type(dtype) + else: + out_dtype = triton_store_type(dtype) + + return f"{x}.to({out_dtype})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + assert src_dtype.itemsize == dtype.itemsize + # We may promote float16 or bfloat16 to float32 and cause the + # bitwidth of dtype to be different from the input tensor (i.e. float32). + # In such as case, we will have to convert the input tensor to + # its src_type, perform bitcast, and then convert the bit-casted + # tensor back to float to ensure we use values with the right precision. + if x.dtype != src_dtype: + x = f"{x}.to({triton_type(src_dtype)})" + + out = f"{x}.to({triton_type(dtype)}, bitcast=True)" + if upcast_compute_type(dtype) != dtype: + out = f"{out}.to({triton_type(upcast_compute_type(dtype))})" + + return out + + @staticmethod + def _shaped_constant(value, dtype, shape): + type_ = torch._prims_common.dtype_to_type(dtype) + triton_val = constant_repr(type_(value)) + triton_type = triton_compute_type(dtype) + + if triton_type == "tl.float32": + # Float constants are always f32 in triton + return triton_val + + # NOTE: We use a tensor here in order to get the expected type. + # Otherwise, e.g. float64 constants would be truncated to float32. + if value < 0 and not dtype.is_signed: + triton_signed_type = f"tl.{triton_type[4:]}" + return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" + else: + return f"tl.full({shape}, {triton_val}, {triton_type})" + + @classmethod + def constant(cls, value, dtype): + return cls._shaped_constant(value, dtype, shape=[]) + + @staticmethod + @maybe_upcast_float32() + def abs(x): + return f"tl_math.abs({x})" + + # TODO - register these ops as having divergent dtype + # output if doing graph pass to remove consecutive casts + + @staticmethod + def truediv(x, y): + x_dtype = getattr(x, "dtype", None) + y_dtype = getattr(y, "dtype", None) + + if ( + x_dtype == torch.float32 + and y_dtype == torch.float32 + and config.emulate_divison_rounding + ): + # x / y in Triton is lowered to div.full which is approx + # we want div_rn to adhere with eager + out = f"triton.language.div_rn({x}, {y})" + else: + out = f"({x} / {y})" + + # Workaround here since the functionality of div_rn has not ready on XPU. + # TODO: remove this workaround after https://github.com/intel/intel-xpu-backend-for-triton/issues/5306 + # resolved. + if torch.xpu.is_available(): + out = f"({x} / {y})" + + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().truediv(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + + return out + + @staticmethod + def mod(x, y): + out = f"({x} % {y})" + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().mod(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + return out + + @staticmethod + @maybe_upcast_float32() + def exp(x): + """ + When use_fast_math, use the ftz (flushing to zero) variant + of exponent computation. + + Check https://github.com/triton-lang/triton/issues/5735 for + more details. + """ + if config.use_fast_math: + return f"tl_math.exp({x})" + else: + return f"libdevice.exp({x})" + + @staticmethod + @maybe_upcast_float32() + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + @maybe_upcast_float32() + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + @maybe_upcast_float32() + def sqrt(x): + # work around for https://github.com/pytorch/pytorch/issues/165738 + if torch.xpu.is_available(): + return f"libdevice.sqrt({x})" + return f"tl.sqrt_rn({x})" + + @staticmethod + def relu(x): + bug = config.triton.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + # NB: this only triggers runtime error as long as input + # is not all zero + return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' + elif bug == "accuracy": + return f"{x} + 1" + elif bug is None: + return ops.maximum(ops.constant(0, torch.int32), x) + else: + raise AssertionError( + f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"tl.where({a}, {b}, {c})" + + @staticmethod + def dot(a, b): + """ + Triton code generation for lowering ops.dot to tl.dot. + + The logic is as follows: + + 1. Downcasting for performance + If the data was previously upcasted to fp32, we downcast back to the + original dtype (e.g., fp16 or bf16) for better performance. While + surrounding operations may run in fp32, matmul itself is executed at the + original precision to optimize throughput. + + 2. Handling non-constant reduction masks + If the reduction mask is not constant and there was any operation between + tl.load and tl.dot, we zero out regions outside the mask using + tl.where(r0_mask, val, 0). + This ensures that values outside the mask do not contribute to the dot + product, preventing incorrect results. + + 3. Shape alignment for tl.dot + We massage shapes to match the tl.dot requirement of (Y, R) x (R, X). + Current codegen eagerly broadcasts tl.arange to create unique axes. We + reshape, transpose, or broadcast to align with the (Y, R) x (R, X) shape. + We avoid using 3D dot ((Z, Y, R) x (Z, R, X)) because 3D tl.dot has + poor performance. During batched matmul (bmm), we keep ZBLOCK=1 and call + the 2D dot kernel instead. + """ + assert V.kernel.is_native_matmul + orig_a, orig_b = a, b + + def is_where_needed(var): + # Skip if the variable doesn't have a reduction mask + if not any(map(prefix_is_reduction, var.mask_vars)): + return False + + reduction_range = V.kernel.range_trees[-1] + assert reduction_range.is_reduction + + # Skip if reduction mask was already constant + if V.kernel._has_constant_mask(reduction_range): + return False + + # Skip if the variable is already zeroed outside the mask + # (e.g., from tl.load(..., other=0.0)) + # TODO : track the value of outside of mask region with cse + for k, v in V.kernel.cse._cache.items(): + if v == var and "tl.load" in k and "other=0.0" in k: + return False + + return True + + def where_cond(var): + default = ir.Reduction.default_value("dot", var.dtype) + reduction_mask = [ + f"{tree.prefix}mask" + for tree in V.kernel.range_trees + if tree.is_reduction + ] + + assert len(reduction_mask) == 1, "don't tile reduction when native matmul" + + where_var = TritonKernelOverrides.where(reduction_mask[0], var, default) + return V.kernel.cse.generate( + V.kernel.compute, where_var, dtype=var.dtype, shape=var.shape + ) + + # When computing expressions like ((A+1) @ (B+2)), + # native codegen will do + # + # a = tl.load(..., r0_mask, other=0.0) + # b = tl.load(..., r0_mask, other=0.0) + # tmp0 = a+1 + # tmp1 = b+2 + # tmp2 = tl.dot(tmp0, tmp1) + # + # This produces incorrect results because outside of r0_mask is not zero. + # So before calling tl.dot, apply tl.where to zero out values properly. + # TODO: Optimize - We don't need both operands to be zeroed except NaN * 0 + if is_where_needed(orig_a): + a = where_cond(a) + if is_where_needed(orig_b): + b = where_cond(b) + + def reshape_transpose_broadcast_for_dot( + value, + initial_shape: Sequence[sympy.Expr], + final_shape: Sequence[sympy.Expr], + ) -> str: + """ + Generate a reshape, transpose, and broadcast for the tl.dot. + tl.dot requires specific shape requirement : (Y,R) x (R,X) + but the current triton codegen eagerly broadcast the tl.arange so + it needs to be reshaped to meet the requirement. + + This is done by three steps. + 1. remove the empty dimension (dim with size 1) and make it 2d with tl.reshape + 2. permute the dimension if needed (e.g., (X,R) -> (R,X)) with tl.trans + 3. broadcast if needed with broadcast_to. + - This shows up when matmul operand is broadcasted with torch.expand/repeat. + - e.g., torch.rand((16,)).expand(16,16) @ B + + e.g., (Y,1,R), (Y,R) -> tl.reshape(var, (Y,R)) + e.g., (1,X,R), (R,X) -> tl.trans(tl.reshape(var, (X,R))) + e.g., (1,X,1), (R,X) -> tl.broadcast_to(tl.trans(tl.reshape(var, (X,1))), (R,X)) + + TODO : eventually we want to remove this function when lazy broadcasting arrives + """ + + # Triton 3d dot is slower than 2d dot, so we want to keep block shape in 2d + # by fixing ZBLOCK=1 in the autotune config + if ZBLOCK in initial_shape: + initial_shape = ["1" if dim == ZBLOCK else dim for dim in initial_shape] + + if final_shape == [YBLOCK, RBLOCK]: + assert XBLOCK not in initial_shape, ( + "left tl.dot operand cannot depend on x" + ) + + shape_2d = ["1", "1"] + if YBLOCK in initial_shape: + shape_2d[0] = YBLOCK + if RBLOCK in initial_shape: + shape_2d[1] = RBLOCK + + # reshape it into 2d + value = triton_reshape(value, initial_shape, shape_2d) + + # broadcast if needed + broadcast_needed = shape_2d != [YBLOCK, RBLOCK] + if broadcast_needed: + value = f"tl.broadcast_to({value}, ({YBLOCK}, {RBLOCK}))" + + elif final_shape == [RBLOCK, XBLOCK]: + assert YBLOCK not in initial_shape, ( + "right tl.dot operand cannot depend on y" + ) + + shape_2d = ["1", "1"] + if XBLOCK in initial_shape: + shape_2d[0] = XBLOCK + if RBLOCK in initial_shape: + shape_2d[1] = RBLOCK + + # reshape it into 2d (X,R) + value = triton_reshape(value, initial_shape, shape_2d) + + # transpose to (R,X) + value = f"tl.trans({value})" + + # broadcast if needed + broadcast_needed = shape_2d != [XBLOCK, RBLOCK] + if broadcast_needed: + value = f"tl.broadcast_to({value}, ({RBLOCK}, {XBLOCK}))" + else: + raise NotImplementedError + + return value + + assert len(V.kernel.dense_size_list()) >= 3, "tl.dot can only do mm and bmm" + + XBLOCK = str(TritonSymbols.block_sizes[SymT.XBLOCK]) + YBLOCK = str(TritonSymbols.block_sizes[SymT.YBLOCK]) + ZBLOCK = str(TritonSymbols.block_sizes[SymT.ZBLOCK]) + RBLOCK = str(TritonSymbols.block_sizes[SymT.R0_INDEX]) + + a = V.kernel.cse.generate( + V.kernel.compute, + reshape_transpose_broadcast_for_dot(a, list(a.shape), [YBLOCK, RBLOCK]), + dtype=a.dtype, + shape=(YBLOCK, RBLOCK), + ) + + b = V.kernel.cse.generate( + V.kernel.compute, + reshape_transpose_broadcast_for_dot(b, list(b.shape), [RBLOCK, XBLOCK]), + dtype=b.dtype, + shape=(RBLOCK, XBLOCK), + ) + + if torch.backends.cuda.matmul.fp32_precision == "tf32": + input_precision = "tf32" + else: + input_precision = "ieee" + + return f'tl.dot({a}, {b}, input_precision="{input_precision}")' + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + triton_type = triton_compute_type(dtype) + input_refs = ", ".join([str(i) for i in inputs]) + if constraints is None: + constraints = ", ".join(["=r"] + ["r" for _ in inputs]) + return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 + + @staticmethod + @maybe_upcast_float32() + def cos(x): + return f"tl_math.cos({x})" + + @staticmethod + @maybe_upcast_float32() + def sin(x): + return f"tl_math.sin({x})" + + @classmethod + def index_expr(cls, expr, dtype): + raise NotImplementedError("ops.index_expr not implemented outside a kernel") + + @staticmethod + def masked(mask, body, other): + raise NotImplementedError("ops.masked not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + @maybe_upcast_float32() + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + @maybe_upcast_float32() + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + @maybe_upcast_float32() + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + @maybe_upcast_float32() + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + @maybe_upcast_float32() + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + @maybe_upcast_float32() + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + @maybe_upcast_float32() + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + @maybe_upcast_float32() + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + @maybe_upcast_float32() + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + @maybe_upcast_float32() + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + @maybe_upcast_float32() + def erfinv(x): + return f"libdevice.erfinv({x})" + + @staticmethod + @maybe_upcast_float32() + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + @maybe_upcast_float32() + def log2(x): + return f"libdevice.log2({x})" + + @staticmethod + @maybe_upcast_float32() + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + offset = f"({offset}).to(tl.uint32)" + return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + raise NotImplementedError("ops.load_seed not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + @maybe_upcast_float32() + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + @maybe_upcast_float32() + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + @maybe_upcast_float32() + def tanh(x): + cse_var = V.kernel.cse.varname_map.get(x) + if cse_var and hasattr(cse_var, "dtype"): + dtype = cse_var.dtype + else: + dtype = None + if ( + config.use_fast_math + and torch.version.hip + and get_triton_version() > (3, 5) + and dtype != torch.float64 + and dtype is not None + ): + # Requires upstream Triton 3.6+ for latest fast_tanhf support + # https://github.com/triton-lang/triton/pull/8551 + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" + + @staticmethod + @maybe_upcast_float32() + def sigmoid(x): + return f"tl.sigmoid({x})" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return ( + f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0" + ) + + @staticmethod + @maybe_upcast_float32() + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def log(x): + return f"tl_math.log({x})" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isinf(x): + return f"libdevice.isinf({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isnan(x): + return f"libdevice.isnan({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32() + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + @maybe_upcast_float32() + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def floordiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Similar to div_floor_kernel_cuda in pytorch core. + # Notice that // in triton behaves as truncdiv instead of floordiv + quot = f"{a} // {b}" + rem = f"{a} % {b}" + return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" + + @staticmethod + def sign(x): + z = ops.constant(0, torch.int32) + left = ops.to_dtype((ops.lt(z, x)), torch.int8) + right = ops.to_dtype((ops.lt(x, z)), torch.int8) + sub = ops.sub(left, right) + return f"{sub}.to({x}.dtype)" + + @staticmethod + @maybe_upcast_float32() + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Notice that // in triton behaves as truncdiv instead of floordiv + return f"{a} // {b}" + + @staticmethod + @maybe_upcast_float32() + def ceil(x): + return f"libdevice.ceil({x})" + + +TritonOverrides._initialize_pointwise_overrides("triton") + + +class TritonKernelOverrides(TritonOverrides): + """Map element-wise ops to Triton within a TritonKernel + + Unlike TritonOverrides, these assume the code is going to be inserted into + the body of the main triton kernel and so it may use indexing and mask + variables which are assumed to already be defined in the current scope. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # happens in __init__ unlike _initialize_pointwise_overrides + # because the libdevice registrations are populated during lowerings + self._setup_libdevice_routing() + + @classmethod + @functools.cache + def _setup_libdevice_routing(cls): + """Set up routing to libdevice implementations for fp64 inputs.""" + + from torch._inductor.codegen.common import OpDecompositions + + for fn_name in torch._inductor.utils.op_requires_libdevice_fp64: + assert hasattr(cls, fn_name) + original_impl = getattr(cls, fn_name) + + def decomposition_router(x, _original_impl, _fn_name): + if x.dtype != torch.float64: + return _original_impl(x) + else: + return getattr(OpDecompositions, _fn_name)(x).value + + if fn_name == "sigmoid": + assert hasattr(OpDecompositions, "sigmoid") + fn = functools.partial( + decomposition_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + continue + + def dtype_router(x, _original_impl, _fn_name): + if x.dtype == torch.float64: + return f"libdevice.{_fn_name}({x})" + else: + return _original_impl(x) + + fn = functools.partial( + dtype_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + + @classmethod + def constant(cls, value, dtype): + # NOTE: Cannot use shape=[] as it's not supported by triton-rocm + # We could use shape=[1] instead but starting with the correct + # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. + ndim = V.kernel.triton_tensor_ndim() + shape = [1] * ndim + return cls._shaped_constant(value, dtype, shape=shape) + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing( + expr, block_ptr=False, tma_compatibility_checker=None + ) + assert isinstance(indexing, IndexingOptions) + + shape: BlockShapeType + if indexing.expand_shape: + shape = indexing.expand_shape + else: + shape = TritonSymbols.get_block_shape(indexing.index) + + # Our sympy expr printing casts to the current kernel index dtype. + # we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype + index_dtype = V.kernel.get_index_dtype_as_torch_dtype() + dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype + + # after we emit this var we cast it to the correct dtype + orig = config.test_configs.runtime_triton_dtype_assert + try: + config.test_configs.runtime_triton_dtype_assert = False + var = V.kernel.cse.generate( + V.kernel.compute, + indexing.index_str, + bounds=get_bounds_index_expr(expr), + dtype=dtype, + shape=shape, + ) + finally: + config.test_configs.runtime_triton_dtype_assert = orig + + if dtype not in (torch.int32, torch.int64): + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, dtype), + dtype=upcast_compute_type(dtype), + shape=var.shape, + ) + else: + # TODO: we are not always consistent in enforcing that the output of the index expr printing + # results in the indexing dtype. So if we detect that we have an input which might type promote + # to a dtype other than indexing dtype, add a cast. + # Trying to avoid + dtype = index_dtype + for index_var in expr.free_symbols: + if symbol_is_type(index_var, SymT.TMP): + dtype = torch.promote_types( + dtype, V.kernel.cse.varname_map[index_var.name].dtype + ) + + if dtype != index_dtype: + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, index_dtype), + dtype=index_dtype, + shape=var.shape, + ) + + var.mask_vars = indexing.mask_vars + return var + + @staticmethod + def masked(mask, body, other): + if mask is not None and torch.version.hip is not None: + mask = V.kernel.cse.generate( + V.kernel.compute, + f"{mask}.to(tl.int1)", + dtype=torch.bool, + shape=mask.shape, + ) + + nodes = body.graph.find_nodes(op="output") + assert nodes, "graph for body does not contain an output" + + need_where = False + # If we have a tl.load with a masking operator and no other value + # we can add the mask here and the other value to the tl.load + # operator to save the branching cost. + for node in nodes: + for arg in node.args: + if arg.target != "load" or should_unwrap_unspec_arg(arg.args[1]): + need_where = True + break + + value = None if need_where else other + + with V.kernel.mask_loads(mask, value=value) as new_mask: + result = body() + + if need_where: + # Remove once CSEVariables track the dtype + if result.bounds.is_bool: + other = bool(other) + # Take dtype from result to prevent accidental promotion + other = V.kernel.cse.generate( + V.kernel.compute, + f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", + bounds=ValueRanges.wrap(other), + dtype=result.dtype, + shape=result.shape, + ) + ret = ops.where(new_mask, result, other) + else: + ret = result + + ret.mask_vars.discard(new_mask) + return ret + + @staticmethod + def load_seed(name, offset): + var = V.kernel.args.input(name) + return ( + f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" + ) + + @staticmethod + def frexp(x): + cache_key = f"frexp({x})" + if cse_val := V.kernel.cse.try_get(cache_key): + return cse_val + + mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) + exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) + V.kernel.compute.writeline( + f"{mantissa}, {exponent} = triton_helpers.frexp({x})" + ) + V.kernel.cse.put(cache_key, (mantissa, exponent)) + return (mantissa, exponent) + + @staticmethod + def partial_accumulate( + name: str, + reduction_type: str, + value: CSEVariable, + extra_meta: dict[str, Any], + ) -> None: + raise NotImplementedError + + +class HelperFunctions: + """An ordered set of helper functions.""" + + _templates_seen: dict[str, str] # Template code to function name + finalized_helpers: list[str] + + def __init__(self) -> None: + self._templates_seen = {} + self.finalized_helpers = [] + + def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str: + """This accepts a function definition with the function name + left as a format specifier e.g. + + @triton.jit + def {name}(arg0, arg1): + return arg0 + arg1 + + We add the templated code to the function set and return the name + assigned to that function. + + """ + existing_name = self._templates_seen.get(template_code) + if existing_name is not None: + # Don't duplicate existing helpers + return existing_name + + name = f"{base_name}{len(self.finalized_helpers)}" + self._templates_seen[template_code] = name + self.finalized_helpers.append(template_code.format(name=name)) + return name + + def __iter__(self): + return iter(self.finalized_helpers) + + def __getitem__(self, idx): + return self.finalized_helpers[idx] + + +@dataclasses.dataclass +class BlockParameters: + """ + Class representing ND block dimensions, for block pointer analysis. + """ + + shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + block_shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + strides: list[sympy.Expr] = dataclasses.field(default_factory=list) + offsets: list[sympy.Expr] = dataclasses.field(default_factory=list) + + @dataclasses.dataclass + class StrideSorter: + original_strides: list[int] + sort_idx: list[int] + revert_sort_idx: list[int] = dataclasses.field(init=False) + + def __post_init__(self): + assert len(self.original_strides) > 0 + assert len(self.sort_idx) == len(self.original_strides) + + identity_sort_idx = list(range(len(self.original_strides))) + self._is_identity = self.sort_idx == identity_sort_idx + + # Set revert_sort_idx + sorted_dims_by_strides_map = {k: i for i, k in enumerate(self.sort_idx)} + self.revert_sort_idx = [ + sorted_dims_by_strides_map[i] + for i in range(len(sorted_dims_by_strides_map)) + ] + + @property + def is_identity(self): + return self._is_identity + + @classmethod + @abstractmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + """Create a `StrideSorter` that can be used to sort block parameters.""" + + def sort(self, attr): + if not self.is_identity: + return [attr[i] for i in self.sort_idx] + return attr + + def revert(self, attr): + if not self.is_identity: + return [attr[i] for i in self.sort_idx] + return attr + + @dataclasses.dataclass + class IdentityStrideSorter(StrideSorter): + def __post_init__(self): + super().__post_init__() + + @classmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + return cls( + original_strides=original_strides, + sort_idx=list(range(len(original_strides))), + ) + + @dataclasses.dataclass + class TensorDecriptorStrideSorter(StrideSorter): + """ + Sorts BlockParameters dimensions with strides in descending order. + """ + + def __post_init__(self): + super().__post_init__() + + @classmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + """ + If the strides are not all known constants or if the strides are already + sorted in descending order, return identity sort. + + For example if block_shape @ strides is [ZBLOCK, XBLOCK, YBLOCK] @ [8, 1, 16] + The indices to sort the strides in descending order will be [2, 0, 1]. + The indices to revert back to the original order will be [1, 2, 0]. + """ + identity_sort = list(range(len(original_strides))) + try: + # TODO: even if the strides are not in descending order the strides + # may be tensor descriptor compliant + # i.e. innermost stride == 1 and outer strides 16 byte aligned + # We should benchmark the effect of applying a transpose to these + # cases vs leaving them unsorted. + sort_idx = utils.argsort_sym(shape_env, original_strides, reverse=True) + except AssertionError: + # Symbolic shapes, failed to evaluate comparison expression + sort_idx = identity_sort + + return cls( + original_strides=original_strides, + sort_idx=sort_idx, + ) + + def __add__(self, other: BlockParameters) -> BlockParameters: + """ + Concatenates block parameters. + """ + cls = type(self) + a, b = tuple(dataclasses.asdict(x) for x in (self, other)) + return cls(**{key: a[key] + b[key] for key in a}) + + def maybe_sort_with_stride_order( + self, stride_sorter_cls: type[StrideSorter], shape_env: ShapeEnv + ) -> tuple[BlockParameters, BlockParameters.StrideSorter]: + """ + Sort `BlockParameter` with stride_sorter_cls. Returns block parameters + as well as a `StrideSorter` which contains information on how the sort + can be reverted. + """ + stride_sorter = stride_sorter_cls.create(self.strides, shape_env=shape_env) + params = BlockParameters( + **{ + key: stride_sorter.sort(val) + for key, val in dataclasses.asdict(self).items() + } + ) + return params, stride_sorter + + def remove_dims(self, removable_dims: list[bool]) -> BlockParameters: + """ + Remove dimensions where removable_dims is True. + """ + + def filter_dims(it): + return [ + item + for item, is_removable in zip(it, removable_dims) + if not is_removable + ] + + return BlockParameters( + **{key: filter_dims(val) for key, val in dataclasses.asdict(self).items()}, + ) + + +class CooperativeReductionWorkspaceCache: + """ + The scratch space used for cooperative reductions can be reused + after two reduction loops. This keeps track of what can be reused. + """ + + def __init__(self, args): + self.args = args + self.current_loop = [] + self.prior_loop = [] + self.ready_for_reuse = collections.defaultdict(collections.deque) + self.loop_count = 0 + self.store_count = 0 + + def allocate(self, nbytes: sympy.Expr): + cached = self.ready_for_reuse.get(nbytes) + if cached: + return cached.popleft() + ws_name, _, ws_offset = self.args.workspace(nbytes, False) + self.current_loop.append((nbytes, ws_name, ws_offset)) + return (ws_name, ws_offset) + + def on_loop_end(self): + # Buffers can be reused after 2 loop ends + for nbytes, ws_name, ws_offset in self.prior_loop: + self.ready_for_reuse[nbytes].append((ws_name, ws_offset)) + self.prior_loop = self.current_loop + self.current_loop = [] + self.loop_count += 1 + + def increment_store_count(self): + prior = self.store_count + self.store_count += 1 + return prior + + +@dataclasses.dataclass +class FixedTritonConfig: + config: dict[str, int] + + def __getitem__(self, item): + return self.config[item] + + def __contains__(self, item): + return item in self.config + + +class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]): + """ + Subclasses CSE to apply the current load mask to the cache key to avoid CSEing + variables across separate masked blocks. + """ + + def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]: + if mask := V.kernel._load_mask: + return (cache_key, mask.name) + else: + return cache_key + + +@dataclasses.dataclass +class TMACompatibilityChecker: + """ + Checks if the TMA API can be used for load / store triton operations. + """ + + kernel: TritonKernel + dtype: torch.dtype + for_store: bool + force: bool + + def __post_init__(self): + self.failed_debug_prefix = "Cannot use TMA descriptor for load / store since: " + + # Also see Note: TMA API Restrictions for the below + def can_use_tma( + self, + ) -> bool: + if self.force: + return True + if not ( + V.graph.get_current_device_or_throw().type == "cuda" + and torch.cuda.get_device_capability()[0] >= 9 + and config.triton.use_tensor_descriptor + and config.assume_aligned_inputs + and has_triton_stable_tma_api() + # For CUDA The base ptr needs to be aligned + ): + log.debug( + ( + "%s Requires triton>=3.4.0, a CUDA device with cc>=9.0 and" + " `use_tensor_descriptor` and `assume_aligned_inputs` options enabled" + ), + self.failed_debug_prefix, + ) + return False + + # `no_x_dim` => XBLOCK=1, and for reductions this means only one element + # is to be stored . However the TMA API requires that + # the store will be 16 byte aligned, which is not attainable with a single + # element + if self.for_store and self.kernel.no_x_dim: + log.debug( + "%s stores with `no_x_dim` cannot load 16 bytes.", + self.failed_debug_prefix, + ) + return False + + return True + + def are_block_parameters_compatible( + self, + block_params: BlockParameters, + ) -> bool: + """ + Check if the block parameters are valid for TMA. + If force, we allow relying on symbolic hints equivalent + to what we check for Triton templates. + """ + if self.force: + strides = [ + V.graph.sizevars.symbolic_hint(st) for st in block_params.strides + ] + else: + strides = block_params.strides + + # The TMA API requires that the innermost stride is 1 + # and that the outer strides are 16 byte aligned + if not V.graph.sizevars.statically_known_equals(strides[-1], sympy.Integer(1)): + log.debug( + "%s TMA API requires innermost stride to be 1. Strides are: %s", + self.failed_debug_prefix, + strides, + ) + return False + + element_size = self.dtype.itemsize + for stride in strides[:-1]: + if not V.graph.sizevars.statically_known_equals( + ModularIndexing(stride * element_size, 1, sympy.Integer(16)), + sympy.Integer(0), + ): + log.debug( + "%s TMA API requires outer strides to be 16 byte aligned. Dtype bytes: %d, strides: %s", + self.failed_debug_prefix, + element_size, + strides, + ) + return False + + # Now compute the minimum value of the block type that is used + # in the innermost block size that can guarantee that 16 bytes of data + # can be loaded / stored. + # Start with finding the innermost block type + innermost_block_shape = block_params.block_shape[-1] + + # Pure singleton case + if V.graph.sizevars.statically_known_equals( + innermost_block_shape, sympy.Integer(1) + ): + log.debug( + "%s innermost block shape cannot load 16 bytes. Block shape: %s", + self.failed_debug_prefix, + block_params.block_shape, + ) + return False + + innermost_block_type = None + innermost_block_symt = None + for block_type_str in innermost_block_shape.free_symbols: + for block_symt in TritonSymbols.block_types: + if symbol_is_type(block_type_str, block_symt): + innermost_block_type = block_type_str + innermost_block_symt = block_symt + break + + assert innermost_block_type and innermost_block_symt, ( + f"{innermost_block_shape} expr must contain a single block type from {TritonSymbols.block_types}" + ) + + # For persistent reductions, the reduction block sizes are fixed at compile time + if self.kernel.persistent_reduction and not self.for_store: + # For a discontiguous tensor, a 1D block will be split across several + # dimensions, e.g. R0_BLOCK: + # block_shape=[XBLOCK, ((R0_BLOCK + 31)//32), Min(1, ((R0_BLOCK + 31)//32)), Min(32, R0_BLOCK)] + # The persistent R0_BLOCK will be a power of 2 that is at least r0_numel So it + # should be guaranteed that Min(32, R0_BLOCK) * element_size >= 16 + innermost_tree_prefix = prefix_str[innermost_block_symt] + tree_numel = None + for t in self.kernel.range_trees: + if t.is_reduction: + if t.prefix == innermost_tree_prefix: + tree_numel = t.numel + break + assert tree_numel is not None + persistent_rblock = self.kernel._get_persistent_RBLOCK(tree_numel) + innermost_block_bytes = ( + innermost_block_shape.subs({innermost_block_type: persistent_rblock}) + * element_size + ) + if not V.graph.sizevars.statically_known_geq( + innermost_block_bytes, sympy.Integer(16) + ): + log.debug( + "%s persistent reduction innermost block shape cannot load 16 bytes. Block shape: %s, persistent RBLOCK: %d", + self.failed_debug_prefix, + block_params.block_shape, + persistent_rblock, + ) + return False + + else: + # E.g. if the innermost block shape is Min(2, XBLOCK) + # then the TMA API can only be used if the dtype has an 8 byte element + # size so that 16 bytes of data can be loaded in the innermost dimension + try: + + def indexing_div_rep( + x: sympy.Expr, + y: sympy.Expr, + z: Optional[sympy.Expr] = None, + ) -> sympy.Expr: + div = x / y + if z: + div = div % z + return div + + solve_expr = innermost_block_shape * element_size - 16 + # Sympy cannot handle FloorDiv and ModularIndexing well, so simplify + solve_expr_simplified = solve_expr.replace( + FloorDiv, indexing_div_rep + ).replace(ModularIndexing, indexing_div_rep) + min_block_size = next_power_of_2( + int( + sympy.nsolve( + solve_expr_simplified, + innermost_block_type, + 1, + ) + ) + ) + + # TODO: min block size may be too large / introduce redundancy + if min_block_size > self.kernel.max_block( + prefix_str[innermost_block_symt] + ): + log.debug( + "%s the minimum block size to satisfy expression %s is too large: %d", + self.failed_debug_prefix, + solve_expr_simplified, + min_block_size, + ) + return False + + block_type_str = self.kernel.index_to_str(innermost_block_type) + # Check block sizes if the user has provided a fixed triton config + if self.kernel.fixed_config: + if min_block_size > self.kernel.fixed_config[block_type_str]: + log.debug( + "%s For block %s, fixed config block size %d is smaller " + "than the minimum required: %d", + self.failed_debug_prefix, + block_type_str, + self.kernel.fixed_config[block_type_str], + min_block_size, + ) + return False + else: + # Update the minimum block sizes that are passed to triton + # heuristics + self.kernel.tma_min_block_sizes[block_type_str] = max( + min_block_size, + self.kernel.tma_min_block_sizes.get(block_type_str, 1), + ) + + except ValueError: + log.debug( + "%s innermost block shape cannot load 16 bytes. Block params: %s", + self.failed_debug_prefix, + block_params.block_shape, + ) + return False + + return True + + def can_lift(self) -> bool: + """ + Can you lift the make_tensor_descriptor + call to the top of the kernel? This requires + being certain that all of the shape, stride, + and block_shape information is handled in arguments + or top level definitions. + + Right now we assume this is always possible if you force TMA. + """ + return self.force + + +class TritonKernel(SIMDKernel[TritonCSEVariable]): + """A class to represent a triton kernel and helpers to generate + triton kernel programmatically + """ + + overrides = TritonKernelOverrides # type: ignore[assignment] + helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True + tma_compatibility_checker_cls = TMACompatibilityChecker + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None + + def __init__( + self, + tiling: dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + hint_override: Optional[int] = None, + **kwargs, + ) -> None: + self.optimize_mask: bool = optimize_mask + self.fixed_config = fixed_config + super().__init__(tiling, **kwargs) + self.cse = TritonCSE(self.newvar_prefix, self.suffix) + # Cache of values that can be reused for the prologue. + self.prologue_cache: dict[str, str] = {} + self.prologue: IndentedBuffer = IndentedBuffer() + self.post_loop_combine: IndentedBuffer = IndentedBuffer() + self.post_loop_store: IndentedBuffer = IndentedBuffer() + self.outside_loop_vars = OrderedSet[Any]() + self.min_elem_per_thread = min_elem_per_thread + self.block_ptr_id = itertools.count() + self.block_ptr_to_buffer = dict[str, str]() + self.helper_functions = HelperFunctions() + self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = ( + collections.defaultdict(dict) + ) + self.tma_min_block_sizes = dict[str, int]() + self.hint_override = hint_override + self._load_counts: collections.Counter[str] = collections.Counter() + self._load_index = 0 + + # A set of autotuning hints to pass as part of triton_meta + self.autotune_hints = OrderedSet[AutotuneHint]() + self.triton_meta: Optional[dict[str, Any]] = None + + if self.inside_reduction: + self.codegen_reduction_numels(self.body) + + if self.cooperative_reduction: + self.init_cooperative_reduction() + + self.codegen_range_tree() + + if self.cooperative_reduction: + self.init_cooperative_reduction_mask() + + self.has_load_with_contiguous_rdim = False + # We track the store name since a store can be canceled later + self.stores_with_contiguous_rdim: list[str] = [] + + @staticmethod + def _has_stride1_on_rdim(index) -> bool: + # These analysis is only needed in deterministic mode so far + # to filter triton configs. Return false immediately to avoid + # increasing compilation time when the mode is off. + if not ( + config.deterministic or config.test_configs.force_filter_reduction_configs + ): + return False + support_vars = index.free_symbols + reduce_vars = [ + var + for var in support_vars + if symbol_is_type(var, TritonSymbols.reduction_types) + ] + + if len(reduce_vars) == 0: + return False + + # for expression "x0 + 150528*((x1//(s27*s38))) + 3*(ModularIndexing(x1, 1, s38)) + 672*(ModularIndexing(x1, s38, s27))" + # stride_vars will results in DivisionByZero error + try: + stride_vars = V.graph.sizevars.stride_vars(index, reduce_vars, support_vars) + except ZeroDivisionError: + return False + + return any(stride == 1 for stride in stride_vars) + + @property + def has_store_with_contiguous_rdim(self) -> bool: + return not all( + is_buffer_removed(name) for name in self.stores_with_contiguous_rdim + ) + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return triton_type(dtype) + + def should_use_cooperative_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_cooperative_reduction( + self.features + ) + + def init_cooperative_reduction(self): + """One time setup code for cooperative reductions.""" + assert self.cooperative_reduction + + # shift all the grids over since tl.program_id(0) is for rsplit + for tree in self.range_trees: + if tree.grid_dim is not None: + tree.grid_dim += 1 + + sem_count = self.numels["x"] + if self.fixed_config: + sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"]) + self.semaphores_name = self.args.semaphores(sem_count) + self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache( + self.args + ) + self.body.splice( + """\ + RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT) + RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2 + HAS_RSPLIT: tl.constexpr = RSPLIT > 1 + rsplit_id = tl.program_id(0) + num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK + rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK + rsplit_start = rsplit_chunk * rsplit_id + rsplit_end = rsplit_chunk * (rsplit_id + 1) + """, + ) + if any( + not self._has_constant_mask(tree) + for tree in self.range_trees + if tree.is_reduction + ): + self.body.writeline( + "rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)" + ) + + def init_cooperative_reduction_mask(self): + rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)" + if not self.no_x_dim: + rsplit_arange = f"{rsplit_arange}[None, :]" + self.body.writeline(f"rsplit_arange = {rsplit_arange}") + + if self._has_constant_xmask(): + self.body.splice( + """\ + if RSPLIT_IS_POWER_OF_2: + rsplit_mask: tl.constexpr = None + else: + rsplit_mask = rsplit_arange < RSPLIT + """ + ) + else: + assert not self.no_x_dim + self.body.writeline( + "rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)" + ) + + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + self.iteration_ranges_codegen_header(tree, self.body) + elif self.inside_reduction: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline( + f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}" + ) + + if self.inside_reduction: + if any(tree.is_loop for tree in self.range_trees): + # If the kernel contains loops, compute rbase. + rn_bases = self._get_reduction_symbols( + "base", integer=True, nonnegative=True + ) + rbase = self._flatten_reduction_indices(rn_bases) + self.body.splice(f"rbase = {self.index_to_str(rbase)}") + else: + # For looped reductions, indexing is deferred to the innermost loop. + self.codegen_reduction_indices(self.body) + + def need_numel_args(self): + """ + Indicate whether we need provide numel as arguments for the generated + kernel calls in the benchmark. + + Should be true for pointwise/reduction kernels but false for triton + matmul kernels. + """ + return True + + def should_use_persistent_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_persistent_reduction( + self.features, self.cooperative_reduction + ) + + def want_no_x_dim(self): + return ( + self.persistent_reduction + and len(self.numels) == self.num_reduction_dims + 1 + and self.fixed_config + and self.fixed_config["XBLOCK"] == 1 + ) + + @property + def assert_function(self) -> str: + return "tl.device_assert" + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape: Optional[Union[str, tuple[str]]] = None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + tma_compatibility_checker: Optional[TMACompatibilityChecker] = None, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.prepare_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: OrderedSet[str] = OrderedSet() + for var in sorted(index_vars, key=operator.attrgetter("name")): + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type( + var, TritonSymbols.reduction_types + ) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN, r0_N or r1_N + prefix_matches = [ + prefix_str[symt] + for symt in TritonSymbols.block_types + if symbol_is_type(var, symt) + ] + if len(prefix_matches) == 0: + pass + assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}" + mask_vars.add(f"{prefix_matches[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars: OrderedSet[str] = OrderedSet() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + ( + (block_ptr and self.allow_block_ptr and config.triton.use_block_ptr) + or ( + tma_compatibility_checker + and tma_compatibility_checker.can_use_tma() + ) + ) + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/triton-lang/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + + def match_affine_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches expressions of the form: + idx = s * xindex + + This implies stride (s,), and shape (XBLOCK,). + """ + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: + return None + + return BlockParameters( + shape=[range_tree.numel], + block_shape=[TritonSymbols.get_block_size(range_tree)], + strides=[stride], + offsets=[TritonSymbols.get_block_offset(range_tree)], + ) + + def match_mod_div_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. + + Example expression to match: + sN * ((rindex//(d1 * ... * d(N-1)))) + + s1 * ModularIndexing(rindex, 1, d1) + + ... + + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) + + This iterates over a block of shape (dN, ..., d1) and stride + (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are + wildcards that we match. + + Note that dN does not appear in the expression, but we solve for it + using range tree numels and the other dims. + """ + + index_var = range_tree.symbol() + + # Bound the possible number of dims. We use the following heuristics: + # - At least one dim for each range tree node. + # - At least one dim for every FloorDiv or ModularIndexing op. + # - At least 2 dims to pattern match. + denom, modulo = sympy.symbols( + "denom modulo", + cls=functools.partial(sympy.Wild, exclude=[index_var]), + ) + num_dims = max( + 2, + # range_tree.nodes only includes the entries for the range tree + # len(range_tree.nodes) <= self.range_tree_nodes + len(range_tree.nodes), + ( + index.count(FloorDiv(index_var, denom)) + + index.count(ModularIndexing(index_var, denom, modulo)) + ), + ) + + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, index_var, range_tree.numel, num_dims + ) + if match_result is None: + return None + + ( + dims, + strides, + block_index_exprs, + ) = match_result + slice_numels = BlockPatternMatcher.get_slice_numels(dims) + + # Check for applicable iteration range sizes. + # When mapping a 1D block into an ND one, we need to know that + # the number of elements is not changed. This means the slice numels of + # the ND iteration range must evenly divide the length of the 1D block. + # There are two cases where we can guarantee this: + # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, + # with n and m integers, then either numel is a multiple of XBLOCK, or numel + # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) + # 2. Numels are multiples of the maximum possible block size. + sizevars = V.graph.sizevars + max_block = self.max_block(range_tree.prefix) + if any( + not sizevars.statically_known_multiple_of(numel, max_block) + and not sizevars.statically_known_power_of_2(numel) + for numel in slice_numels + ): + return None + + # Compute the ND block shape from the linear block size. + # Use CielDiv to round leading dimensions up to 1. + # Non-leading dimensions are clamped to the size of the iteration range, + # while the leading dimension can exceed this to accommodate a larger + # block size. + linear_block_size = TritonSymbols.get_block_size(range_tree) + block_shape: list[sympy.Expr] = [ + CeilDiv(linear_block_size, slice_numels[0]) + ] + [ + sympy.Min(CeilDiv(linear_block_size, numel), dim) + for numel, dim in zip(slice_numels[1:], dims[1:]) + ] + + # Compute block offsets from {xyzr}offset and the matched expressions. + block_offsets: list[sympy.Expr] = [ + sympy_subs( + expr, {index_var: TritonSymbols.get_block_offset(range_tree)} + ) + for expr in block_index_exprs + ] + + return BlockParameters( + shape=dims, + block_shape=block_shape, + strides=strides, + offsets=block_offsets, + ) + + def match_block_subexpr( + expr: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Match a block indexing subexpression involving a single range tree. + """ + for match_func in ( + match_affine_block, + match_mod_div_block, + ): + match = match_func(expr, range_tree) + if match is not None: + return match + + return None + + def match_block_expr() -> Optional[BlockDescriptorOptions]: + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees() + + # Partition the index into subexpressions pertaining to each range tree. + # For example xindex * 5 + r0_index * 3 is partitioned to + # (xindex * 5, r0_index * 3). + index_subexprs = [ + BlockPatternMatcher.get_subexpr_involving_symbol( + index_relative_to_xyr_index, tree.symbol() + ) + for tree in range_trees + ] + + # Match each range tree's subexpression separately. + range_symbols = OrderedSet(tree.symbol() for tree in range_trees) + block_params = BlockParameters() + for tree, subexpr in zip(range_trees, index_subexprs): + # Reject mixed terms, e.g. xindex * r0_index. + # NB: the zero expression is allowed, for broadcasting. + if len(range_symbols.intersection(subexpr.free_symbols)) > 1: + return None + + # Match the subexpression for this range tree. + params = match_block_subexpr(subexpr, tree) + if params is None: + return None + block_params += params + + # Collect leftover terms as a constant offset. + offset = index_relative_to_xyr_index - sum(index_subexprs) + + # Form the block pointer or TMA descriptor. + self.filter_masks(mask_vars) + + options_class = ( + BlockPtrOptions + if config.triton.use_block_ptr + else TensorDescriptorOptions + ) + nonlocal tma_compatibility_checker + stride_sorter_cls: type[BlockParameters.StrideSorter] + if config.triton.use_block_ptr: + can_lift = False + stride_sorter_cls = BlockParameters.IdentityStrideSorter + else: + tma_compatibility_checker = cast( + TMACompatibilityChecker, tma_compatibility_checker + ) + can_lift = tma_compatibility_checker.can_lift() + + if ( + self.transpose_discontiguous_tensor_descriptors_override + is not None + ): + transpose_contiguous = ( + self.transpose_discontiguous_tensor_descriptors_override + ) + else: + transpose_contiguous = ( + config.triton.transpose_discontiguous_tensor_descriptor + ) + + # For templates: + # Only try transpose if we know the output shape + # in case we need to transpose the data. + if hasattr(self, "template_out_shape"): + transpose_contiguous &= copy_shape is not None + + stride_sorter_cls = ( + BlockParameters.TensorDecriptorStrideSorter + if transpose_contiguous + else BlockParameters.IdentityStrideSorter + ) + + options = options_class.create( + params=block_params, + constant_offset=offset, + range_trees=range_trees, + mask_vars=mask_vars, + get_max_block=self.max_block, + can_lift=can_lift, + stride_sorter_cls=stride_sorter_cls, + ) + if options_class == TensorDescriptorOptions: + tma_compatibility_checker = cast( + TMACompatibilityChecker, tma_compatibility_checker + ) + if not tma_compatibility_checker.are_block_parameters_compatible( + options.params + ): + return None + + return options + + # Return a block pointer, if indexing matches the pattern. + options = match_block_expr() + if options is not None: + return options + expand_str = None + expand_shape: BlockShapeType = None + index_str = self.index_to_str(index) + + def _get_expand_str(): + if copy_shape: + if isinstance(copy_shape, str): + return f"{copy_shape}.shape", None + else: + return "[" + ", ".join(str(c) for c in copy_shape) + "]", copy_shape + else: + return self.dense_size_str(), tuple(self.dense_size_list()) + + if is_sympy_integer_like(index): + # Integer indexing produces a size-1 scalar tensor with the same shape + # as the dense dimension. E.g, if dense_size = [YBLOCK, XBLOCK, R0_BLOCK], + # then we create tl.full([1, 1, 1], int). + # + # Exceptions: + # 1. If copy_shape is explicitly provided, use copy_shape expansion instead. + # 2. If the dense tensor has only one dimension (e.g., [XBLOCK]), + # broadcasting does not apply. For example: + # tl.arange(0, XBLOCK) + tl.full([1], int) # -> broadcasting error + # In this case, we fall back to dense indexing: + # tl.full([XBLOCK], int) + if copy_shape or len(self.dense_size_list()) == 1: + expand_str, expand_shape = _get_expand_str() + else: + expand_str = str([1] * len(self.dense_size_list())) + expand_shape = tuple([1] * len(self.dense_size_list())) + + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + if self.fixed_config and not self._has_constant_xmask(): + mask_vars = OrderedSet(["xmask"]) + else: + mask_vars = OrderedSet() + if self._load_mask: + mask_vars.add(self._load_mask) + return IndexingOptions( + index_str, + mask_vars, + expand_str, + has_rindex, + index, + expand_shape=expand_shape, + ) + + if need_dense and not have_dense: + if self.inside_reduction and self.is_native_matmul: + # This avoids full broadcasting (need_dense) when performing native matmul. + # For example, self._load_mask previously required tl.broadcast_to() in index_str. + # Due to the restrictions of tl.dot semantics, we only want to expand the block + # shape for the necessary axes. + # + # Previously: + # tmp1 = tl.load(ptr + tl.broadcast_to(r0, [YBLOCK, XBLOCK, R0_BLOCK]), + # r0_mask & tmp0 & xmask) + # + # Now: + # tmp1 = tl.load(ptr + tl.broadcast_to(r0, [1, 1, R0_BLOCK]), + # r0_mask & tmp0 & xmask) + # + # We achieve this by determining the required block shape through mask inspection. + # When a temporary variable appears in the mask (e.g., self._load_mask), we retrieve + # its true shape by inspecting tmp.mask_vars tracked by TritonCSEVariable. + # + # Caution: it may miss the correct block shape if the specific mask was constant + # and thus not tracked in TritonCSEVariable.mask_vars. + # + # TODO: Once the shape propagation PR lands, reimplement this logic: + # https://github.com/pytorch/pytorch/pull/152198 + mask_shape = mask_vars.copy() + if self._load_mask: + mask_shape.add(self._load_mask) + + xyzr = OrderedSet(["xmask", "ymask", "zmask", "r0_mask"]) + while not mask_shape.issubset(xyzr): + tmp_masks = mask_shape.difference(xyzr) + tmp = tmp_masks.pop() + assert isinstance(tmp, TritonCSEVariable) + mask_shape.discard(tmp) + mask_shape.update(tmp.mask_vars) + + # e.g., expand_list becomes ['ZBLOCK', 1, 1, 'R0_BLOCK'] + expand_list = ["1"] * len(self.dense_size_list()) + for mask in mask_shape: + assert isinstance(mask, str) + for tree in self.active_range_trees(): + if mask.startswith(tree.prefix): + dim = tree.tensor_dim + assert isinstance(dim, int) + expand_list[dim] = self.dense_size_list()[dim] + + expand_str = "[" + ",".join(map(str, expand_list)) + "]" + expand_shape = tuple(expand_list) + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + else: + expand_str, expand_shape = _get_expand_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + expand_shape_str, expand_shape = _get_expand_str() + index_str = f"tl.broadcast_to({index_str}, {expand_shape_str})" + mask_vars = dense_mask_vars + + if expand_shape is None: + if need_dense or have_dense: + _, expand_shape = _get_expand_str() + else: + expand_shape = () + + if override_mask: + mask_vars = OrderedSet([override_mask]) + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + return IndexingOptions( + index_str, + mask_vars, + expand_str, + has_rindex, + index, + expand_shape=expand_shape, + ) + + def codegen_block_ptr( + self, + name: str, + var: str, + indexing: Union[BlockPtrOptions, TensorDescriptorOptions], + other="", + ) -> tuple[str, str]: + """Generate a block pointer or tensor descriptor for Triton kernel operations. + + This method creates either a block pointer (for regular Triton operations) or + a tensor descriptor (for TMA operations) based on the indexing type. It handles + caching and reuse of descriptors for performance optimization. + + Args: + name: The name of the buffer/tensor being accessed + var: The variable name for the pointer + indexing: Block pointer options or tensor descriptor options containing + indexing information and boundary check settings + other: Additional parameters string (e.g., padding options) + + Returns: + A tuple containing: + - block_descriptor: The generated block pointer or tensor descriptor variable name + - other: Modified additional parameters string with boundary check options + """ + check = indexing.boundary_check() + if isinstance(indexing, TensorDescriptorOptions): + if check and other: + # The TMA API currently does not support padding values + # but the default is zero + assert other == ", other=0.0" + other = "" + else: + if not check: + # workaround https://github.com/triton-lang/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" + + if ( + self.inside_reduction + and self.range_trees[-1].is_loop + and indexing.has_rindex() + ) or indexing.can_lift: + if indexing.can_lift and var in self.prologue_cache: + # Check for epilogue subtiling to reuse the same + # tensor descriptor. + block_descriptor = self.prologue_cache[var] + else: + block_ptr_line = indexing.format(var, roffset=False) + block_var = self.cse.try_get(block_ptr_line) + + # Early return if block descriptor already exists + if block_var: + return str(block_var), other + + block_descriptor_id = next(self.block_ptr_id) + if isinstance(indexing, BlockPtrOptions): + block_descriptor = f"block_ptr{block_descriptor_id}" + else: + block_descriptor = f"tma_descriptor{block_descriptor_id}" + named_var = self.cse.namedvar( + block_descriptor, dtype=torch.uint64, shape=[] + ) + self.cse.put(block_ptr_line, named_var) + + line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}") + if indexing.can_lift: + self.prologue.writeline(line_body) + # Cache the descriptor for epilogue subtiling + self.prologue_cache[var] = block_descriptor + else: + self.body.writeline(line_body) + + if isinstance(indexing, BlockPtrOptions): + # Store for later use. If the buffer is removed the below advancements + # are no longer necessary + self.block_ptr_to_buffer[block_descriptor] = name + + # Generate block pointer advancements, for later use. + for symt in TritonSymbols.reduction_types: + advance_offsets = indexing.advance_roffset(symt) + + # Ignore identity advancements. + if all( + V.graph.sizevars.statically_known_equals( + offset, sympy.Integer(0) + ) + for offset in advance_offsets + ): + continue + + advancements = self.pointer_advancements[symt] + assert block_descriptor not in advancements, ( + f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'" + ) + advancements[block_descriptor] = advance_offsets + else: + block_descriptor = indexing.format(var) + return block_descriptor, other + + def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): + # Stores require an explicit broadcast. We do this in two phases: + # 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK, + # YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads. + # 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the + # result to the shape of the pointer. + value = f"tl.broadcast_to({value}, {indexing.final_shape})" + + # These dims no longer need broadcasting. + for idx, (dim, broadcast_dim) in enumerate( + zip(indexing.final_shape, indexing.broadcast_shape) + ): + if V.graph.sizevars.statically_known_equals(dim, broadcast_dim): + indexing.broadcasting_dims[idx] = False + + value = indexing.codegen_broadcast_and_reshape( + value, + indexing.final_shape, + indexing.block_shape, + allow_implicit=False, + for_store=True, + ) + + # workaround https://github.com/triton-lang/triton/issues/2814 + value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + if isinstance(indexing, BlockPtrOptions): + return f"tl.store({block_ptr}, {value}{other})" + return f"{block_ptr}.store({V.kernel.index_to_str(indexing.offsets)}, {value})" + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + assert isinstance(expr, sympy.Expr) + indexing = self.indexing(expr, block_ptr=False, tma_compatibility_checker=None) + assert isinstance(indexing, IndexingOptions) + + index_str = indexing.index_str + mask_str = indexing.mask_str if indexing.has_mask() else None + size_str = texpr(self.rename_indexing(size)) if upper else None + + # expr is already wrapped + line = self.indirect_assert( + index_str, "0" if lower else None, size_str, mask_str + ) + + buffer = self.get_load_buffer(indexing) + self.cse.generate(buffer, line, assignment=False, dtype=torch.int32) + + def get_load_buffer(self, indexing): + if indexing.has_indirect() or indexing.has_tmpmask(): + # Masked loads must come after the mask is computed + return self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indexing.has_rindex() + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + return self.body + else: + return self.loads + + def _handle_pdl_before_load(self, wait_buffer): + GDC_WAIT = "tl.extra.cuda.gdc_wait()" + self._load_index += 1 + if self.inside_reduction: + wait_buffer = self.body + if enable_pdl_codegen(): + if self._load_index == 1: + wait_buffer.writeline(GDC_WAIT) + + def _handle_pdl_after_load(self, launch_buffer, result_var): + GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()" + if self.inside_reduction: + launch_buffer = self.post_loop_combine + if enable_pdl_codegen(): + current_load_index = self._load_index + launch_if_last_load = DelayMaybeLine( + lambda: current_load_index == self._load_index, + f"0; {GDC_LAUNCH} # gdc launch for {result_var}", + ) + self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32) + + def partial_accumulate( + self, name: str, reduction_type, val, extra_meta: dict[str, Any] + ): + self.saved_partial_accumulate.append( + PartialAccumulate(name, reduction_type, val) + ) + + def load(self, name: str, index: sympy.Expr): + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + var = self.args.input(name) + load_counts = self._load_counts + load_counts[name] += 1 + make_line: Callable[[str], Union[str, DelayReplaceLine]] = identity + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + dtype = V.graph.get_dtype(name) + indexing = self.indexing( + index, + block_ptr=True, + tma_compatibility_checker=self.tma_compatibility_checker_cls( + self, + dtype, + for_store=False, + force=False, + ), + ) + + if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim( + indexing.index + ): + self.has_load_with_contiguous_rdim = True + + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + + # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold + # 1) We are doing broadcasting + # 2) It is a non-coalesced load. The intuition is that if it's + # non-coalesced, we will likely load each element multiple times in + # practice. + # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold + # 3.1) We are in a reduction loop + # 3.2) Its not its last use + # 3.3) This load will not be lifted to the body + # + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + if self.is_broadcasted(original_index): + ep = ", eviction_policy='evict_last'" + elif not is_coalesced: + ep = ", eviction_policy='evict_last'" + elif self.inside_reduction and self.range_trees[-1].is_loop: + + def decide_later(): + if load_counts[name] > expected_count and ( + has_rindex or indirect_indexing + ): + return "evict_last" + return "evict_first" + + expected_count = load_counts[name] + ep = ", eviction_policy=''" + make_line = functools.partial(DelayReplaceLine, "", decide_later) + else: + ep = "" + + if (has_tmpmask or has_rindex) and indexing.has_mask(): + if self._load_other: + other = f", other={constant_repr(self._load_other)}" + else: + other = ", other=0.0" + else: + other = "" + + """Check if the buffer we're about to load, has + more than one read dependency + NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1 + """ + has_read_deps = True + if config.triton.skip_l1_cache: + buffer_read_counts = self.features.buffer_read_counts() + has_read_deps = buffer_read_counts[name] > 1 + """Skip L1 cache if we're (pretty?) sure the data is used only once + """ + skip_l1_cache = ( + not self.is_broadcasted(original_index) + and not self.inside_reduction + and not has_read_deps + and is_coalesced # for indirect loads is_coalesced is False? + ) + cachemod = "" + if skip_l1_cache: + cachemod = ", cache_modifier='.cg'" + + append_broadcast = None + shape: BlockShapeType = None + + if should_unwrap_unspec_arg(name): + line = var + # unwrapped bf16/fp16 0d tensors are passed in as float32 scalars + # see triton_utils.py:signature_of + if dtype in (torch.float16, torch.bfloat16): + if config.triton.codegen_upcast_to_fp32: + dtype = torch.float32 + else: + line += f".to({triton_type(dtype)})" + shape = () + + else: + if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)): + block_descriptor, other = self.codegen_block_ptr( + name, var, indexing, other + ) + if isinstance(indexing, BlockPtrOptions): + line = f"tl.load({block_descriptor}{other}{ep}{cachemod})" + else: + line = f"{block_descriptor}.load({V.kernel.index_to_str(indexing.offsets)})" + line = indexing.codegen_broadcast_and_reshape( + line, + indexing.block_shape, + indexing.final_shape, + allow_implicit=True, + for_store=False, + ) + shape = indexing.final_shape + elif is_sympy_integer_like(original_index): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + shape = () + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" + + # The block shape of tl.load depends on the indexing expression. + # Inferring shape solely from the mask may miss cases where the mask is constant. + # Inferring from indexing.expand_shape alone may also fail when dense indexing is absent. + # so, iterate over variables in the indexexpr to accurately infer the block shape. + if indexing.expand_shape: + shape = indexing.expand_shape + else: + shape = TritonSymbols.get_block_shape(indexing.index) + + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + dtype = torch.float32 + if dtype == torch.bool and torch.version.hip is None: + # Workaround for https://github.com/triton-lang/triton/issues/2151 + # tl.load returns int8 when loading from pointer to int1 + # NOTE: Currently causes hangs on bool UTs for ROCm + line += ".to(tl.int1)" + dtype = torch.bool + + load_buffer = self.get_load_buffer(indexing) + self._handle_pdl_before_load(load_buffer) + result_var = self.cse.generate( + load_buffer, make_line(line), dtype=dtype, shape=shape + ) + self._handle_pdl_after_load(load_buffer, result_var) + if result_var.use_count > 1: + load_counts[name] -= 1 # don't double count cache hit + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate( + load_buffer, line, dtype=dtype, shape=indexing.expand_shape + ) + if indexing.mask_vars: + if dtype.is_floating_point: + zero = "0.0" + elif dtype == torch.bool: + zero = "True" + else: + zero = "0" + other_val = ( + constant_repr(self._load_other) if self._load_other else zero + ) + line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})" + result_var = self.cse.generate( + load_buffer, line, dtype=dtype, shape=result_var.shape + ) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + """ + store the 'value' to the memory location 'name', offset by some indexing expression 'index'. + """ + + var = self.args.output(name) + original_index = index + dtype = V.graph.get_dtype(name) + + tma_compatibility_checker = None + if mode is None or mode == "tma": + force = mode == "tma" + tma_compatibility_checker = self.tma_compatibility_checker_cls( + self, + dtype, + for_store=True, + force=force, + ) + indexing = self.indexing( + index, + dense_indexing=True, + block_ptr=mode is None, + tma_compatibility_checker=tma_compatibility_checker, + ) + + if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim( + indexing.index + ): + self.stores_with_contiguous_rdim.append(name) + + # Guard against write-after-read corruption in triton. + # See # https://github.com/triton-lang/triton/issues/1615 + # This triton bug means that a load which is broadcasted over multiple + # warps may see the result of a store that happens later in the triton + # program. The workaround is to add a barrier before storing, which + # enforces that all warps have already read the data. + is_inplace = name in self.args.inplace_buffers + is_broadcasted = self.is_broadcasted(original_index) + if is_inplace and is_broadcasted: + self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) + + if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)): + block_descriptor, other = self.codegen_block_ptr(name, var, indexing) + # block_ptr / tma descriptor stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_descriptor, value, other + ) + elif mode is None: + # If indexing is an integer and value has block shape larger than one, + # broadcasting fails. So, we manually broadcast indexing to the value shape. + # Without broadcast : + # tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32)), tmp4, xmask) # Fail + # + # With broadcast: + # tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to((XBLOCK,1)), tmp4, xmask) + indexing_str = indexing.index_str + if ( + is_sympy_integer_like(index) + and value.shape is not None + and not all(str(x) == "1" for x in value.shape) + ): + value_shape = ", ".join(map(str, value.shape)) + indexing_str += f".broadcast_to({value_shape})" + line = f"tl.store({var} + ({indexing_str}), {value}, {indexing.mask_str})" + elif mode == "atomic_add": + self.atomic_add_found = True + indexing_str = indexing.index_str + if ( + is_sympy_integer_like(index) + and value.shape is not None + and not all(str(x) == "1" for x in value.shape) + ): + value_shape = ", ".join(map(str, value.shape)) + indexing_str += f".broadcast_to({value_shape})" + line = f"tl.atomic_add({var} + ({indexing_str}), {value}, {indexing.mask_str}, sem='relaxed')" + else: + raise NotImplementedError(f"store mode={mode}") + + exit_stack = contextlib.ExitStack() + if not self.inside_reduction and self.cooperative_reduction: + exit_stack.enter_context(self.guard_cooperative_store(name, self.stores)) + + self.stores.writeline(DeferredLine(name, line)) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + exit_stack.close() + + def device_assert_async(self, cond, msg) -> None: + self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})") + + def guard_cooperative_store(self, name, buffer): + """ + For cooperative reductions only one thread block should write out the result. + We rotate which thread block does each write for better parallelism + """ + idx = self.cooperative_reduction_workspace_cache.increment_store_count() + buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) + return buffer.indent() + + def _combine_masks(self, *variables: Optional[CSEVariable]): + masks = None + for elem in variables: + if elem is None: + continue + if hasattr(elem, "mask_vars"): + if masks is None: + masks = elem.mask_vars + else: + masks = masks | elem.mask_vars + return masks + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + + # Triton performance for bucketize_binary_search is much better when the number + # of threads equals the number of elements. + # If we're trying to use a bucketize kernel, we should make sure that an + # autotuning config with num_elements_per_warp=(warp_size) exists. + self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD) + + boundaries_ptr = self.args.input(boundaries[0]) + boundary_size = self.index_to_str(boundaries[1]) + boundaries_underlying_numel = self.index_to_str(boundaries[2]) + boundary_stride = self.index_to_str(boundaries[3]) + sorter_ptr = self.args.input(sorter[0]) if sorter else "None" + sorter_stride = self.index_to_str(sorter[1]) if sorter else "None" + + if indexing_dtype == torch.int32: + triton_dtype = "tl.int32" + elif indexing_dtype == torch.int64: + triton_dtype = "tl.int64" + else: + raise NotImplementedError( + "Bucketize only supports indexing with int32 and int64" + ) + + self._handle_pdl_before_load(self.compute) + result = self.cse.generate( + self.compute, + f"triton_helpers.bucketize_binary_search({values}, " + f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, " + f"{boundary_indices}, " + f"{triton_dtype}, " + f"{right}, " + f"{sorter_ptr}, {sorter_stride}, " + f"{sorter_indices}, " + ")", + dtype=indexing_dtype, # type: ignore[attr-defined] + shape=values.shape, + ) + self._handle_pdl_after_load(self.compute, result) + + masks = self._combine_masks(values, boundary_indices, sorter_indices) + result.mask_vars = masks # type: ignore[attr-defined] + + return result + + def reduction_resize(self, value) -> str: + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + + nreduce = self.num_reduction_dims + sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce + return f"{value}[{', '.join(sizes)}]" + + def reduction_resize_and_shape(self, value, shape) -> tuple[str, BlockShapeType]: + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})", shape + + nreduce = self.num_reduction_dims + sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce + new_shape = ( + (*shape[: (ndims - nreduce)], *[1] * nreduce) if shape is not None else None + ) + return f"{value}[{', '.join(sizes)}]", new_shape + + def reduction_collapse_dims( + self, buffer, value: CSEVariable, dtype: torch.dtype + ) -> CSEVariable: + """ + Reshape to RBLOCK, collapsing all reduction dims. + """ + # This is not needed for 1D reductions. + if self.num_reduction_dims == 1: + return value + + target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims + initial_shape = self.dense_size_list() + target_shape = initial_shape[:target_ndim] + ["RBLOCK"] + return self.cse.generate( + buffer, + triton_reshape(str(value), initial_shape, target_shape), + dtype=dtype, + shape=tuple(target_shape), + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """ + codegen reduction of value to Triton according the reduction_type + """ + + def maybe_upcast(value: CSEVariable) -> CSEVariable: + # Math reductions in FP16/BF16 are less accurate because the Triton compiler does not + # automatically promote to FP32 for accumulation. Additionally, max/min reductions + # do not support FP16/BF16. We manually promote to FP32 here. + return ( + ops.to_dtype(value, torch.float32) + if value.dtype + in [ + torch.float16, + torch.bfloat16, + ] + else value + ) + + original_dtypes = [val.dtype for val in pytree.tree_leaves(value)] + value = pytree.tree_map(maybe_upcast, value) + if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes): + # Only promote FB16/BF16; do not promote other integer/boolean dtypes + src_dtype = torch.promote_types(src_dtype, torch.float32) + dtype = torch.promote_types(dtype, torch.float32) + + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix[0] + + # When we do native matmtul codegen, + # we don't want to keep the R0_BLOCK/R1_BLOCK in the accumulator. + # so instead of naively calling dense_size_str(), we filter out + # reduction block from accumulator and only keep (Y,X). + # In bmm (Z,Y,R)x(Z,R,X) case, we also remove z dimension from accumulator + # because 3d (Z,Y,X) tl.dot is somehow slower than 2d tl.dot. + # Instead, we force ZBLOCK to be always 1 during autotune. + dense_size_str: str + if self.is_native_matmul: + dense_sizes = self.dense_size_list() + assert len(dense_sizes) >= 3 + xy_sizes_only = [size for size in dense_sizes if "X" in size or "Y" in size] + dense_size_str = f"[{', '.join(xy_sizes_only)}]" + value_shape = tuple(xy_sizes_only) + else: + dense_size_str = self.dense_size_str() + value_shape = tuple(self.dense_size_list()) + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, + f"tl.broadcast_to({v}, {dense_size_str})", + dtype=v.dtype, + shape=value_shape, + ), + value, + ) + + logical_index = None + if reduction_type in ("argmin", "argmax"): + if isinstance(value, tuple): + value, logical_index = value + + dim = self.triton_tensor_ndim() - self.num_reduction_dims + root_op: str + + def final_reduction( + buffer, + value: CSEVariable, + result_type: Optional[torch.dtype], + ) -> tuple[str, Optional[torch.dtype], BlockShapeType]: + """ + Helper to generate a reduction call, e.g. tl.sum. + """ + triton_reduction_fn = get_triton_reduction_function(reduction_type) + + value = self.reduction_collapse_dims(buffer, value, dtype) + if reduction_type == "dot": + # Native matmul is a special case because accumulator shape is fixed to (Y,X) + is_bmm = len(self.dense_size_list()) == 4 + assert value.shape is not None + if is_bmm: + result = f"{value}[None,:,:,None]" # (Y,X) to (Z=1,Y,X,R=1) + shape = [1, *value.shape, 1] + else: + result = f"{value}[:,:,None]" # (Y,X) to (Y,X,R=1) + shape = [*value.shape, 1] + else: + result, shape = self.reduction_resize_and_shape( # type: ignore[assignment] + f"{triton_reduction_fn}({value}, {dim})", value.shape + ) + + if result_type is not None: + result = f"{result}.to({self.dtype_to_str(result_type)})" + else: + result_type = value.dtype + + return result, result_type, shape + + def final_reduction_define( + buffer, + result_var: CSEVariable, + value: CSEVariable, + result_type: Optional[torch.dtype], + ) -> None: + """ + Generate a reduction and assign it to an existing variable. + """ + value, _, _ = final_reduction(buffer, value, result_type) + buffer.splice(f"{result_var} = {value}") + + def final_argreduce(buffer, result_var, value, index): + value = self.reduction_collapse_dims(buffer, value, dtype) + index = self.reduction_collapse_dims(buffer, index, dtype) + buffer.splice( + f"""\ + {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f"{result_var}_idx")} + """ + ) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_shape = list(self.dense_size_list()) + result_shape[dim] = "1" + result_var: Any = self.cse.newvar( + dtype=torch_acc_type, shape=tuple(result_shape) + ) + result_var.mask_vars = OrderedSet( + var for var in masks if not prefix_is_reduction(var[0]) + ) + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + + def update_constant_dtype(constant, src_dtype, dst_dtype): + "update reduction constant mask value to match dst_dtype" + + # int is the only mask which may not fit within lower bitwidth, + # because float uses inf/-inf + if src_dtype.is_floating_point or src_dtype == torch.bool: + return constant + + if src_dtype == dst_dtype or constant == 0: + return constant + + if constant == torch.iinfo(src_dtype).max: + return torch.iinfo(dst_dtype).max + elif constant == torch.iinfo(src_dtype).min: + return torch.iinfo(dst_dtype).min + else: + return constant + + def _mask_value(value, default) -> CSEVariable: + default = update_constant_dtype(default, src_dtype, value.dtype) + default_str = self._map_tuple_or_scalar(constant_repr, default) + + return self.cse.generate( + self.compute, + where_cond(value, default_str), + dtype=value.dtype, + shape=value.shape, + ) + + masked_value: Union[CSEVariable, Sequence[CSEVariable]] + if reduction_type == "online_softmax_reduce": + # Don't generate mask value for online_softmax since we + # will fallback below + pass + elif isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] # type: ignore[arg-type] + elif reduction_type == "dot": + # Here, we don't perform the masking. + # Masking w/ where condition in native matmul is handled in ops.dot codegen. + # Since tl.dot performs reduction within the triton block, + # masking should happen before the tl.dot is called. + masked_value = self.cse.generate(self.compute, value, dtype=value.dtype) + else: + masked_value = _mask_value(value, default) + + if reduction_type in ("argmax", "argmin"): + assert isinstance(masked_value, CSEVariable) + accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() + if logical_index: + accumulator_index = f"({str(logical_index)}).to({self.dtype_to_str(accumulator_dtype)})" + else: + accumulator_index = str( + self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + dtype=accumulator_dtype, + shape=masked_value.shape, + ) + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + result_var.dtype = accumulator_dtype + elif reduction_type == "welford_reduce": + if self.cooperative_reduction: + # cooperative reductions require full welford for correctness + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + else: + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + result_var = self.welford_reduce_fallback(dtype, value) + elif reduction_type == "welford_combine": + assert isinstance(masked_value, Sequence) + (mean, m2, weight) = masked_value + result_var = tuple( + self.cse.generate(self.compute, value, dtype=dtype, shape=shape) + for value, shape in self._welford( + self.compute, mean, m2, weight, dim, dtype + ) + ) + elif reduction_type == "online_softmax_reduce": + # All data is loaded to register anyway, no need to do + # online softmax + result_var = self.prepare_softmax_twopass_fallback(dtype, value) + else: + assert isinstance(masked_value, CSEVariable) + _result, _dtype, _shape = final_reduction( + self.compute, masked_value, masked_value.dtype + ) + result_var = self.cse.generate( + self.compute, _result, dtype=_dtype, shape=_shape + ) + else: + accumulator = self.cse.namedvar( + f"_{result_var}", + dtype=torch_acc_type, + shape=tuple(self.dense_size_list()), + ) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + if reduction_type == "dot": + dense_sizes = self.dense_size_list() + assert len(dense_sizes) >= 3 + xy_sizes_only = [ + size for size in dense_sizes if "X" in size or "Y" in size + ] + accumulator.shape = tuple(xy_sizes_only) + dense_size_str = f"[{', '.join(xy_sizes_only)}]" + self.body.writeline( + f"{accumulator} = tl.full({dense_size_str}, {default}, {acc_type})" + ) + else: + self.body.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in ("argmax", "argmin"): + accumulator_index = f"_{result_var}_index" + index_dtype = self.features.select_index_dtype() + self.body.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, " + f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + # Use logical_index if it was unpacked, otherwise fall back to physical index + index_var = ( + f"({str(logical_index)}).to({self.dtype_to_str(index_dtype)})" + if logical_index is not None + else f"{reduction_range_prefix}index" + ) + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {index_var} + ) + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)} + """ + ) + final_argreduce( + self.post_loop_combine, result_var, accumulator, accumulator_index + ) + elif is_welford_reduction(reduction_type): + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + elif reduction_type == "online_softmax_reduce": + accumulator_max = f"_{result_var}_max" + accumulator_sum = f"_{result_var}_sum" + + # setup accumulator + self.body.writeline( + f"{accumulator_max} = tl.full({self.dense_size_str()}, float('-inf'), {acc_type})" + ) + self.body.writeline( + f"{accumulator_sum} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + + # combine + # Note, we pass config.use_fast_math to the JITFunction + # since a triton kernel can not access a config. + self.compute.splice( + f""" + {accumulator_max}_next, {accumulator_sum}_next = triton_helpers.online_softmax_combine( + {accumulator_max}, {accumulator_sum}, {value}, {config.use_fast_math} + ) + """ + ) + + # mask + self.compute.splice( + f""" + {accumulator_max} = {where_cond(f"{accumulator_max}_next", accumulator_max)} + {accumulator_sum} = {where_cond(f"{accumulator_sum}_next", accumulator_sum)} + """ + ) + + # reduce. Similar to the final reduction for coopereative + # reduction + result_max = result_var + result_sum = self.cse.newvar(dtype=dtype, shape=result_max.shape) + + result_var = self.online_softmax_reduce_final_reduction( + self.post_loop_combine, + result_max, + result_sum, + accumulator_max, + accumulator_sum, + dim, + dtype, + ) + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + if reduction_type == "dot": + self.compute.writeline(f"{accumulator} = {updated}") + else: + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator = self.cse.generate( + self.post_loop_combine, + f"{accumulator}.to(tl.int8)", + dtype=torch.int8, + shape=accumulator.shape, + ) + + final_reduction_define( + self.post_loop_combine, result_var, accumulator, None + ) + + if self.cooperative_reduction: + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + exit_stack = contextlib.ExitStack() + for buf in (self.post_loop_combine, self.post_loop_store): + # only do cooperative reduction combines if we have more than one thread block + buf.writeline("if HAS_RSPLIT:") + exit_stack.enter_context(buf.indent()) + + if reduction_type in ("argmax", "argmin"): + self.post_loop_combine.writeline( + f"{result_var}_bval = {self.reduction_resize(f'{result_var}_val')}" + ) + peer_val = self.codegen_cooperative_reduction_peer_combine( + f"{result_var}_bval", src_dtype, default + ) + index_dtype = self.features.select_index_dtype() + peer_idx = self.codegen_cooperative_reduction_peer_combine( + result_var, index_dtype, torch.iinfo(index_dtype).max + ) + final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx) + elif is_welford_reduction(reduction_type): + assert reduction_type == "welford_reduce" + result_mean, result_m2, result_weight = result_var + peer_mean = self.codegen_cooperative_reduction_peer_combine( + result_mean, + upcast_acc_dtype(src_dtype), + default[0], # type: ignore[index] + ) + peer_m2 = self.codegen_cooperative_reduction_peer_combine( + result_m2, + upcast_acc_dtype(src_dtype), + default[1], # type: ignore[index] + ) + peer_weight = self.codegen_cooperative_reduction_peer_combine( + result_weight, + upcast_acc_dtype(src_dtype), + default[2], # type: ignore[index] + ) + self.welford_reduce_final_reduction( + self.post_loop_store, + result_mean, + result_m2, + result_weight, + peer_mean, + peer_m2, + peer_weight, + dim, + dtype, + ) + elif reduction_type == "online_softmax_reduce": + result_max, result_sum = result_var + assert isinstance(default, Sequence) + peer_max = self.codegen_cooperative_reduction_peer_combine( + result_max, upcast_acc_dtype(src_dtype), default[0] + ) + peer_sum = self.codegen_cooperative_reduction_peer_combine( + result_sum, upcast_acc_dtype(src_dtype), default[1] + ) + self.online_softmax_reduce_final_reduction( + self.post_loop_store, + result_max, + result_sum, + peer_max, + peer_sum, + dim, + dtype, + ) + else: + peers = self.codegen_cooperative_reduction_peer_combine( + result_var, upcast_acc_dtype(src_dtype), default + ) + final_reduction_define(self.post_loop_store, result_var, peers, None) + exit_stack.close() + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + assert all(isinstance(x, TritonCSEVariable) for x in result_var) + self.outside_loop_vars.update(result_var) + + # Match output dtype with input dtype + if reduction_type in ("welford_reduce", "online_softmax_reduce"): + assert len(original_dtypes) == 1 + original_dtypes = len(result_var) * original_dtypes + + assert len(result_var) == len(original_dtypes) + for var, orig_dtype in zip(result_var, original_dtypes): + assert orig_dtype is not None + if var.dtype != orig_dtype: + self.post_loop_combine.writeline( + f"{var} = {var}.to({triton_compute_type(orig_dtype)})" + ) + else: + assert isinstance(result_var, TritonCSEVariable) + self.outside_loop_vars.add(result_var) + + # Match output dtype with input dtype + if result_var.dtype != original_dtypes[0]: + assert original_dtypes[0] is not None + self.post_loop_combine.writeline( + f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})" + ) + + return result_var + + def _online_softmax_reduce( + self, buffer, accumulator_max, accumulator_sum, dim, dtype: torch.dtype + ): + accumulator_max = self.reduction_collapse_dims(buffer, accumulator_max, dtype) + accumulator_sum = self.reduction_collapse_dims(buffer, accumulator_sum, dtype) + result_max, result_sum = [str(self.cse.newvar(dtype=dtype)) for _ in range(2)] + buffer.splice( + f""" + {result_max}, {result_sum} = triton_helpers.online_softmax_reduce( + {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math}) + {result_max} = {self.reduction_resize(f"{result_max}")} + {result_sum} = {self.reduction_resize(f"{result_sum}")} + """ + ) + + return result_max, result_sum + + def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype): + """ + Helper to codegen triton_helpers.welford. + """ + mean, m2, weight = ( + self.reduction_collapse_dims(buffer, value, dtype) + for value in (mean, m2, weight) + ) + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + + def reduced_shape(shape): + return tuple(shape[0:dim] + shape[dim + 1 :]) + + welford_results = [ + self.cse.newvar(dtype=dtype, shape=reduced_shape(value.shape)) + for value in (mean, m2, weight) + ] + buffer.writeline(f"{', '.join([str(r) for r in welford_results])} = {welford}") + + return tuple( + self.reduction_resize_and_shape(value, value.shape) + for value in welford_results + ) + + def welford_reduce( + self, result_var, reduction_type, value, where_cond, acc_type, dtype + ): + """Helper to codegen a welford reduction""" + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + accumulator = TritonCSEVariable( + f"{result_var}_mean", + shape=tuple(self.dense_size_list()), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) + accumulator_m2 = TritonCSEVariable( + f"{result_var}_m2", + shape=tuple(self.dense_size_list()), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) + accumulator_weight = TritonCSEVariable( + f"{result_var}_weight", + shape=tuple(self.dense_size_list()), + dtype=acc_type, + bounds=ValueRanges.unknown(), + ) + self.body.writeline( + f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """ + ) + else: + assert reduction_type == "welford_reduce" + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 + ) + """ + ) + self.compute.splice( + f"""\ + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)} + {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)} + """ + ) + result_mean = result_var + return self.welford_reduce_final_reduction( + self.post_loop_combine, + result_mean, + None, + None, + accumulator, + accumulator_m2, + accumulator_weight, + dim, + dtype, + ) + + def welford_reduce_final_reduction( + self, + buffer, + result_mean, + result_m2, + result_weight, + mean, + m2, + weight, + dim, + dtype, + ): + """Helper to codegen call to triton_helpers.welford""" + values = list(self._welford(buffer, mean, m2, weight, dim, dtype)) + + result_exprs = [result_mean, result_m2, result_weight] + for i, (result_expr, (value, shape)) in enumerate(zip(result_exprs, values)): + if result_expr is None: + result_expr = self.cse.newvar(dtype=dtype, shape=shape) + result_exprs[i] = result_expr + buffer.splice(f"{result_expr} = {value}") + + return tuple(result_exprs) + + def online_softmax_reduce_final_reduction( + self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype + ): + accumulator_max = self.reduction_collapse_dims(buffer, peer_max, dtype) + accumulator_sum = self.reduction_collapse_dims(buffer, peer_sum, dtype) + buffer.splice( + f""" + {result_max}, {result_sum} = triton_helpers.online_softmax_reduce( + {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math}) + {result_max} = {self.reduction_resize(f"{result_max}")} + {result_sum} = {self.reduction_resize(f"{result_sum}")} + """ + ) + return result_max, result_sum + + def max_rsplit(self): + if self.fixed_config: + return self.fixed_config["RSPLIT"] + return TRITON_MAX_RSPLIT + + def codegen_cooperative_reduction_peer_combine( + self, result_var, dtype, default_val + ) -> CSEVariable: + """ + Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different + column. After the barrier, every thread block loads the completed value so that it can compute the final + value independently. + """ + xnumel = self.numels["x"] + mask = "xindex < xnumel" if not self._has_constant_xmask() else None + + nbytes = xnumel * dtype.itemsize * self.max_rsplit() + ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes) + + self.post_loop_combine.splice( + f""" + {result_var}_ws = ({ws_name} + {self.index_to_str(ws_offset)}).to(tl.pointer_type({triton_type(dtype)})) + tl.store({result_var}_ws + (xindex * RSPLIT + rsplit_id), {result_var}, {mask}) + """, + strip=True, + ) + peers = self.create_cse_var( + f"{result_var}_peers", + shape=["XBLOCK", "RSPLIT"], + dtype=dtype, + bounds=ValueRanges.unknown(), + ) + self.post_loop_store.writeline( + f"{peers} = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " + f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" + ) + return peers + + def store_reduction( + self, + name: str, + index: sympy.Expr, + value: CSEVariable, + ): + assert self.inside_reduction + self.inside_reduction = False + dtype = V.graph.get_dtype(name) + indexing = self.indexing( + index, + block_ptr=True, + tma_compatibility_checker=self.tma_compatibility_checker_cls( + kernel=self, + dtype=dtype, + for_store=True, + force=False, + ), + ) + self.inside_reduction = True + var = self.args.output(name) + + exit_stack = contextlib.ExitStack() + if self.cooperative_reduction: + exit_stack.enter_context( + self.guard_cooperative_store(name, self.post_loop_store) + ) + + if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + assert isinstance(indexing, IndexingOptions) + + indexing_str = indexing.index_str + if ( + is_sympy_integer_like(index) + and value.shape is not None + and not all(str(x) == "1" for x in value.shape) + ): + value_shape = ", ".join(map(str, value.shape)) + indexing_str += f".broadcast_to({value_shape})" + + self.post_loop_store.writeline( + DeferredLine( + name, + f"tl.store({var} + ({indexing_str}), {value}, {indexing.mask_str})", + ) + ) + + exit_stack.close() + + def _lift_helper( + self, fn, values: tuple[CSEVariable, ...], dtypes: tuple[torch.dtype, ...] + ) -> str: + # Lift IR function for scan operations into a triton function + # in the global namespace + helper = IndentedBuffer() + helper.writeline("@triton.jit") + cse = CSE() + + args = [ + tuple( + cse.namedvar(f"arg{i}_{n}", dtype=dtype, shape=value.shape) + for n, (value, dtype) in enumerate(zip(values, dtypes)) + ) + for i in range(2) + ] + signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) + helper.writeline(f"def {{name}}({signature}):") + + overrides = TritonOverrides() + + # Build a name that changes depending on fn to workaround a triton bug + # where the combine_fn to reduce and scan is not hashed, and so different + # scan ops may collide in the triton cache. + # This is fixed with the latest triton pin, but not the triton-rocm pin. + helper_name = "_triton_helper_fn" + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + from torch._inductor.shape_propagation import ShapePropagationOpsHandler + + shape_handler = ShapePropagationOpsHandler() + dtype_handler = DtypePropagationOpsHandler() + + class CSEProxy(DefaultHandler): + def _default( + self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + nonlocal helper_name + helper_name += f"_{name}" + + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + + output_shape = getattr( + shape_handler, + name, + )(*args, **kwargs) + + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + dtype=output_dtype, + shape=output_shape, + ) + + with helper.indent(), V.set_ops_handler(CSEProxy()): + outputs = fn(*args) + outputs = ", ".join(str(output) for output in outputs) + helper.writeline(f"return {outputs}") + + return self.helper_functions.add(helper.getvalue(), base_name=helper_name) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + """ + Perform an associative scan on 'values'. + """ + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + broadcasted_values = [] + accumulators = [] + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + cse_compute = functools.partial(self.cse.generate, self.compute) + combine_helper_fn = self._lift_helper(combine_fn, values, dtypes) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + for value, dtype in zip(values, dtypes): + value_dtype = self.cse.generate( + self.compute, + f"{value}.to({triton_compute_type(dtype)})", + dtype=dtype, + shape=value.shape, + ) + value = self.cse.generate( + self.compute, + f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", + dtype=dtype, + shape=tuple(self.dense_size_list()), + ) + broadcasted_values.append(value) + + acc_type = triton_acc_type(dtype) + + if not self.persistent_reduction: + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + accumulator = self.cse.newvar(dtype=dtype, shape=reduced_size) + reduced_size_str = f"[{', '.join(reduced_size)}]" + + default = "float('nan')" if dtype.is_floating_point else "-1" + self.body.writeline( + f"{accumulator} = tl.full({reduced_size_str}, {default}, {acc_type})" + ) + + accumulators.append(accumulator) + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, values, masks, dtypes): + n = len(values) + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [ + self.cse.newvar(dtype=dtype, shape=value.shape) + for (dtype, value) in zip(dtypes, values) + ] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + partial_scan_vars = cse_multiple( + f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", + broadcasted_values, + masks, + dtypes, + ) + + if not self.persistent_reduction: + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + def _partial_scan_shape(var): + if var.shape is None: + return None + else: + shape = list(var.shape) + shape[-1] = "1" + return shape + + partial_reduce_vars = [ + cse_compute( + f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", + dtype=upcast_compute_type(partial_scan_var.dtype), + shape=_partial_scan_shape(partial_scan_var), + ) + for partial_scan_var in partial_scan_vars + ] + accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) + full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) + result_vars = [ + cse_compute( + f"tl.where(roffset > 0, {full_scan}, {partial_scan})", + dtype=partial_scan.dtype, + shape=partial_scan.shape, + ) + for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) + ] + for acc_next, accumulator, partial_reduce in zip( + accs_next, accumulators, partial_reduce_vars + ): + self.compute.writeline( + f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})" + ) + else: + result_vars = partial_scan_vars + + for result_var in result_vars: + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = OrderedSet(masks) + + return tuple(result_vars) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.sort not supported inside ops.masked" + assert self.persistent_reduction, ( + "ops.sort is only supported in persistent reductions" + ) + + cse_compute = functools.partial(self.cse.generate, self.compute) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + assert len(dtypes) == len(values) + broadcasted_values = [ + cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtypes[i], + shape=tuple(self.dense_size_list()), + ) + for i, value in enumerate(values) + ] + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, broadcasted_values, masks, dtypes): + n = len(broadcasted_values) + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [ + self.cse.newvar(dtype=dtype, shape=value.shape) + for dtype, value in zip(dtypes, broadcasted_values) + ] # type: ignore[attr-defined] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + assert self.range_trees[-1].is_reduction + rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel" + + if len(values) == 2: + line = ( + f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," + f" {rnumel}, {dim}, stable={stable}, descending={descending})" + ) + result_vars = cse_multiple(line, broadcasted_values, masks, dtypes) + else: + raise AssertionError("Unhandled sort") + + for result_var, input_var in zip(result_vars, values): + result_var.mask_vars = masks # type: ignore[attr-defined] + result_var.bounds = input_var.bounds + + return tuple(result_vars) + + def codegen_prologue(self, code: IndentedBuffer): + """ + Generate the output from prologue. This should be + extracted from the subgraph, which is why this is + partitioned from codegen_body. + """ + if not self.prologue: + return + + code.splice(self.prologue) + self.prologue.clear() + self.prologue_cache.clear() + + def codegen_body(self): + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if not ( + self.indexing_code + or self.loads + or self.stores + or self.compute + or self.post_loop_combine + or self.post_loop_store + ): + return + + loop_trees = [tree for tree in self.range_trees if tree.is_loop] + if self.mix_order_reduction: + assert self.persistent_reduction, ( + "Mix order reduction requires persistent reduction" + ) + accumname2var = {} + for idx, partial_accum in enumerate(self.saved_partial_accumulate): + reduction_type = partial_accum.reduction_type + default = ir.Reduction.default_accumulator(reduction_type, torch.float) + default = self._map_tuple_or_scalar(constant_repr, default) + name = f"accum{idx}" + self.body.writeline( + f"{name} = tl.full([R0_BLOCK], {default}, tl.float32)[None, :]" + ) + accumname2var[name] = self.cse.namedvar( + name, dtype=torch.float, shape=("1", "R0_BLOCK") + ) + self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)") + self.body.writeline( + "for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):" + ) + with self.body.indent(offset=1): + # generate xmask if it's not constant + if not self._has_constant_xmask(): + entry = self.range_trees[0] + assert entry.prefix == "x" + x = entry.prefix + self.body.writeline(f"{x}mask = {entry.name} < {x}numel") + self.body.splice(self.indexing_code) + self.body.writelines( + [ + "xindex += XBLOCK", + ] + ) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.post_loop_store) + + # no need to sum if XBLOCK == 1, or does that matter? + for idx, partial_accum in enumerate(self.saved_partial_accumulate): + var = partial_accum.value + name = f"accum{idx}" + combine_fn = ir.get_reduction_combine_fn( + partial_accum.reduction_type, torch.float + ) + triton_reduction_function = get_triton_reduction_function( + partial_accum.reduction_type, + ) + newval = self.cse.generate( + self.body, + f"{triton_reduction_function}({var}, 0)", + dtype=var.dtype, + shape=("R0_BLOCK",), + ) + import unittest + + with unittest.mock.patch.object(self, "compute", self.body): + updated = combine_fn( + accumname2var[name], + newval, + ) + self.body.writeline(f"{name} = {updated}") + + for idx in range(len(self.saved_partial_accumulate)): + self.body.writeline( + f"tl.store(ws_ptr + (tl.program_id(0) + {idx} * tl.num_programs(0)) * r0_numel + r0_index, accum{idx}, r0_mask)" + ) + + elif self.inside_reduction and len(loop_trees) > 0: + # Write the loop headers. + for level, tree in enumerate(loop_trees): + with self.body.indent(offset=level): + prefix = tree.prefix + loop_start = "rsplit_start" if self.cooperative_reduction else "0" + loop_end = ( + "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" + ) + self.body.writeline( + f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + ) + with self.body.indent(offset=level + 1): + self.iteration_ranges_codegen_header(tree, self.body) + + # The innermost loop performs the reduction. + with self.body.indent(offset=len(loop_trees)): + self.codegen_reduction_indices(self.body) + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + # Write loop suffixes. + for level, tree in reversed([*enumerate(loop_trees)]): + with self.body.indent(offset=level + 1): + # Advance pointers at the end of each loop. + for block_ptr, advancement in self.pointer_advancements[ + tree.symt + ].items(): + # Subtract any advancements made in the previous loop level. + if level < len(loop_trees) - 1: + prev_tree = loop_trees[level + 1] + prev_advancement = self.pointer_advancements[ + prev_tree.symt + ][block_ptr] + prev_block = TritonSymbols.get_block_size(prev_tree) + prev_num_iter = CeilDiv(prev_tree.numel, prev_block) + advancement = [ + cur - prev * prev_num_iter + for cur, prev in zip(advancement, prev_advancement) + ] + + self.body.writeline( + DeferredLine( + self.block_ptr_to_buffer[block_ptr], + f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})", + ) + ) + + # Invalidate any cache entries that came from inside the loop. + self.cse.invalidate(self.outside_loop_vars) + tree.cache_clear() + else: + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.post_loop_combine) + if self.cooperative_reduction and ( + self.post_loop_combine or self.post_loop_store + ): + sem_ptr = f"{self.semaphores_name} + tl.program_id(1)" + self.body.splice( + f""" + if HAS_RSPLIT: + triton_helpers.x_grid_barrier({sem_ptr}) + """, + strip=True, + ) + self.cooperative_reduction_workspace_cache.on_loop_end() + if not self.mix_order_reduction: + self.body.splice(self.post_loop_store) + self.indexing_code.clear() + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_combine.clear() + self.post_loop_store.clear() + + def kernel_benchmark_extra_args(self) -> list[str]: + args = [] + if self.need_numel_args(): + numel_args: list[sympy.Expr] = [] + self.add_numel_to_call_args("", numel_args, []) + for arg in numel_args: + if isinstance(arg, int): + args.append(str(arg)) + elif isinstance(arg, SymbolicCallArg): + hint = V.graph.sizevars.size_hint( + arg.inner_expr, + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + args.append(str(hint)) + elif isinstance(arg, sympy.Expr): + hint = V.graph.sizevars.size_hint( + arg, + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + args.append(str(hint)) + else: + raise ValueError(f"Unsupported numel argument type: {type(arg)}") + return args + + def codegen_kernel_benchmark(self, num_gb: Optional[float]) -> IndentedBuffer: + """ + Generates Python code for benchmarking this Triton kernel. + - Creates example inputs (random tensors, constants, sizes). + - Runs the kernel on the current GPU/stream. + - Prints runtime (ms) and throughput (GB/s) using `num_gb`. + Args: + num_gb (float): The number of gigabytes to use for throughput calculation. + Returns: + IndentedBuffer: A buffer containing the generated Python benchmark code. + """ + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + size = V.graph.sizevars.size_hints( + buf.get_size(), + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + stride = V.graph.sizevars.size_hints( + buf.get_stride(), + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + result.writeline( + f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + size = V.graph.sizevars.size_hints( + const_tensor.size(), + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + stride = V.graph.sizevars.size_hints( + const_tensor.stride(), + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + result.writeline( + f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint( + arg_sig.expr, + hint_override=self.hint_override, + fallback=config.unbacked_symint_fallback, + ) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint( + arg_sig.count, hint_override=self.hint_override + ) + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + current_device = V.graph.get_current_device_or_throw() + index = current_device.index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def _get_heuristic(self): + if self.fixed_config: + return "fixed_config" + elif self.cooperative_reduction: + return "cooperative_reduction" + elif self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction" + elif self.inside_reduction: + return "reduction" + return "pointwise" + + @staticmethod + def inductor_meta_common(): + inductor_meta = { + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + "assert_indirect_indexing": config.assert_indirect_indexing, + "autotune_local_cache": config.autotune_local_cache, + "autotune_pointwise": config.triton.autotune_pointwise, + "autotune_remote_cache": config.autotune_remote_cache, + "force_disable_caches": config.force_disable_caches, + "dynamic_scale_rblock": config.dynamic_scale_rblock, + "max_autotune": config.max_autotune, + "max_autotune_pointwise": config.max_autotune_pointwise, + "min_split_scan_rblock": config.triton.min_split_scan_rblock, + "spill_threshold": config.triton.spill_threshold, + "store_cubin": config.triton.store_cubin, + "deterministic": config.deterministic, + "force_filter_reduction_configs": config.test_configs.force_filter_reduction_configs, + } + + if config.write_are_deterministic_algorithms_enabled: + inductor_meta["are_deterministic_algorithms_enabled"] = ( + torch.are_deterministic_algorithms_enabled() + ) + + if torch.version.hip is not None: + inductor_meta["is_hip"] = True + if config.is_fbcode(): + inductor_meta["is_fbcode"] = True + if config.profile_bandwidth: + inductor_meta["profile_bandwidth"] = config.profile_bandwidth + inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex + inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output + inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = ( + config.profile_bandwidth_with_do_bench_using_profiling + ) + if config.coordinate_descent_tuning: + inductor_meta["coordinate_descent_tuning"] = ( + config.coordinate_descent_tuning + ) + inductor_meta["coordinate_descent_search_radius"] = ( + config.coordinate_descent_search_radius + ) + inductor_meta["coordinate_descent_check_all_directions"] = ( + config.coordinate_descent_check_all_directions + ) + return inductor_meta + + def codegen_kernel(self, name=None) -> str: + """ + Convert the TritonKernel from Inductor SIMD IR to triton code, including inductor triton heuristics, imports, + metadata, and benchmarking infra. + """ + + code = IndentedBuffer() + + size_hints = {} + for prefix, numel in self.numels.items(): + if prefix_is_reduction(prefix) and not self.inside_reduction: + continue + + numel_hint = V.graph.sizevars.symbolic_hint(numel) + if not isinstance(numel_hint, (int, sympy.Integer)): + # This default heuristic hint was picked carefully: it is + # large, to ensure that we don't shrink the block size (since + # if you don't have many elements, it'd be wasteful to pick a + # large block size). Since we don't know how many elements we + # might have, we should be OK with some inefficiency to make + # sure we handle the large case well. 8192 is the largest + # block size we support, so we pick that. + # + # If we have a better hint for unbacked SymInts (e.g., because + # a user told us, or we are tracking upper bounds) we could + # use that here. + size_hint = 8192 + else: + size_hint = next_power_of_2(int(numel_hint)) + size_hints[prefix] = size_hint + + if name is None: + code.splice(gen_common_triton_imports()) + device_type = V.graph.get_current_device_or_throw().type + if device_type == "cpu": + code.splice("triton_helpers.set_driver_to_cpu()") + else: + code.splice("triton_helpers.set_driver_to_gpu()") + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + # mypy is unhappy about the sympy.Expr + # type for the key of the dict below + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + mutated_args: OrderedSet[str] = OrderedSet() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add( + cast(InplacedBuffer, self.args.inplace_buffers[mutation]).inner_name + ) + if mutation in self.args.output_buffers: + mutation_arg = self.args.output_buffers[mutation] + assert not isinstance(mutation_arg, RemovedArg) + mutated_args.add(mutation_arg) + + # Note: [Workspace Mutation] + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. Similarly, ZERO_PER_GRAPH requires the kernel to return + # the buffer back to its original state. + for argname, arg in zip(argdefs, signature): + if ( + isinstance(arg, WorkspaceArg) + and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL + ): + mutated_args.add(argname.name) + + mutated_args = sorted(mutated_args) + + for tree in self.active_range_trees(): + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(sizearg.name)) + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + + def add_constexpr_arg(arg_name): + # new versions (but not old versions) of Triton need constexprs included in the signature + if triton_version_uses_attrs_dict(): + signature.append(ConstexprArg(arg_name)) + argdefs.append(ArgName(arg_name, is_constexpr=True)) + + for tree in self.range_trees: + if tree.is_reduction and self.persistent_reduction: + # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels + continue + if tree.tensor_dim is None: + continue + + add_constexpr_arg(f"{tree.prefix.upper()}BLOCK") + + if self.cooperative_reduction: + add_constexpr_arg("RSPLIT") + + if self.mix_order_reduction: + add_constexpr_arg("RSPLIT_SIZE") + add_constexpr_arg("NUM_STAGES") + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype, argdefs=argdefs + ) + triton_meta: dict[str, Any] = { + "signature": triton_meta_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + "native_matmul": ( + torch._inductor.config.triton.native_matmul + and ("tl.dot" in str(self.body) or "tl.dot" in str(self.compute)) + ), + } + + # Skip memory optimization for forward of the training loop where we expect + # every new node will increase the peak memory and our greedy approach would + # introduce a lot of unnecessary cpu copies. + optimize_mem = V.graph.is_inference or V.graph.is_backward + + inductor_meta = { + "grid_type": self._get_grid_type().__name__, + # Triton will not accept an OrderedSet for autotune_hints + "autotune_hints": set(self.autotune_hints), # noqa: set_linter + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "optimize_mem": optimize_mem, + "no_x_dim": self.no_x_dim, + "atomic_add_found": self.atomic_add_found, + "num_load": self.num_load, + "num_store": self.num_store, + "num_reduction": self.num_reduction, + **self.inductor_meta_common(), + } + + if self.mix_order_reduction: + inductor_meta["RSPLIT_SIZE"] = self.rsplit_size + + if config.deterministic or config.test_configs.force_filter_reduction_configs: + inductor_meta["has_loadstore_with_contiguous_rdim"] = ( + self.has_load_with_contiguous_rdim + or self.has_store_with_contiguous_rdim + ) + + # Bail on 3d tiling, which has more complicated coalesce patterns + looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction + tiling_scores = self.tiling_scores + two_d_red = len(self.tiling) == 2 + if looped_red and two_d_red: + memory_stats = self.features.memory_stats(self.tiling) + dim_stats = memory_stats.persistent.memory.dim[0] + mem_ops_per_thread = dim_stats.count_per_thread + + if ( + tiling_scores is not None + and "x" in tiling_scores + and "r0_" in tiling_scores + ): + # large rblock inhibits xblock size, dont attempt if there is a decent amount of + # reads coalesced by xblock + r_coalesce_ratio = tiling_scores["r0_"] / max(tiling_scores["x"], 1) + contiguous_red = r_coalesce_ratio >= 8.0 + else: + from torch._inductor.runtime.hints import ReductionHint + + contiguous_red = ( + self.features.get_reduction_hint() == ReductionHint.INNER + ) + + looped_mem = memory_stats.looped.memory.bytes + persistent_mem = memory_stats.persistent.memory.bytes + # check that we save significant memory by doing persistent + saved_bytes_ratio = V.graph.sizevars.size_hint( + looped_mem, fallback=config.unbacked_symint_fallback + ) / max( + V.graph.sizevars.size_hint( + persistent_mem, fallback=config.unbacked_symint_fallback + ), + 1, + ) + + # TODO - rnumel should be reasonably close to power of 2 + if ( + # significant memory bandwidth savings + saved_bytes_ratio >= 1.3 + and contiguous_red + # TODO - need more detailed register analysis + and V.graph.sizevars.statically_known_leq( + self.features.reduction_numel, 32768 + ) + # We will already generate a persistent config in this case + and V.graph.sizevars.statically_known_gt( + self.features.reduction_numel, 2048 + ) + and mem_ops_per_thread <= 10 + ): + inductor_meta["add_persistent_rblock"] = True + + if self.tiling_scores: + inductor_meta["tiling_scores"] = self.tiling_scores + + if self.tma_min_block_sizes: + inductor_meta["tma_min_block_sizes"] = self.tma_min_block_sizes + + if self.cooperative_reduction: + inductor_meta["persistent_reduction"] = self.persistent_reduction + + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + if num_gb is not None: + inductor_meta["kernel_num_gb"] = num_gb + if config.benchmark_kernel: + flops = self.estimate_flops() + if flops is not None: + inductor_meta["kernel_flop"] = flops + + triton_meta["configs"] = [config_of(signature)] + + if enable_pdl_codegen(): + triton_meta["launch_pdl"] = True + + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts + + self.triton_meta = triton_meta + + self.codegen_prologue(self.body) + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + if self.fixed_config: + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + config={self.fixed_config.config!r}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + elif self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if ( + len(non_constexpr_signature(signature)) == 4 + ): # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + @staticmethod + def _get_persistent_RBLOCK(rnumel): + rnumel = V.graph.sizevars.simplify(rnumel) + if isinstance(rnumel, (sympy.Integer, int)): + val = int(rnumel) + val = next_power_of_2(val) + else: + val = 2 + while not V.graph.sizevars.statically_known_leq(rnumel, val): + if val > 16 * 1024: + raise ValueError(f"Failed to find static RBLOCK for {rnumel}") + val *= 2 + + return val + + return val + + @staticmethod + def has_persistent_RBLOCK(rnumel): + try: + TritonKernel._get_persistent_RBLOCK(rnumel) + return True + except ValueError: + return False + + def codegen_static_numels(self, code): + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + r0_numel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + + def is_static_integer(expr: sympy.Expr) -> bool: + return isinstance(expr, (sympy.Integer, int)) + + for tree in self.range_trees: + if not tree.is_reduction or self.inside_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if is_static_integer(simplified_tree_numel): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + + if tree.is_reduction and self.persistent_reduction: + if self.cooperative_reduction: + numel = self.kexpr(self.rename_indexing(tree.numel)) + val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)" + else: + val = self._get_persistent_RBLOCK(tree.numel) + if self.is_native_matmul: + # tl.dot only supports shapes >= 16 + val = max(val, 16) + + code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}") + + if tree.prefix == "x" and self.no_x_dim: + code.writeline("XBLOCK: tl.constexpr = 1") + + def _get_grid_type(self) -> type[triton_heuristics.GridExpr]: + n = sum([int(not tree.is_reduction) for tree in self.range_trees]) + if self.mix_order_reduction: + assert n == 1 + return triton_heuristics.MixOrderReductionGrid + elif self.cooperative_reduction: + assert n == 1 + return triton_heuristics.CooperativeReductionGrid + elif n == 1: + return triton_heuristics.Grid1D + elif n == 2: + if any(map(self.needs_yz_grid_overflow, self.range_trees)): + return triton_heuristics.Grid2DWithYZOverflow + return triton_heuristics.Grid2D + elif n == 3: + return triton_heuristics.Grid3D + raise ValueError(f"Unsupported number of dimensions: {n}") + + def add_numel_to_call_args(self, name, call_args, arg_types): + # TODO(jansel): if there are constants, we shouldn't bother passing them as args + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree) + + if not tree.is_reduction or self.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def call_kernel( + self, name: str, node: Optional[IRNode] = None, deallocate_ws: bool = True + ): + wrapper = V.graph.wrapper_code + wrapper.write_triton_header_once() + _, call_args, _, arg_types = self.args.python_argdefs() + self.add_numel_to_call_args(name, call_args, arg_types) + + for ws in self.args.workspace_args: + wrapper.generate_workspace_allocation(ws) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + + if deallocate_ws: + self.deallocate_workspaces() + + def codegen_nan_check(self) -> None: + wrapper = V.graph.wrapper_code + _, call_args, arg_signatures, _ = self.args.python_argdefs() + for arg, arg_signature in zip(call_args, arg_signatures): + if isinstance(arg_signature, TensorArg): + if V.graph.cpp_wrapper: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable: + return TritonCSEVariable(*args, **kwargs) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + + # mix order reduction introduces an extra loop across the x + # dimension + if entry.root.is_loop or (self.mix_order_reduction and entry.prefix == "x"): + self.indexing_code.writeline(line) + else: + # lift non-reduction stores outside loop + self.body.writeline(line) + + def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str: + assert entry.tensor_dim is not None + size = self.indexing_size_str(entry.tensor_dim) + index_dtype = self.index_dtype + suffix = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + if ( + self.cooperative_reduction + and self.persistent_reduction + and entry.is_reduction + ): + suffix = f"{suffix} + rsplit_start" + return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}" + + def iteration_ranges_scalar_code( + self, entry: IterationRangesRoot, value: Any + ) -> str: + index_dtype = self.index_dtype + ndim = self.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str: + assert entry.grid_dim is not None + key = f"tl.program_id({entry.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if self.needs_yz_grid_overflow(entry): + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" + pid = entry.pid_cache.get(key, key) + if self.index_dtype != "tl.int32": + return f"{pid}.to({self.index_dtype})" + return pid + + def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool: + return ( + entry.grid_dim == 1 + and not entry.has_zdim + and not self.cooperative_reduction + and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid()) + ) + + def max_block(self, prefix: str) -> int: + if self.fixed_config: + return self.fixed_config[f"{prefix.upper()}BLOCK"] + return TRITON_MAX_BLOCK[prefix.upper()] + + def _has_constant_mask(self, tree: IterationRangesRoot) -> bool: + if self.is_native_matmul: + # tl.dot requires the shape to be >= 16, + # so when matmul shape is smaller than 16, we always keep the mask. + if V.graph.sizevars.statically_known_lt(tree.numel, 16): + return False + + if not self.optimize_mask: + return False + + if self.fixed_config and f"{tree.prefix.upper()}BLOCK" in self.fixed_config: + if self.fixed_config[f"{tree.prefix.upper()}BLOCK"] == 1: + return True + else: + if V.graph.sizevars.statically_known_equals(tree.numel, 1): + return True + + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.is_reduction and self.persistent_reduction: + max_block = self._get_persistent_RBLOCK(tree.numel) + elif tree.prefix == "x" and self.no_x_dim: + max_block = 1 + else: + max_block = self.max_block(tree.prefix) + + if tree.is_reduction and self.cooperative_reduction: + max_block = max_block * self.max_rsplit() + + # [Note: Constant mask optimisation] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # If this tree is for the y dimension, we should only use a constant + # mask if it can be guaranteed that: + # 1. (ynumel / YBLOCK) < max_ygrid or + # 2. (ynumel / YBLOCK) % max_ygrid == 0 + # Because YBLOCK is not constant, use a conservative heuristic: + # only use a constant mask if ynumel < max_ygrid. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): + return ( + tree.grid_dim != 1 + or tree.has_zdim + or V.graph.sizevars.statically_known_leq(tree.numel, get_max_y_grid()) + ) + + return False + + def _has_constant_xmask(self) -> bool: + xtree = self.range_trees[0] + assert xtree.prefix == "x" + return self._has_constant_mask(xtree) + + def filter_masks(self, mask_vars: OrderedSet[str]) -> None: + for tree in self.range_trees: + if self._has_constant_mask(tree): + mask_vars.discard(f"{tree.prefix}mask") + + # can be added as an override_mask + mask_vars.discard("None") + + @cache_on_self + def get_reduction_prefixes(self) -> list[str]: + return [ + prefix_str[symt] + for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims] + ] + + def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None: + """ + Generates code that flattens ND reduction numels, block sizes, etc. into 1D. + """ + # rnumel = r0_numel * ... * r(n-1)_numel + reduction_trees = [tree for tree in self.range_trees if tree.is_reduction] + rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees)) + buffer.splice(f"rnumel = {self.kexpr(rnumel)}") + + # RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK + rn_blocks = [ + TritonSymbols.block_sizes[tree.symt] + for tree in self.range_trees + if tree.is_reduction + ] + rblock = sympy_product(rn_blocks) + buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}") + + def _get_reduction_symbols(self, suffix: str, **kwargs) -> list[sympy.Symbol]: + """ + Helper to initialize symbols like rn_numel, rn_base, etc. + """ + rn_prefixes = self.get_reduction_prefixes() + return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes] + + @cache_on_self + def _get_reduction_index_coeffs(self) -> list[sympy.Expr]: + """ + Compute coefficients to convert ND reduction indices to linear indices. + For example: + rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index. + """ + rn_prefixes = self.get_reduction_prefixes() + rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True) + return [ + sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1) + ] + [sympy.Integer(1)] + + def _flatten_reduction_indices(self, multi_inds: list[sympy.Expr]) -> sympy.Expr: + """ + Compute linear reduction indices from N dimensional ones. + """ + coeffs = self._get_reduction_index_coeffs() + return sympy_dot(coeffs, multi_inds) + + def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None: + """ + Generates code that converts ND reduction indices into linear indices. + """ + # Gather relevant numels, indices, etc. + rn_offsets = self._get_reduction_symbols( + "offset", integer=True, nonnegative=True + ) + rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True) + + # Compute roffset and rindex. + roffset = self._flatten_reduction_indices(rn_offsets) + buffer.splice(f"roffset = {self.index_to_str(roffset)}") + rindex = self._flatten_reduction_indices(rn_inds) + buffer.splice(f"rindex = {self.index_to_str(rindex)}") + + def iteration_ranges_codegen_header( + self, entry: IterationRangesRoot, code: IndentedBuffer + ) -> None: + x = entry.prefix + if entry.is_loop: + code.writeline(f"{entry.name} = {x}offset + {x}base") + elif entry.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}") + code.writeline(f"{x}offset = 0") + else: + if entry.tensor_dim is not None: + line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}" + else: + line = self.iteration_ranges_scalar_code(entry, f"{x}offset") + + block_size = ( + f"{x.upper()}BLOCK" if not self.mix_order_reduction else "RSPLIT_SIZE" + ) + code.writelines( + [ + f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {block_size}", + f"{entry.name} = {line}", + ] + ) + if self._has_constant_mask(entry): + code.writeline(self.create_constant_mask(entry)) + elif not (x == "x" and self.mix_order_reduction): + # mix order reduction should generate xmask inside the loop + code.writeline(f"{x}mask = {entry.name} < {x}numel") + + +class TritonScheduling(SIMDScheduling): + kernel_type: type[Any] = TritonKernel + backend_features = OrderedSet( + [ + BackendFeature.FOREACH, + BackendFeature.BUCKETIZE, + BackendFeature.INPLACE_BUFFERS, + BackendFeature.MASKED_SCATTER_WITH_INDEX, + BackendFeature.SCAN, + BackendFeature.SORT, + BackendFeature.TRITON_TEMPLATES, + BackendFeature.TUPLE_REDUCTION, + ] + ) + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + if scheduler is None or not hasattr(scheduler, "nodes"): + return + for node in scheduler.nodes: + if isinstance(node, (SchedulerNode, FusedSchedulerNode)): + node.debug_device_str = debug_triton_code + + @classmethod + def get_backend_features(cls, device: torch.device): + if ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ): + return OrderedSet( + [*cls.backend_features, BackendFeature.REDUCE_TO_SINGLE_ELEMENT] + ) + return cls.backend_features + + def codegen_comment(self, node_schedule, kernel_name=None): + wrapper = V.graph.wrapper_code + origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper) + if origins: + wrapper.make_comment(origins) + + if config.debug_fusion: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ) + + if not any( + isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule + ): + # We probably should look what are the nodes inside a foreach + # schedule node + node_names = [ + n.get_name() + for n in node_schedule + if isinstance(n, BaseSchedulerNode) + ] + wrapper.make_comment( + f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" + ) + + if kernel_name: + debug_handle = set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + kernel_name, + ) + wrapper.write_provenance_debug_handle(kernel_name, debug_handle) + + def define_kernel(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + if config.aot_inductor.model_name_for_generated_files: + # When AOTI compiles multiple submodules, we need to use the model name to + # distinguish kernel related symbols. + kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" + + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + compile_wrapper = IndentedBuffer() + + if async_compile.use_process_pool(): + # The process pool is warm, we can shell out to workers right away. This + # allows us to save the result in async_compile.CompiledTritonKernels, + # so that the second time we call async_compile.triton, we do no work. + async_compile.triton(subs_name, src_code) + + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name + + def benchmark_fused_nodes(self, nodes, n_spills_threshold=8) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) + mod = PyCodeCache.load(src_code) + return self.benchmark_codegened_module( + mod, n_spills_threshold, node_names=OrderedSet(n.get_name() for n in nodes) + ) + + def benchmark_codegened_module( + self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None + ) -> tuple[float, str]: + """Benchmark an already compiled module""" + device_interface = get_interface_for_device(V.graph.device_type) + with ( + preserve_rng_state(), + device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined] + ): + ms = None + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms)) + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + node_names = ( + node_names if node_names is not None else OrderedSet(["unknown"]) + ) + log.debug( + "kernel src code for %s written to: %s", + node_names, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms, mod.__file__ + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + # call once to trigger the compilation + try: + call(wrapped_jit_function.clone_args(*args)[0]) + except Exception as e: + if config.triton.disallow_failing_autotune_kernels_TESTING_ONLY: + raise + log.debug( # noqa: G200 + "Exception (%s) in compiling fused nodes %s", + e, + node_names, + ) + ms = float("inf") + store_cache() + return ms, mod.__file__ + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + # n_spills does not necessarily mean it's not profitable to fuse, + # and sometimes it can be inaccurate + if launchers[0].n_spills > n_spills_threshold: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + device = V.graph.get_current_device_or_throw() + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, + ) + # overhead of cloning args gives bias for fusing the kernel + # in the case of mutating/in-placeable second fusion + # TODO - would be better as a hook in triton do_bench that reset + # the input values between benchmarking + if len(wrapped_jit_function.mutated_arg_names) > 0: + ms = ms - benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args), + device=str(device), + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + node_names, + ms, + ) + store_cache() + return ms, mod.__file__ + + def create_kernel_choices( # type: ignore[override] + self, + kernel_features: SIMDKernelFeatures, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + is_scan = kernel_features.contains_op("scan") + is_split_scan = is_scan and any( + node.is_split_scan() for node in kernel_features.scheduler_nodes() + ) + kernel_type: type[TritonKernel] = self.kernel_type + if is_split_scan: + from .triton_split_scan import TritonSplitScanKernel + + kernel_type = TritonSplitScanKernel + + if is_scan: + # TODO(jansel): scan does not yet work with cooperative reductions + kernel_kwargs["override_cooperative_reduction"] = False + + # ops.sort only works with persistent reduction, and is not bandwidth bound anyway + # so taking the hit of non-coalesced loads is okay + if kernel_features.contains_op("sort"): + kernel_kwargs["override_persistent_reduction"] = True + kernel_kwargs["override_cooperative_reduction"] = False + + if not TritonKernel.has_persistent_RBLOCK(kernel_features.reduction_numel): + # Cannot use persistent reduction with unknown dynamic rnumel + assert not kernel_kwargs.get("override_persistent_reduction") + kernel_kwargs["override_persistent_reduction"] = False + + kernel_kwargs = V.choices.triton_kernel_kwargs( + kernel_type, kernel_features, kernel_args, kernel_kwargs + ) + kernel = kernel_type(*kernel_args, **kernel_kwargs) + return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs) + + def add_multi_kernel_choices( + self, + kernel: TritonKernel, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + kernels: list[TritonKernel] = [kernel] + if not config.triton.multi_kernel: + return kernels + + optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get( + "override_persistent_reduction" + ) + optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get( + "override_cooperative_reduction" + ) + if optional_persistent: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_persistent_reduction=False, + ) + ) + if optional_cooperative: + rnumel = kernel.features.reduction_numel + # for larger sizes non-cooperative gets very slow + if V.graph.sizevars.statically_known_leq(rnumel, 65536): + kernels.append( + other := self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + ) + ) + if optional_persistent and other.persistent_reduction: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + override_persistent_reduction=False, + ) + ) + + if len(kernels) > 1: + for kernel2 in kernels[1:]: + # Keep buffers needed by the non-persistent reduction so both kernels have the same arguments + kernel2.must_keep_buffers = kernel.must_keep_buffers + # persistent kernels must be generated last so must_keep_buffers works right + kernels.sort(key=lambda k: k.persistent_reduction) + return kernels + + def benchmark_combo_kernel(self, node_list): + mod: ModuleType + ms: float + ms_clone: float + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return tuple(float(e) for e in fd.read().split()) + return (None, None) + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms) + " " + str(ms_clone)) + + total_ms, file_list = 0, [] + total_clone_ms: float = 0.0 + removed_buffers_orig = V.graph.removed_buffers + V.graph.removed_buffers = OrderedSet(removed_buffers_orig) + inplaced_to_remove_orig = V.graph.inplaced_to_remove + V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig) + enable_autotune = config.combo_kernels_autotune > 0 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0 + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes=node_list, + custom_part_algorithm=True, + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + only_gen_src_code=True, + ) + + for src_code, _, node_group in kernel_code_list: + fused_node_lists = [node.get_nodes() for node in node_group] + names = [n.get_name() for nodes in fused_node_lists for n in nodes] + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + log.debug( + "kernel src code for %s written to: %s", + names, + mod.__file__, + ) + ms, ms_clone = load_cache() + if ms is not None: + total_ms += ms # type: ignore[assignment] + total_clone_ms += ms_clone + file_list.append(mod.__file__) + continue + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)[0]) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = ms_clone = float("inf") + else: + device = V.graph.get_current_device_or_throw() + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, + ) + ms_clone = benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args)[0], + device=device, + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs", + OrderedSet(n.get_name() for n in node_group), + ms, + ms_clone, + ) + store_cache() + total_ms += ms + total_clone_ms += ms_clone + file_list.append(mod.__file__) + V.graph.removed_buffers = removed_buffers_orig + V.graph.inplaced_to_remove = inplaced_to_remove_orig + return total_ms, total_clone_ms, file_list + + +def debug_triton_code(node: BaseSchedulerNode) -> list[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + device = node.get_device() + assert device is not None + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), ( + f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + ) + + with V.graph.set_current_device(device): + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes( + node.get_nodes() + ).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..74ed7d3797396deb07d5957d67af5bb22930e11a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py @@ -0,0 +1,1037 @@ +import itertools +import logging +import textwrap +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, cast, Optional, Union + +import sympy +from sympy import Integer, Symbol + +from torch.utils._ordered_set import OrderedSet + +from .. import config, metrics +from ..runtime.hints import DeviceProperties +from ..runtime.runtime_utils import next_power_of_2 +from ..runtime.triton_heuristics import ( + RoundRobinComboKernelGrid, + SequentialComboKernelGrid, +) +from ..scheduler import BaseSchedulerNode +from ..utils import Placeholder, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + Kernel, + PythonPrinter, + RemovedArg, + SizeArg, + WorkspaceArg, +) +from .simd import prefix_is_reduction, SIMDScheduling +from .simd_kernel_features import SIMDKernelFeatures +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, equal_1_arg_indices, signature_to_meta + + +log = logging.getLogger(__name__) +pexpr = PythonPrinter().doprint +LARGE_NUMELS = 512e5 +BLOCK_UTILIZATION = 0.8 + + +def _default_custom_combo_kernel_horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], +) -> list[list[BaseSchedulerNode]]: + """Horizontally partition the given list of nodes into a list of list of nodes where each sublist + represents a partition. Nodes in different partitions are implemented in different combo kernels. + Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. + + Input arguments: + nodes: a list of fused scheduler nodes to partition. + triton_scheduling: TritonScheduling instance. + kernel_map: a map from node to its kernel. + node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel). + Output: + a list of list of nodes with each sublist representing a partition. + + The default algorithm is to partition nodes based on the following rules: + 1) nodes with the same number of block dimensions are grouped together. + 2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes. + 3) large reduce nodes are separated from other nodes. + """ + + assert len(nodes) >= 1 + + # first partition nodes based on number of block dimensions + tilings = [node_info_map[n][1] for n in nodes] + + max_dims = max(len(t) for t in tilings) + nodes_per_ndim: list[list[BaseSchedulerNode]] = [] + for i in range(2, max_dims + 1): + group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] + reduction = [ + n + for n in group_per_dim + if kernel_map[n].inside_reduction + and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim) + ] + not_reduction = [n for n in group_per_dim if n not in reduction] + # rnumel > 2048 usually has long execution time + # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes + long_reduction = [ + n + for n in reduction + if ( + V.graph.sizevars.shape_env.has_hint(n.group[-1][-1]) + and V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] + ) + ] + short_reduction = [n for n in reduction if n not in long_reduction] + if long_reduction: + log.debug( + "ComboKernels: %d long reduction nodes are separated", + len(long_reduction), + ) + large_pointwise = [ + n + for n in not_reduction + if not kernel_map[n].inside_reduction + and len(kernel_map[n].numels) == 2 + and V.graph.sizevars.shape_env.has_hint(kernel_map[n].numels["x"]) + and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS + ] + if large_pointwise: + # TODO benchmark the performance when large pointwise nodes combining with others + log.debug( + "ComboKernels: %d large pointwise nodes are separated", + len(large_pointwise), + ) + not_reduction = [n for n in not_reduction if n not in large_pointwise] + nodes_per_ndim.extend([node] for node in large_pointwise) + + nodes_per_ndim.extend( + g for g in (not_reduction, short_reduction, long_reduction) if g + ) + + assert sum(len(p) for p in nodes_per_ndim) == len(nodes) + return nodes_per_ndim + + +_custom_combo_kernel_horizontal_partition_algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], +] = _default_custom_combo_kernel_horizontal_partition + + +def set_custom_combo_kernel_horizontal_partition( + algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], + ], +) -> None: + """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions + are implemented in different combo kernels. Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args. + + The algorithm should take a list of nodes and return a list of list of nodes. + + The default algorithm is to partition nodes based on number of block dimensions. + """ + global _custom_combo_kernel_horizontal_partition_algorithm + _custom_combo_kernel_horizontal_partition_algorithm = algorithm + + +@dataclass +class PartitionState: + partitions: list[list[BaseSchedulerNode]] + cur_partition: list[BaseSchedulerNode] + cur_count: int + + def finalize(self) -> None: + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ComboKernel(Kernel): + @staticmethod + def _update_partition( + partition_state: PartitionState, + node_rw_count: int, + node_info: BaseSchedulerNode, + ) -> None: + if partition_state.cur_count + node_rw_count > config.combo_kernel_max_num_args: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def _base_horizontal_partition( + subkernel_nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + # TODO support combination of kernels with different block dimensions + assert len(subkernel_nodes) >= 1 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm + ) + + ndim_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + yelem_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + _node_schedule, tiled_groups, _numel, _rnumel = node_info_map[node] + node_info = node + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + ndim = len(tiled_groups) + assert ndim >= 2, f"Combokernel not support tile {tiled_groups}" + if not mixed_sizes and ndim == 3: + y_elem = tiled_groups["y"] + partition_state = yelem_to_partition_state[y_elem] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + else: + assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}" + partition_state = ndim_to_partition_state[ndim] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + + all_partitions = [] + for partition_state in ndim_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + for partition_state in yelem_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + + return all_partitions + + @staticmethod + def horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool = False, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum) + for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into + sublists in the following way: + 1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True + 2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is + guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same + 2D or 1D blocking strategy. + """ + if custom_algorithm: + raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm( + nodes, triton_scheduling, kernel_map, node_info_map + ) + else: + raw_partitions = [nodes] + + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + all_partitions = [] + for raw_partition in raw_partitions: + all_partitions.extend( + ComboKernel._base_horizontal_partition( + raw_partition, triton_scheduling, node_info_map, custom_algorithm + ) + ) + return all_partitions + + class SequentialDispatch: + """ + The dispatcher which dispatches the subkernels in a sequential manner: + the blocks are first dispatched to the 1st subkernel (until it is filled), + then to the 2nd subkernel, and so on. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = SequentialComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + if num == 0: + cls._calculate_xblocks(kernel, code) + code.splice(f"if pid < num_xblocks_{num}:") + with code.indent(): + code.splice("pid_offset = pid") + else: + code.splice(f"elif pid < num_xblocks_{num}:") + with code.indent(): + code.splice(f"pid_offset = pid - num_xblocks_{num - 1}") + + @classmethod + def _calculate_xblocks( + cls, kernel: "ComboKernel", code: IndentedBuffer + ) -> None: + x_numels_list = kernel.x_numels_list + for i in range(len(x_numels_list)): + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) + ) + xblock_str = ( + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" + ) + if i == 0: + code.splice(f"num_xblocks_{i} = {xblock_str}") + else: + code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}") + + class RoundRobinDispatch: + """ + The dispatcher which dispatches the subkernels in a round robin manner: + the blocks are interleavedly dispatched to each subkernel to execute them + in parallel. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = RoundRobinComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + num_kernels = len(kernel.sub_kernels) + if num == 0: + cond = "if" + else: + cond = "elif" + code.splice(f"{cond} pid % {num_kernels} == {num}:") + with code.indent(): + code.splice(f"pid_offset = pid // {num_kernels}") + + def __init__( + self, enable_autotune: bool = False, mixed_sizes: bool = False + ) -> None: + super().__init__() + self.sub_kernels: list[TritonKernel] = [] + self.iter_vars_count = itertools.count() + self.grids: list[list[int]] = [] + self.min_x_blocks_list: list[Union[int, str]] = [] + self.x_numels_list: list[Union[int, str]] = [] + self.enable_autotune = enable_autotune + self.mixed_sizes = mixed_sizes + self.dispatch_class: Optional[ + type[Union[ComboKernel.SequentialDispatch, ComboKernel.RoundRobinDispatch]] + ] = None + self.block_args: list[str] = [] + # there following are used when autotuning is disabled + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.block_size_reduce = 256 + self.dynamic_shape_args: list[str] = [] + + def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: + sub_kernel = triton_kernel + # pyrefly: ignore [bad-assignment] + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + @staticmethod + def create_triton_kernel( + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + optimize_mask: bool, + ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ + return TritonKernel( + tiling, + features=features, + pid_cache={"tl.program_id(0)": "pid_offset"}, + optimize_mask=optimize_mask, + # foreach kernels don't work with cooperative reductions + override_cooperative_reduction=False, + ) + + def codegen_static_numels_sub_kernel( + self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int + ) -> list[str]: + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + grid = [] + uniquify_block_sizes = [] + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") + + # pyrefly: ignore [missing-argument] + if not tree.is_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + # pyrefly: ignore [bad-argument-type] + grid.append(f"{tree.prefix}numel_{num}") + + if tree.is_reduction and sub_kernel.persistent_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + val = int(simplified_tree_numel) + else: + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) + val = next_power_of_2(val) + code.writeline( + f"{tree.prefix.upper()}BLOCK_{num}: tl.constexpr = {val}" + ) + uniquify_block_sizes.append(f"{tree.prefix.upper()}BLOCK") + + if tree.prefix == "x" and sub_kernel.no_x_dim: + code.writeline(f"XBLOCK_{num}: tl.constexpr = 1") + uniquify_block_sizes.append("XBLOCK") + self.grids.append(grid) + return uniquify_block_sizes + + def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: + """ + Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. + Grid calculation needs to make sure that they are assigned with enough number of blocks. + """ + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + if sub_kernel.no_x_dim: + min_x_blocks = x_numels + x_numels = ( + # pyrefly: ignore [unsupported-operation] + -min_x_blocks + if isinstance(x_numels, int) + # pyrefly: ignore [redundant-cast] + else "-" + cast(str, x_numels) + ) + else: + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + self.min_x_blocks_list.append(min_x_blocks) + self.x_numels_list.append(x_numels) + + def select_heuristics(self, sub_kernel: TritonKernel) -> tuple[str, dict[str, int]]: + size_hints = { + prefix: next_power_of_2( + V.graph.sizevars.size_hint( + numel, fallback=config.unbacked_symint_fallback + ) + ) + for prefix, numel in sub_kernel.numels.items() + if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction + } + if sub_kernel.persistent_reduction: + assert sub_kernel.inside_reduction + heuristics = "persistent_reduction" + elif sub_kernel.inside_reduction: + heuristics = "reduction" + else: + heuristics = "pointwise" + return heuristics, size_hints + + def select_combo_heuristics( + self, heuristics_list: list[str], size_hints_list: list[dict[str, int]] + ) -> tuple[str, dict[str, int], TritonKernel]: + if not self.enable_autotune: + return "foreach", size_hints_list[0], self.sub_kernels[0] + if "reduction" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "reduction" else 0, + ) + return heuristics_list[i], size_hints_list[i], self.sub_kernels[i] + elif "pointwise" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "pointwise" else 0, + ) + # modify size_hint to avoid oom check fail (may be a false alarm) + num_pointwise = len([e for e in heuristics_list if e == "pointwise"]) + num_reduction = len([e for e in heuristics_list if e == "reduction"]) + num_persistent_reduction = len( + [e for e in heuristics_list if e == "persistent_reduction"] + ) + assert num_reduction == 0, ( + "combining pointwise and reduction are not supported yet." + ) + heuristics = ( + "pointwise_with_reduction" + if num_persistent_reduction > 0 + else "pointwise" + ) + if len(heuristics_list) - num_pointwise >= 4: + size_hints = size_hints_list[i] + size_hints["x"] = min(128, size_hints["x"]) + return heuristics, size_hints_list[i], self.sub_kernels[i] + else: + return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] + + def get_mutated_args_sub_kernels(self) -> list[str]: + mutated_args: OrderedSet[str] = OrderedSet() + for sub_kernel in self.sub_kernels: + for mutation in sub_kernel.mutations: + if mutation in sub_kernel.args.input_buffers: + mutated_args.add(sub_kernel.args.input_buffers[mutation]) + if ( + mutation in sub_kernel.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in sub_kernel.removed_buffers + ): + mutated_args.add( + cast( + InplacedBuffer, sub_kernel.args.inplace_buffers[mutation] + ).inner_name + ) + if mutation in sub_kernel.args.output_buffers: + arg = sub_kernel.args.output_buffers[mutation] + assert not isinstance(arg, RemovedArg) + mutated_args.add(arg) + return sorted(mutated_args) + + def select_dispatch_strategy(self) -> None: + if self.dispatch_class is not None: + return + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in x_numels_list means a dynamic shape + self.dispatch_class = ComboKernel.SequentialDispatch + return + # A negative x_blocks_list element means the kernel is not tunable, + # i.e., no_x_dim = True + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] + total = max(x_numels_list) * len(x_numels_list) + needed = sum(x_numels_list) + if needed / total > BLOCK_UTILIZATION: + # Introduced overhead (masked blocks) is less than 20% + self.dispatch_class = ComboKernel.RoundRobinDispatch + else: + self.dispatch_class = ComboKernel.SequentialDispatch + + def jit_line( + self, + heuristics: str, + size_hints: dict[str, int], + selected_kernel: TritonKernel, + signature: list[Any], + argdefs: list[ArgName], + pointwise_with_reduce: bool = False, + ) -> str: + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + for i, sub in enumerate(self.sub_kernels): + self.min_x_blocks_sub_kernel(sub, i) + self.select_dispatch_strategy() + triton_meta = { + "signature": signature_to_meta( + signature, size_dtype=size_dtype, argdefs=argdefs + ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + } + + for arg_num in equal_1_arg_indices(signature): + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + + # pyrefly: ignore [unsupported-operation] + triton_meta["configs"] = [config_of(signature)] + mutated_args = self.get_mutated_args_sub_kernels() + dispatch = self.dispatch_class + assert dispatch is not None + inductor_meta = { + "grid_type": dispatch.grid_expr.__name__, + "combo_grid_meta": self.combo_grid_meta(), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + **TritonKernel.inductor_meta_common(), + } + + sub_kernel = selected_kernel + if heuristics == "foreach": + heuristics_line = f""" + @triton_heuristics.foreach( + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + elif sub_kernel.inside_reduction: + reduction_hint = sub_kernel.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + + return heuristics_line + + def codegen_blocks(self, code: IndentedBuffer) -> None: + for block in self.block_args: + assert block in ( + "XBLOCK", + "YBLOCK", + "R0_BLOCK", + ), f"{block} is not supported without autotuning" + if "YBLOCK" in self.block_args: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + if "R0_BLOCK" in self.block_args: + code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}") + code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}") + + def get_block_args(self) -> list[ConstexprArg]: + """ + Calculate blocks from sub_kernels and range_trees. + **Update self.block_args** + Return the block args + """ + block_names = {} + for sub_kernel in self.sub_kernels: + # TODO: we assume all sub_kernels have the same block size + for tree in sub_kernel.range_trees: + # pyrefly: ignore [missing-argument] + if tree.is_reduction and ( + not sub_kernel.inside_reduction or sub_kernel.persistent_reduction + ): + continue + if tree.prefix == "x" and sub_kernel.no_x_dim: + continue + block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix + self.block_args = list(block_names.keys()) + + return [ConstexprArg(x) for x in block_names] + + def add_numel_to_args( + self, argdefs: list[ArgName], signature: list[Any] + ) -> list[ArgName]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(f"{tree.prefix}numel_{num}")) + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args( + self, name: str, call_args: list[Any], arg_types: list[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.range_trees: + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + # pyrefly: ignore [missing-argument] + if not tree.is_reduction or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def kernel_benchmark_extra_args(self) -> list[str]: + extra_args = [] + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.range_trees: + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + # pyrefly: ignore [missing-argument] + if not tree.is_reduction or sub_kernel.inside_reduction: + extra_args.append( + str( + V.graph.sizevars.size_hint( + tree.numel, fallback=config.unbacked_symint_fallback + ) + ) + ) + return extra_args + + def codegen_kernel(self, name: Optional[str] = None) -> str: + # TODO: is it correct to use the first sub kernel's heuristics? + heuristics_list, size_hints_list = [], [] + for subkernel in self.sub_kernels: + h, s = self.select_heuristics(subkernel) + heuristics_list.append(h) + size_hints_list.append(s) + heuristics, size_hints, selected_kernel = self.select_combo_heuristics( + heuristics_list, size_hints_list + ) + pointwise_with_reduction, heuristics = ( + (True, "pointwise") + if heuristics == "pointwise_with_reduction" + else (False, heuristics) + ) + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + if config.benchmark_combo_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + seen_helpers: OrderedSet[str] = OrderedSet() + for sub_kernel in self.sub_kernels: + for helper in sub_kernel.helper_functions: + if helper not in seen_helpers: + code.writeline("") + code.splice(helper) + seen_helpers.add(helper) + + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) + block_args = self.get_block_args() + if self.enable_autotune: + argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args]) + if triton_version_uses_attrs_dict(): + signature.extend(block_args) + + code.splice( + self.jit_line( + heuristics, + size_hints, + selected_kernel, + pointwise_with_reduce=pointwise_with_reduction, + signature=signature, + argdefs=argdefs, + ) + ) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + + with code.indent(): + code.splice("pid = tl.program_id(0)") + if not self.enable_autotune: + self.codegen_blocks(code) + + for num, sub_kernel in enumerate(self.sub_kernels): + assert self.dispatch_class is not None + self.dispatch_class.codegen_pid_range(self, num, code) + with code.indent(): + uniquify = self.codegen_static_numels_sub_kernel( + code, sub_kernel, num + ) + sub_kernel.codegen_body() + uniquified_body = self.uniquify_block_sizes( + sub_kernel.body, num, uniquify + ) + code.splice(uniquified_body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + if config.benchmark_combo_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb=0)) + + return code.getvalue() + + def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: + """ + Generates Python code for benchmarking this combo kernel. + - Creates example inputs (random tensors, constants, sizes). + - Runs the kernel on the current GPU/stream. + - Prints runtime (ms) and throughput (GB/s) using `num_gb`. + Args: + num_gb (float): The number of gigabytes to use for throughput calculation. + Returns: + IndentedBuffer: A buffer containing the generated Python benchmark code. + """ + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + size = V.graph.sizevars.size_hints( + buf.get_size(), fallback=config.unbacked_symint_fallback + ) + stride = V.graph.sizevars.size_hints( + buf.get_stride(), fallback=config.unbacked_symint_fallback + ) + result.writeline( + f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + size = V.graph.sizevars.size_hints( + const_tensor.size(), fallback=config.unbacked_symint_fallback + ) + stride = V.graph.sizevars.size_hints( + const_tensor.stride(), fallback=config.unbacked_symint_fallback + ) + result.writeline( + f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + # for benchmark harness, we ignore arg_sig.zero_mode and always zero it + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + if self.dynamic_shape_args: + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + device = V.graph.get_current_device_or_throw() + index = V.graph.get_current_device_or_throw().index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self) -> str: + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def uniquify_block_sizes( + self, code: IndentedBuffer, num_kernel: int, uniquify: list[str] + ) -> IndentedBuffer: + if not uniquify: + return code + modified = IndentedBuffer(initial_indent=code._indent) + for line in code._lines: + if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]): + modified_line = line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + modified.writeline(modified_line) + elif isinstance(line, DeferredLine) and ( + blocks := [e for e in uniquify if e in line.line] + ): + modified_line = line.line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + new_line = DeferredLine(line.name, modified_line) + modified.writeline(new_line) + else: + modified.writeline(line) + return modified + + def call_kernel(self, code: IndentedBuffer, name: str) -> None: + _, call_args, _, arg_types = self.args.python_argdefs() + + wrapper = V.graph.wrapper_code + assert self.dispatch_class is not None + if self.dynamic_shape_args: + self.add_numel_to_call_args(name, call_args, arg_types) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + ) + + def combo_grid_meta(self) -> dict[str, Any]: + dynamic_shape = bool(self.dynamic_shape_args) + num_kernels = len(self.sub_kernels) + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) + + if not self.enable_autotune: + if "YBLOCK" in self.block_args: + default_config = { + "XBLOCK": self.block_size_2d, + "YBLOCK": self.block_size_2d, + } + else: + default_config = {"XBLOCK": self.block_size_1d} + else: + default_config = None + + meta = { + "num_kernels": num_kernels, + "min_blocks": min_blocks, + "default_config": default_config, + } + + for num, sub_kernel in enumerate(self.sub_kernels): + meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim + for tree in sub_kernel.range_trees: + # pyrefly: ignore [missing-argument] + if not tree.is_reduction: + numel_name = f"{tree.prefix}numel_{num}" + if numel_name in self.dynamic_shape_args: + meta[numel_name] = None + else: + meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel)) + + return meta diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..0abee5439393980560347aa07f6baf3f24f3e35f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,224 @@ +# mypy: allow-untyped-defs +import functools +from typing import Union + +import sympy + +from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction +from torch._inductor.codegen.triton import ( + triton_compute_type, + TritonCSEVariable, + TritonKernel, +) +from torch._inductor.runtime.triton_heuristics import SplitScanGrid +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv + +from ..utils import sympy_product + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + tiling: dict[str, sympy.Expr], + pid_cache=None, + fixed_config=None, + **kwargs, + ) -> None: + assert pid_cache is None, "not supported" + assert fixed_config is None, "not supported" + super().__init__( + tiling, + **kwargs, + ) + self.no_x_dim = True + + def should_use_persistent_reduction(self) -> bool: + return False + + def should_use_cooperative_reduction(self) -> bool: + return False + + def initialize_range_tree(self, pid_cache): + prefixes = ["y", "x", "r0_"] + assert len(self.numels) <= len(prefixes), ( + "z dimension not supported for split scan" + ) + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = {"r0_": 0, "x": 1, "y": 2} + for prefix in active_prefixes: + numel = self.numels[prefix] + tensor_dim = 0 if prefix_is_reduction(prefix) else None + grid_dim = grid_dims[prefix] + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim=False, + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtypes, combine_fn, values): + """ + Perform an associative scan on 'values'. + """ + import triton.language as tl + + (dtype,) = dtypes + (value,) = values + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + reduction_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if prefix_is_reduction(prefix) + ) + pointwise_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if not prefix_is_reduction(prefix) + ) + max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base: Union[str, TritonCSEVariable] + scratch_base, _, offset = self.args.workspace(nelem=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load( + f"{scratch_base} + {self.index_to_str(offset)}", shape=() + ) + runtime_rblocks = cse_load( + f"tl.num_programs({self.range_trees[-1].index})", shape=() + ) + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}", + shape=(), + ) + + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + value = cse_compute( + f"{value}.to({compute_type})", + dtype=dtype, + shape=value.shape, + ) + value = cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtype, + shape=self.dense_size_list(), + ) + + combine_helper_fn = self._lift_helper(combine_fn, (value,), (dtype,)) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + shape = list(self.dense_size_list()) + del shape[dim] + + block_sum = cse_compute( + f"tl.reduce({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + shape=shape, + ) + exclusive_prefix = self.cse.newvar( + dtype=dtype, + shape=shape, + ) + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + shape=shape, + ) + combined_result = cse_compute( + f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", + dtype=dtype, + shape=shape, + ) + return ( + cse_compute( + f"tl.where(roffset == 0, {block_scan}, {combined_result})", + dtype=dtype, + shape=block_scan.shape, + ), + ) + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_type(self) -> type[SplitScanGrid]: + return SplitScanGrid diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..75a34813c876b2e8fa11cb14cac60b761636973e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,265 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import sympy + +import torch +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import config +from ..runtime.hints import AttrsDescriptorWrapper +from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + WorkspaceArg, +) + + +def should_unwrap_unspec_arg(name: str): + if V.graph.is_unspec_arg(name): + # Unwrap on all devices except CPU + if V.graph.get_current_device_or_throw().type != "cpu": + return True + # Only unwrap on CPU if the input is not used as an output + if name not in V.graph.mutated_buffers: + return True + return False + + +def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/triton-lang/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + typ = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + typ = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + typ = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + typ = "*fp8e5b16" + else: + typ = _type_of(arg.dtype) + if should_unwrap_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_typ = typ.lstrip("*") + if new_typ in ["fp16", "bf16"]: + return "fp32" + else: + return new_typ + else: + return typ + if isinstance(arg, SizeArg): + if arg.expr is None: + if triton_version_uses_attrs_dict(): + # In newer versions of Triton, the signature includes "None" args + # and their type is marked as "constexpr" + return "constexpr" + else: + # In older versions of Triton... + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif _arg_equals_1(arg) and triton_version_uses_attrs_dict(): + # In new versions of Triton, if we have an equal-to-1 arg that's marked as a constant, + # it should be marked as "constexpr" in the signature. + return "constexpr" + elif isinstance(arg.expr, (float, sympy.Float)): + return "fp32" + elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( + arg.expr, (SymT.UNBACKED_FLOAT) + ): + return "fp32" + elif isinstance(arg.expr, bool): + return "i1" + + # if this is a integer + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + elif size_dtype is None: + # no hint: we'll see if we know that this is a 32-bit int, and guard if possible. + int_max = torch.iinfo(torch.int32).max + if expr_fits_within_32bit(arg.expr): + V.graph.sizevars.check_leq(arg.expr, int_max) + return "i32" + else: + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return _type_of(arg.dtype) + if isinstance(arg, TMADescriptorArg): + if arg.api_type == "experimental": + return "nvTmaDesc" + else: + # https://github.com/triton-lang/triton/blob/9695baed9b46cf957e08b157bb4133f4a4b331c5/python/triton/runtime/jit.py#L360-L363 + assert arg.api_type == "stable" + assert arg.block_shape is not None + assert arg.dtype is not None + inner = _type_of(arg.dtype)[1:] # strip the `*`: *fp32 -> fp32 + return f"tensordesc<{inner}{list(arg.block_shape)}>" + if isinstance(arg, ConstexprArg): + return "constexpr" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def non_constexpr_signature(signature): + new_signature = [] + for arg in signature: + if not isinstance(arg, ConstexprArg): + new_signature.append(arg) + + return new_signature + + +def signature_to_meta( + signature: list[KernelArgType], + *, + size_dtype: Optional[str], + argdefs: list[ArgName], + indices: Optional[list[int]] = None, + is_template: bool = False, +) -> dict[str, str]: + if indices is None: + indices = list(range(len(signature))) + + def _decide_tl_dtype(arg): + # Even if the ks0 symbol itself is within tl.int32 range, it's + # risky to use tl.int32 dtype since we may have ks0*ks1 later + # for kernels like torch.mean when dynamic shape is enabled. + # + # Check config.triton.use_block_ptr, since Triton block pointer + # does not support 64bit indexing: + # https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6 + # + # If the triton metadata is for a template, don't use tl.int64 index. + # Templates like flex attention/decoding uses block pointers which + # does not support 64 bit indexing. + if ( + not config.triton.use_block_ptr + and not is_template + and isinstance(arg, SizeArg) + and arg.name.startswith("ks") + ): + return "tl.int64" + return size_dtype + + return { + argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg)) + for i, arg in zip(indices, signature) + } + + +def is_unaligned_buffer(arg: TensorArg): + buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + + if buf_name in V.graph.graph_inputs: + # See Note: [Input Alignment handling in Inductor] + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False + + if buf_name in V.graph.constants: + # all constants are assumed to be aligned + return False + + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + + if isinstance(layout, torch._inductor.ir.NonOwningLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +def _arg_equals_1(arg: KernelArgType) -> bool: + return ( + isinstance(arg, SizeArg) + and isinstance(arg.expr, (int, sympy.Integer)) + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + + +def equal_1_arg_indices( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> tuple[int, ...]: + if indices is None: + indices = list(range(len(args))) + + equal_to_1 = tuple(i for i, arg in zip(indices, args) if _arg_equals_1(arg)) + + return equal_to_1 + + +def config_of( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/triton-lang/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, + alignment, # type: ignore[arg-type] + ) + return offset_aligned and not is_unaligned_buffer(x) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + # We allocate the workspace ourselves, so it is always aligned + return True + if isinstance(x, (TMADescriptorArg, ConstexprArg)): + return False + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + + equal_to_1 = equal_1_arg_indices(args, indices=indices) + + # pyrefly: ignore [bad-argument-type] + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b62bbee97c2f1d81fa7bf6a12599930a920662 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,3950 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import dis +import functools +import inspect +import logging +import operator +import random +import re +import tempfile +from collections.abc import Callable +from itertools import chain, count +from typing import Any, Optional, TYPE_CHECKING, Union + +import sympy +from sympy import Expr + +import torch +import torch._ops +import torch.utils._pytree as pytree +from torch import dtype as torch_dtype +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codegen.debug_utils import DebugPrinterManager +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._library.opaque_object import is_opaque_value_type +from torch._logging import trace_structured +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + ConvertIntKey, + DivideByKey, + resolve_unbacked_bindings, + SymTypes, +) +from torch.fx.node import _get_qualified_name +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import async_compile, config, ir +from ..codecache import output_code_log +from ..ir import IRNode, ReinterpretView +from ..runtime import triton_heuristics +from ..runtime.hints import DeviceProperties +from ..utils import ( + cache_on_self, + DelayReplaceLine, + get_benchmark_name, + get_dtype_size, + IndentedBuffer, + is_codegen_graph_partition_subgraph, + is_using_cudagraph_partition, + LineContext, + sympy_product, + sympy_str, + sympy_subs, + triton_version_uses_attrs_dict, +) +from ..virtualized import V +from .common import ( + ArgName, + CodeGen, + DeferredLine, + PythonPrinter, + WorkspaceArg, + WorkspaceZeroMode, +) +from .cpp_utils import cexpr +from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta + + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + import triton + + from ..graph import GraphLowering + from ..ir import ExternKernel + from ..scheduler import BaseSchedulerNode + from .wrapper_fxir import FxConverter + + +log = logging.getLogger(__name__) + +pexpr = PythonPrinter().doprint + + +ReuseKey = tuple[torch.device, torch.dtype, str, bool] +BufferLike = Union[ir.Buffer, WorkspaceArg] +FxConversionFunc = Callable[["WrapperLine"], None] + + +def buffer_reuse_key(node: BufferLike) -> ReuseKey: + storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers + return ( + node.get_device_or_error(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, + ) + + +def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike): + # Return True if input_buf can be re-inplaced for output_buf. + # This differs from `buffer_reuse_key` for general buffer reuse. + if input_buf.get_device_or_error() != output_buf.get_device_or_error(): + return False + + if input_buf.get_dtype() != output_buf.get_dtype(): + return False + + input_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(input_buf) + ) + output_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(output_buf) + ) + + if ( + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(input_size) == sympy_str(output_size) + ) or ( + # statically known that 0.95 * input_size <= output_size <= input_size + V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size) + and V.graph.sizevars.statically_known_leq(output_size, input_size) + ): + return True + + return False + + +def codegen_reinterpret_view_helper(data): + """ + Collapse a chain of ReinterpretView <- StorageBox + <- ReinterpretView <- StorageBox.... <- buffer wrappers if every layer + has the same offset as the innermost (base) buffer. + + Returns: + (size, stride, offset, dtype, collapsible: bool) + """ + if isinstance(data, ir.Buffer): + lay = data.get_layout() + return lay.size, lay.stride, lay.offset, lay.dtype, True + + layouts: list[Any] = [] + cur = data + while isinstance(cur, (ir.TensorBox, ir.StorageBox, ir.ReinterpretView)): + lay = cur.get_layout() + if lay is None: + return None, None, None, None, False + layouts.append(lay) + cur = cur.data # unwrap + + if not isinstance(cur, ir.Buffer): + return None, None, None, None, False + + # All wrapper offsets must match base offset to be collapsible + for lay in layouts: + if lay.offset != cur.get_layout().offset: + return None, None, None, None, False + + base_lay = cur.get_layout() + return base_lay.size, base_lay.stride, base_lay.offset, base_lay.dtype, True + + +# TODO: Move to a well known place +TritonMetaParams = dict[str, int] +TritonGrid = Union[ + tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: list[triton.Config], # type: ignore[name-defined] + grids: list[TritonGrid], + wrapper: Optional[PythonWrapperCodegen] = None, + original_fxnode_name: Optional[str] = None, +) -> tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid( + grid: TritonGrid, + example_grid: Optional[TritonGrid] = None, + ): + """ + This function return a tuple of two values: the first one is for the real grid + which is used in the generated code; the second one is an example grid with + concreate values which is used in the autotune block to run the generated + kernels at compile time. + """ + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid, grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + if not example_grid: + example_grid = sympy_grid + return ( + wrapper.codegen_python_shape_tuple(sympy_grid), + ( + wrapper.codegen_python_shape_tuple( + tuple( + wrapper.generate_example_arg_value(g, type(g)) + for g in example_grid # type: ignore[union-attr] + ) + ) + if config.triton.autotune_at_compile_time + else None + ), + ) + + def writeline(line: str, example_grid: Optional[str] = None): + output.writeline(line) + if ( + wrapper + and config.triton.autotune_at_compile_time + and name not in wrapper.kernel_autotune_names + ): + wrapper.kernel_autotune_calls.writeline(example_grid or line) + + fn_name = f"grid_wrapper_for_{name}" + writeline(f"def {fn_name}(meta):") + kernel_autotune_calls_indent = ( + wrapper.kernel_autotune_calls.indent() + if wrapper and config.triton.autotune_at_compile_time + else contextlib.nullcontext() + ) + with output.indent(), kernel_autotune_calls_indent: + if ( + config.triton.autotune_at_compile_time + and original_fxnode_name + and V.graph.autotuning_grids + and original_fxnode_name in V.graph.autotuning_grids + ): + example_grids = V.graph.autotuning_grids[original_fxnode_name] + else: + example_grids = [None] * len(grids) + if len(grids) == 1: + grid, example_grid = determine_grid(grids[0], example_grids[0]) + writeline(f"return {grid}", f"return {example_grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen: OrderedSet[str] = OrderedSet() + # sort the configs from the largest # of kwargs to the smallest to + # emit the grids in the order of (approximately) decreasing specificity + # TODO(aakhundov): the sorting below is generally not sufficient, so + # maybe we'll need to restrict the supported cases to identical kwarg + # names in all autotuning configs. + for grid, c, example_grid in sorted( + zip(grids, configs, example_grids), + key=lambda x: len(x[1].kwargs), + reverse=True, + ): + guardslist = [] + if c.kwargs: + # Remove AMD specific kwargs. + for kwarg in c.kwargs: + if kwarg not in [ + "matrix_instr_nonkdim", + "waves_per_eu", + "kpack", + ]: + guardslist.append(f"meta['{kwarg}'] == {c.kwargs[kwarg]}") + if guardslist: + guards = " and ".join(guardslist) + else: + guards = "True" # for configs with empty kwargs + grid, example_grid = determine_grid(grid, example_grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + writeline(statement, f"if {guards}: return {example_grid}") + + return fn_name, output.getvalue() + + +def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str: + """ + Given a triton kernel function pointer collect the transitive closure of + its dependencies + """ + compile_wrapper = IndentedBuffer() + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + import triton + from triton import JITFunction # type: ignore[name-defined, attr-defined] + from triton.language import constexpr # type: ignore[name-defined] + + # global constexpr vars handled above + symbols_included = OrderedSet([kernel.__name__]) + + def traverse(cur_kernel): + # here we extract the unqualified names (i.e., not attributes and + # without prepended module name) loaded in the kernel code, which + # are matched with the co_names and __globals__ below to codegen + # the respective imports necessary for the kernel compilation + unqualified_loads = OrderedSet( + inst.argval + for inst in dis.Bytecode(cur_kernel.fn) + if inst.opname == "LOAD_GLOBAL" + ) + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + # pyrefly: ignore # missing-attribute + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif hasattr(triton, "constexpr_function") and isinstance( + # pyrefly: ignore # missing-attribute + symbol, + # pyrefly: ignore # missing-attribute + triton.runtime.jit.ConstexprFunction, + ): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.constexpr_function") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool, constexpr)): + compile_wrapper.newline() + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol_str}") + symbols_included.add(symbol_name) + elif ( + symbol_name in unqualified_loads + and symbol_name != "tl" # already imported + and hasattr(symbol, "__module__") + # only codegen imports from triton; JITFunctions + # imported from other modules will be codegened + # in the separate branch above + and symbol.__module__.startswith("triton") + ): + # a global symbol imported from triton is referenced + # without module qualification (i.e., `store` instead + # of `tl.store`): need to codegen an import + compile_wrapper.writeline( + f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" + ) + symbols_included.add(symbol_name) + + traverse(kernel) + return compile_wrapper.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: sympy.Symbol + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = ( + collections.defaultdict(list) + ) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> FreeIfNotReusedLine: + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + raise NotImplementedError(f"FX codegen not yet supported for type {type(self)}") + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + graph: GraphLowering + + def __post_init__(self) -> None: + self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_subgraph + + +@dataclasses.dataclass +class ConditionalLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.Conditional + + def codegen(self, code: IndentedBuffer) -> None: + raise NotImplementedError("Only supports FX codegen") + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_conditional + + +@dataclasses.dataclass +class CommentLine(WrapperLine): + line: LineContext + + def codegen(self, code: IndentedBuffer) -> None: + code.writeline(self.line) + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_comment + + +@dataclasses.dataclass +class DynamicScalarLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.DynamicScalar + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._codegen_dynamic_scalar(self.node) + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_dynamic_scalar + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def __post_init__(self) -> None: + self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_subgraph + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" + ) + else: + assert self.last_seen_device_guard_index == self.device_idx, ( + "AOTInductor only supports running on one CUDA device" + ) + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_device_context_manager + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_device_context_manager + + +@dataclasses.dataclass +class ExternKernelAllocLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelAlloc + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs()] + self.wrapper._generate_extern_kernel_alloc_helper(self.node, args) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_alloc + + +@dataclasses.dataclass +class ExternKernelOutLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelOut + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)] + kernel_name = node.get_kernel_name() + if ( + V.graph.cpp_wrapper + and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm" + ): + # For https://github.com/pytorch/pytorch/issues/128474 + kernel_name = "aoti_torch__mm_plus_mm_out" + else: + kernel_name = node.get_kernel_name() + device = d.type if (d := node.get_device()) else V.graph.device_type + self.wrapper._generate_extern_kernel_out_helper( + kernel_name, + node.codegen_reference(), + node.output_view.codegen_reference() if node.output_view else None, + args, + device, + self.node.get_stack_traces(), + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_out + + +@dataclasses.dataclass +class FreeLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: Union[BufferLike, ir.TorchBindObject] + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free + + +@dataclasses.dataclass +class KernelCallLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + call_args: tuple[Any, ...] + raw_keys: tuple[Any, ...] + raw_args: tuple[Any, ...] + arg_types: list[str] + triton: bool + triton_meta: dict[str, Any] + device: torch.device + graph_name: str + original_fxnode_name: str + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_kernel_call_helper( + self.kernel_name, + self.call_args, + triton=self.triton, + arg_types=self.arg_types, + raw_keys=self.raw_keys, + raw_args=self.raw_args, + triton_meta=self.triton_meta, + device=self.device, + graph_name=self.graph_name, + original_fxnode_name=self.original_fxnode_name, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_call + + +@dataclasses.dataclass +class KernelDefinitionLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + kernel_body: str + metadata: Optional[str] = None + gpu: bool = True + cpp_definition: Optional[str] = None + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._define_kernel_helper( + self.kernel_name, + self.kernel_body, + metadata=self.metadata, + gpu=self.gpu, + cpp_definition=self.cpp_definition, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_definition + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: list[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +class EfficientPeakEstimate: + def __init__(self): + from ..memory import estimate_peak_memory, get_freeable_input_buf + + scheduler_nodes = V.graph.scheduler.nodes + graph_inputs = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs = OrderedSet(V.graph.get_output_names()) + names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs) + self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory( + scheduler_nodes, + names_to_freeable_bufs, + graph_outputs, + ) + + from .segmented_tree import SegmentedTree + + self.segmented_tree = SegmentedTree( + peak_by_scheduler_node, operator.add, max, 0 + ) + + def _get_size(self, node: BufferLike) -> int: + return V.graph.sizevars.size_hint( + V.graph.get_allocation_storage_size(node), fallback=0 + ) * get_dtype_size(node.get_dtype()) + + def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine): + return self.segmented_tree.summarize_range( + line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1 + ) + + def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine): + if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index: + return + self.segmented_tree.update_range( + line_a.scheduler_node_index + 1, + line_b.scheduler_node_index - 1, + self._get_size(line_b.node), + ) + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: BufferLike + + def __post_init__(self): + assert V.graph.scheduler.current_node is not None + self.scheduler_node_index = V.graph.scheduler.nodes.index( + V.graph.scheduler.current_node + ) + + def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool: + if free_line.scheduler_node_index + 1 == self.scheduler_node_index: + return True + overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory + peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self) + new_peak_memory = size + peak_memory_in_range + return new_peak_memory <= overall_peak_memory + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + size = V.graph.sizevars.size_hint( + V.graph.get_allocation_storage_size(self.node), fallback=0 + ) * get_dtype_size(self.node.get_dtype()) + if self.should_reuse_buffer(free_line, size): + free_line.is_reused = True + self.wrapper.estimate_peak.update_peak_between(free_line, self) + return ReuseLine(self.wrapper, free_line.node, self.node) + else: + state.push(key, free_line) + return self + + if self.node.get_device_or_error().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_allocate + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: BufferLike + is_reused: bool = False + + def __post_init__(self): + assert V.graph.scheduler.current_node is not None + self.scheduler_node_index = V.graph.scheduler.nodes.index( + V.graph.scheduler.current_node + ) + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if len(self.node.get_inputs_that_alias_output()) > 0: + return self + if isinstance(self.node.layout, ir.MultiOutputLayout): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free_if_not_reused + + +@dataclasses.dataclass +class ReinterpretLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + layout: ir.Layout + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert isinstance(self.layout, ir.NonOwningLayout) + assert isinstance(self.layout.view, ir.ReinterpretView) + self.wrapper.codegen_deferred_allocation( + self.reused_as.get_name(), self.layout.view + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reinterpret + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reuse + + +class NullLine(MemoryPlanningLine): + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_null + + +@dataclasses.dataclass +class CommBufferLine(WrapperLine): + wrapper: PythonWrapperCodegen # type: ignore[name-defined] # noqa: F821 + node: ir.Buffer + + @property + def size(self) -> int: + from torch._inductor.utils import is_symbolic + + numel = self.node.get_numel() + dtype = self.node.get_dtype() + if is_symbolic(numel): + raise AssertionError( + f"The size of a comm buffer can't be symbolic: {self.node}" + ) + return int(numel) * dtype.itemsize + + @property + def comm_buffer_type(self) -> ir.CommBufferType: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.comm_buffer_type + + @property + def group_name(self) -> str: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.group_name + + +@dataclasses.dataclass +class CommBufferAllocateLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + name = self.node.get_name() + device = self.node.get_device() + dtype = self.node.get_dtype() + shape = tuple(self.node.get_size()) + stride = tuple(self.node.get_stride()) + code.writeline( + self.make_allocation_line( + self.comm_buffer_type, + self.group_name, + self.wrapper, + name, + device, + dtype, + shape, + stride, + ) + ) + + @staticmethod + def make_allocation_line( + comm_buffer_type, group_name, wrapper, name, device, dtype, shape, stride + ): + if comm_buffer_type == ir.CommBufferType.SYMM_MEM: + return ( + f"{name} = empty_strided_p2p(" + f"{wrapper.codegen_shape_tuple(shape)}, " + f"{wrapper.codegen_shape_tuple(stride)}, " + f"{dtype}, " + f'torch.device("cuda:{device.index}"), ' + f'group_name="{group_name}", ' + f"alloc_id={random.randint(0, 2**64 - 1)})" + ) + else: + raise NotImplementedError( + f"Unsupported comm buffer type: {comm_buffer_type}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_allocate + + +@dataclasses.dataclass +class CommBufferFreeLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + line = self.wrapper.make_buffer_free(self.node) + code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free") + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_free + + +@dataclasses.dataclass +class MultiOutputLine(WrapperLine): + """ + Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result. + """ + + wrapper: PythonWrapperCodegen + result_name: str + arg_name: str + indices: Sequence[Any] + + def codegen(self, code: IndentedBuffer) -> None: + def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def] + if len(indices) > 0: + itype, i = indices[0] + if issubclass(itype, list): + return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif issubclass(itype, tuple): + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = self.wrapper.codegen_tuple_access( + basename, self.result_name, str(i) + ) + return codegen_list_tuple_access(tuple_access, indices[1:]) + elif issubclass(itype, dict): + return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type: ", itype) + else: + return basename + + value = codegen_list_tuple_access(self.arg_name, self.indices) + code.writeline( + f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_multi_output + + +@dataclasses.dataclass +class IndexPutFallbackLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.IndexPutFallback + indices: list[Optional[ir.IRNode]] + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + assert ir.is_node_sequence(node.inputs) + (x, values) = (t.codegen_reference() for t in node.inputs[:2]) + indices = [ + idx.codegen_reference() if idx else self.wrapper.none_str + for idx in self.indices + ] + + self.wrapper._generate_index_put_fallback( + node.get_kernel_name(), x, indices, values, *node.codegen_const_args() + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_index_put_fallback + + +@dataclasses.dataclass +class ScatterFallbackLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ScatterFallback + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + assert ir.is_node_sequence(node.inputs) + if node.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in node.inputs) + else: + (x, index) = (t.codegen_reference() for t in node.inputs) + src = node.constant_args[1] + device = d.type if (d := node.get_device()) else V.graph.device_type + self.wrapper._generate_scatter_fallback( + x, + [x, node.constant_args[0], index, src], + node.cpp_kernel_name, + node.python_kernel_name, + node.src_is_tensor, + node.kwargs["reduce"], + node.codegen_kwargs(), + device, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_scatter_fallback + + +@dataclasses.dataclass +class SymbolicCallArgLine(WrapperLine): + wrapper: PythonWrapperCodegen + arg: SymbolicCallArg + graph: GraphLowering + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_symbolic_call_arg + + +@dataclasses.dataclass +class UnbackedSymbolDefsLine(WrapperLine): + wrapper: PythonWrapperCodegen + output_name: str + outputs: Any + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._codegen_unbacked_symbol_defs_for_outputs( + self.output_name, self.outputs, self.unbacked_bindings + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_unbacked_symbol_defs + + +BufferName = str +Line = Union[MemoryPlanningLine, LineContext] + + +class PythonWrapperCodegen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + supports_caching = True # Whether the output code is cacheable. + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.args_to_buffers: dict[ + str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject] + ] = {} + self.imports = IndentedBuffer() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.kernel_declarations = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + self.kernel_autotune_defs = IndentedBuffer() + self.kernel_autotune_calls = IndentedBuffer() + self.subgraph_definitions = IndentedBuffer() + self.kernel_autotune_names: OrderedSet[str] = OrderedSet() + # Map key is the kernel argument name; value is a tuple of the resulting example + # tensor name with the kernel where that tensor was most recently used. + self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} + self.kernel_autotune_tmp_arg_idx: int = 0 + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: dict[str, str] = {} + self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet() + self.lines: list[Line] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.comment = "#" + self.none_str = "None" + self.move_begin = "std::move(" if V.graph.cpp_wrapper else "" + self.move_end = ")" if V.graph.cpp_wrapper else "" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {} + self.unbacked_symbol_decls: OrderedSet[str] = ( + OrderedSet() + ) # str of sympy.Symbol + self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet() + self.launcher_fn_name = None + # This function can be overridden to change the launcher name + self.set_launcher_fn_name() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [] + self.computed_sizes_stack = [] + + self.write_header() + + if not is_codegen_graph_partition_subgraph(self): + # See [Note: Removed Graph Partition Arguments] + self.write_prefix() + + self.write_kernel_autotune_defs_header() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated = OrderedSet[BufferName]() + self.freed = OrderedSet[BufferName]() + + # maps from reusing buffer to reused buffer + self.reuses: dict[BufferName, BufferName] = {} + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.cache + def add_import_once(line: str) -> None: + self.imports.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + + self.add_import_once = add_import_once + self._metas: dict[str, str] = {} + self._meta_vars: OrderedSet[str] = OrderedSet() + self.multi_kernel_state = MultiKernelState() + self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet() + self.allocated_workspaces: dict[str, Any] = {} + + # intermediate tensor value printing utility + self.debug_printer = DebugPrinterManager( + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer, + use_array_ref=config.aot_inductor.allow_stack_allocation, + ) + + # Additional files that are dependent to the wrapper (ex. cubin files) + self.additional_files = [] + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None + return SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return PythonWrapperCodegen() + + def set_launcher_fn_name(self) -> None: + # pyrefly: ignore [bad-assignment] + self.launcher_fn_name = "call" + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + context = torch._guards.TracingContext.try_get() + aot_config_comment = "" + if context is not None and context.aot_graph_name is not None: + aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + elif torch._inductor.config.test_configs.track_memory_lifecycle: + inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n" + + self.imports.splice( + f""" + {aot_config_comment} + from ctypes import c_void_p, c_long, c_int + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from cmath import nanj + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + from torch import device, empty_strided + from {async_compile.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + {inductor_debug_utils} + """, + strip=True, + ) + self.header.splice( + """ + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + _quantized = torch.ops._quantized + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + assert_alignment = torch._C._dynamo.guards.assert_alignment + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor + alloc_from_pool = torch.ops.inductor._alloc_from_pool + async_compile = AsyncCompile() + """, + strip=True, + ) + try: + # Only add empty_strided_p2p() if distributed and SymmetricMemory + # is available + from torch._C._distributed_c10d import _SymmetricMemory # noqa: F401 + + self.header.splice( + """ + empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + """, + strip=True, + ) + except (AttributeError, ImportError): + pass + if config.annotate_training: + self.header.writeline("from torch.cuda import nvtx") + + def include_extra_header(self, header: str): + pass + + def write_kernel_autotune_defs_header(self) -> None: + self.kernel_autotune_defs.splice( + f""" + import torch + from torch._dynamo.testing import rand_strided + from torch._dynamo.utils import preserve_rng_state + from torch._inductor.select_algorithm import AlgorithmSelectorCache + from {async_compile.__name__} import AsyncCompile + + async_compile = AsyncCompile() + generate_example_value = AlgorithmSelectorCache.generate_example_value + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + """ + ) + + try: + from torch._C import _cuda_getCurrentRawStream # noqa: F401 + + self.kernel_autotune_defs.splice( + """ + get_raw_stream = torch._C._cuda_getCurrentRawStream + """, + strip=True, + ) + except (ImportError, AttributeError): + pass + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import start_graph, end_graph + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + def write_get_raw_stream_header(self) -> None: + import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as( + "get_raw_stream" + ) + if config.triton.autotune_at_compile_time: + if not self.kernel_autotune_calls.contains(import_get_raw_stream_str): + self.kernel_autotune_calls.writeline(import_get_raw_stream_str) + if not V.graph.cpp_wrapper: + if not self.imports.contains(import_get_raw_stream_str): + self.imports.writeline(import_get_raw_stream_str) + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + self.write_get_raw_stream_header() + + def add_meta_once(self, meta: TritonMetaParams) -> str: + # pyrefly: ignore [bad-assignment] + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + # pyrefly: ignore [unsupported-operation] + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"{var} = {meta}") + self._meta_vars.add(var) + # pyrefly: ignore [index-error] + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> list[str]: + return [ + x.codegen_reference(self.wrapper_call) for x in self.get_graph_outputs() + ] + + def mark_output_type(self) -> None: + return + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: + return V.graph.graph_inputs + + def get_graph_outputs(self) -> list[IRNode]: + return V.graph.graph_outputs + + def codegen_input_size_asserts(self) -> None: + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + # a graph partition may take an IRNode output from a previous partition + if name not in V.graph.graph_input_names or isinstance( + buf, ir.GeneratorState + ): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_python_shape_tuple(buf.get_size()) + stride = self.codegen_python_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_async_compile_wait(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + """ + ) + + def write_args(self, input_names: list[str]): + lhs = ", ".join(input_names) + if len(input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + def write_launcher_fn_call_get_indent(self) -> int: + if config.graph_partition: + self.prefix.splice( + """ + class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + """ + ) + prefix_indent = 2 + else: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + + return prefix_indent + + def get_graph_input_names(self) -> list[str]: + return V.graph.graph_input_names + + def write_prefix(self) -> None: + assert self.launcher_fn_name is not None + self.write_async_compile_wait() + prefix_indent = self.write_launcher_fn_call_get_indent() + + with self.prefix.indent(prefix_indent): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + phase = V.graph.get_training_phase() + if config.annotate_training: + self.prefix.writeline( + f"training_annotation = nvtx._device_range_start('{phase}')" + ) + + if graph_input_names := self.get_graph_input_names(): + self.write_args(graph_input_names) + + self.codegen_inputs() + + # avoid duplicating asserts for both partition functions and + # the call function when using cudagraph partition + if not ( + is_using_cudagraph_partition() + and (not is_codegen_graph_partition_subgraph(self)) + ): + self.codegen_input_size_and_nan_asserts() + + def codegen_input_size_and_nan_asserts(self) -> None: + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes the graph name as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + self.write_get_raw_stream_header() + name = f"stream{device_idx}" + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + f"{name} = get_raw_stream({device_idx})" + ) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return name + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def push_computed_sizes(self, computed_sizes): + from copy import deepcopy + + return self.computed_sizes_stack.append(deepcopy(computed_sizes)) + + def pop_computed_sizes(self): + return self.computed_sizes_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + if config.triton.autotune_at_compile_time: + # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block + self.write_triton_header_once() + self.kernel_autotune_calls.writeline( + f"with {V.graph.device_ops.device_guard(device_idx)}:" + ) + self.kernel_autotune_calls.do_indent() + if is_codegen_graph_partition_subgraph(self): + # Need get_raw_stream for subgraph + self.write_get_raw_stream_header() + self.kernel_autotune_calls.writeline( + f"stream{device_idx} = get_raw_stream({device_idx})" + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.do_unindent() + + def generate_return(self, output_refs: list[str]) -> None: + if output_refs: + if config.nan_asserts: + self.wrapper_call.writeline( + "return_vars = (" + ", ".join(output_refs) + ", )" + ) + self.wrapper_call.writeline("for var in return_vars:") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("if isinstance(var, torch.Tensor):") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("assert not var.isnan().any().item()") + self.wrapper_call.writeline("assert not var.isinf().any().item()") + self.wrapper_call.do_unindent(2) + + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + if config.graph_partition: + all_partition_name_list = ", ".join(self.all_partition_names) + ( + "," if len(self.all_partition_names) == 1 else "" + ) + + result.splice( + f""" + runner = Runner(partitions=[{all_partition_name_list}]) + call = runner.call + recursively_apply_fns = runner.recursively_apply_fns + """ + ) + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + self.writeline(ExternKernelAllocLine(self, node)) + + def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc): + node.codegen_comment(self) + self.writeline(ExternKernelAllocLine(self, node)) + if isinstance(node.layout, ir.Layout): + node.codegen_size_asserts(self) + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + # If it's a NoneLayout then the extern_kernel should essentially be + # treated as if it doesn't return anything + no_return = isinstance(extern_kernel.layout, ir.NoneLayout) + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + + if no_return: + self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") + else: + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out( + self, + node: ir.ExternKernelOut, + ) -> None: + node.codegen_comment(self) + self.writeline(ExternKernelOutLine(self, node)) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + stack_traces: Optional[OrderedSet[str]] = None, + ) -> None: + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") + args.append(f"out={out_view if out_view else out}") + with debug_printer_manager: + self.writeline(f"{kernel}({', '.join(args)})") + + def _generate_tma_descriptor_call_experimental(self, desc, apply_size_hints=False): + dims = desc.dims + block_dims = desc.block_dims + if apply_size_hints: + dims = tuple(V.graph.sizevars.atomically_apply_size_hint(d) for d in dims) + block_dims = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_dims + ) + + ptr = f"{desc.tensor.codegen_reference()}.data_ptr()" + # Explicitly call the Python version of val_to_arg_str + dims = ", ".join(PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in dims) + block_dims = ", ".join( + PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in block_dims + ) + element_size = PythonWrapperCodegen.val_to_arg_str(self, desc.element_size) + prefix = "triton.tools.experimental_descriptor" + fn = f"{prefix}.create_{desc.rank}d_tma_descriptor" + args = f"{ptr}, {dims}, {block_dims}, {element_size}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call_stable(self, desc, apply_size_hints=False): + block_shape = desc.block_shape + if apply_size_hints: + block_shape = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_shape + ) + + prefix = "triton.tools.tensor_descriptor.TensorDescriptor" + fn = f"{prefix}.from_tensor" + args = f"{desc.tensor.codegen_reference()}, {block_shape}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): + if isinstance(desc, ir.TMADescriptorExperimental): + return self._generate_tma_descriptor_call_experimental( + desc, apply_size_hints + ) + else: + assert isinstance(desc, ir.TMADescriptorStable) + return self._generate_tma_descriptor_call_stable(desc, apply_size_hints) + + def generate_tma_descriptor(self, desc): + call = self._generate_tma_descriptor_call(desc) + line = f"{desc.name} = {call}{self.ending}" + self.writeline(line) + + def generate_scatter_fallback(self, node: ir.ScatterFallback): + self.writeline(ScatterFallbackLine(self, node)) + + def _generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + device, + ): + line = f"{python_kernel_name}({','.join(map(str, inputs))}" + if python_kernel_name.startswith("aten.scatter_reduce"): + line += ", ".join([""] + kwargs) + else: + if reduce: + line += f", reduce={repr(reduce)}" + line += ")" + self.writeline(line) + + def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None: + # Collect index tensors into a list. + indices: list[Optional[ir.IRNode]] = [] + valid_indices = node.inputs[2:] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(node.indices): + if node.indices[i] is not None: + index = next(iter_valid_indices) + assert isinstance(index, ir.IRNode) + indices.append(index) + else: + indices.append(None) + + self.writeline(IndexPutFallbackLine(self, node, indices)) + + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"[{', '.join(indices)}]" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})") + + def generate(self, is_inference): + with dynamo_timed("PythonWrapperCodegen.generate"): + return self._generate(is_inference) + + def get_wrapper_call_indent(self) -> int: + if config.graph_partition: + return 2 + else: + return 1 + + @contextlib.contextmanager + def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]: + old = self.writeline + try: + self.writeline = new # type: ignore[method-assign] + yield new + finally: + self.writeline = old # type: ignore[method-assign] + + def _write_multi_kernel_defs(self) -> None: + kernel_defs = self.multi_kernel_state.kernel_defs + if config.triton.autotune_at_compile_time: + self.kernel_autotune_defs.splice(kernel_defs) + else: + self.header.splice(kernel_defs) + + def _generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + self.run_wrapper_ir_passes(is_inference) + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_reset_kernel_saved_flags() + + # At this point, we shouldn't generate any new memory planning lines. + # Override writeline to point at the wrapper call, in case it gets called. + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + if isinstance(line, WrapperLine): + # pyrefly: ignore [missing-attribute] + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + self._write_multi_kernel_defs() + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_save_uncompiled_kernels() + + if config.triton.autotune_at_compile_time: + self.generate_and_run_autotune_block() + + # cpp_wrapper currently doesn't support nvtx + if config.annotate_training and not config.cpp_wrapper: + self.wrapper_call.writeline( + "nvtx._device_range_end(training_annotation)" + ) + self.generate_return(output_refs) + + # Assemble the final code from sections. + result = IndentedBuffer() + result.splice(self.imports) + result.writeline("") + result.splice(self.header) + # We do not want the cpp header for intermediate const graph. Headers would be + # rendered by the main module instead. + if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: + result = IndentedBuffer() + + # Add subgraph definitions to the result + result.splice(self.subgraph_definitions) + self.finalize_prefix() + result.splice(self.prefix) + + wrapper_call_indent = self.get_wrapper_call_indent() + + with result.indent(wrapper_call_indent): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + self.generate_after_suffix(result) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) + + def generate_and_run_autotune_block(self): + """ + Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of + code and execute it to trigger Triton kernel compilation and auto-tuning + """ + self.kernel_autotune_defs.splice( + """ + async_compile.wait(globals()) + del async_compile + """ + ) + scope = {} # type: ignore[var-annotated] + if config.triton.autotune_at_compile_time and V.graph.autotuning_inputs: + scope = { + self.get_autotuning_input_name(idx): v # type: ignore[attr-defined] + for idx, v in enumerate(V.graph.autotuning_inputs) + } + tuning_code = ( + self.kernel_autotune_defs.getvalue() + + "\n" + + self.kernel_autotune_calls.getvalue() + ) + if output_code_log.level == logging.DEBUG: + # Save the autotuning code block into a file + # Create a temporary file + with tempfile.NamedTemporaryFile( + dir=cache_dir(), suffix=".py", delete=False + ) as f: + f.write(tuning_code.encode("utf-8")) + file_path = f.name + output_code_log.debug( + "Auto-tuning code written to %s", + file_path, + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_autotune_at_compile_time_code", + "encoding": "string", + }, + payload_fn=lambda: tuning_code, + ) + # Execute the code to autotune kernels + try: + exec(tuning_code, scope) + except Exception as e: + raise RuntimeError(f"Failed to run autotuning code block: {e}") from e + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + outputs = self.get_graph_outputs() + out_names = V.graph._get_output_names(outputs) + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + # FIXME(rec): not used + _total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + def run_wrapper_ir_passes(self, is_inference: bool): + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + else: + if config.allow_buffer_reuse: + self.estimate_peak = EfficientPeakEstimate() + self.memory_plan_reuse() + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + code.writeline(f"{name}_size = {name}.size()") + return f"{name}_size" + + @functools.cache + def strideof(name): + code.writeline(f"{name}_stride = {name}.stride()") + return f"{name}_stride" + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + code.writeline(f"{value} = {name}") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + if isinstance(size, sympy.Symbol) and size not in bound_vars: + code.writeline(f"{size} = {sizeof(name)}[{dim}]") + bound_vars.add(size) + for dim, stride in enumerate(value.get_stride()): + if isinstance(stride, sympy.Symbol) and stride not in bound_vars: + code.writeline(f"{stride} = {strideof(name)}[{dim}]") + bound_vars.add(stride) + elif isinstance(value, ir.TorchBindObject): + return + elif isinstance(value, ir.GeneratorState): + return + else: + if torch._inductor.config.graph_partition: + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def codegen_inputs(self): + """Assign all symbolic shapes to locals""" + bound_vars = OrderedSet[sympy.Symbol]() + # There is a subtle case in the cpp wrapper codegen which requires generating + # symbol inputs first followed by non-symbol ones. + # + # When a dynamic size constraint specified at the Export time is an expression, + # we need to solve that expression to proper define a symbol in cpp. Thus we + # are enforcing this iterating order here to make sure all plain size symbols + # are defined first. + graph_inputs = self.get_graph_inputs() + inputs = [ + (k, v) for k, v in graph_inputs.items() if isinstance(v, sympy.Symbol) + ] + [(k, v) for k, v in graph_inputs.items() if not isinstance(v, sympy.Symbol)] + for name, value in inputs: + self.codegen_input_symbol_assignment(name, value, bound_vars) + + def _verify_input_symbol_assignment( + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + for expr in chain.from_iterable([value.get_size(), value.get_stride()]): + if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol): + continue + + undefined_symbols = [ + sym for sym in expr.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) > 0: + raise AssertionError( + f"For {expr}, expected {undefined_symbols} to have been codegen-ed." + ) + + # For inputs with size/strides which contain sympy expressions, we can + # encounter symbols that weren't defined yet. Now, let's check each + # symbol is defined. + for _, value in inputs: + if not isinstance(value, ir.TensorBox): + continue + _verify_input_symbol_assignment(value, bound_vars) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + arg = SymbolicCallArg(sym, expr) + self.writeline(SymbolicCallArgLine(self, arg, V.graph)) + + def finalize_prefix(self): + pass + + def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + raise RuntimeError("codegen_cpp_sizevar is only implemented for cpp_wrapper!") + + def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + return pexpr(x, simplify=simplify) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: + parts = [*map(self.codegen_python_sizevar, shape)] + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_python_shape_tuple(shape), + self.codegen_python_shape_tuple(stride), + ] + ) + ), [] + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + # Get the innermost buffer's layout info to help reinterpret view. + # Consider a chain of (ReinterpretView <- TensorBox| StorageBox)... <- buffer + # If we only use x.data to determine the reinterpret, we may get wrong layout. + # For example: + # x = ReinterpretView( + # Storage( + # ReinterpretView( + # storage( + # Buffer(name='buf0', layout=(size=(2, 5, 10), ...) + # ), + # layout=(10, 10), + # ), + # ), + # layout=(10, 10), + # ) + # In this case, x.data.layout == x.layout is (10, 10), the reinterpret view will return buf0, + # but buf0 need to be viewed from (2, 5, 10) to (10, 10). + # So we need to dig into the chain to find the innermost buffer's layout. + d_size, d_stride, d_offset, d_dtype, collapsible = ( + codegen_reinterpret_view_helper(data) + ) + + def apply_reinterpret( + name, tgt_size, tgt_stride, tgt_offset, cast_dtype, base_dtype + ): + s = self.codegen_python_shape_tuple(tgt_size) + st = self.codegen_python_shape_tuple(tgt_stride) + off = self.codegen_sizevar(tgt_offset) + expr = f"reinterpret_tensor({name}, {s}, {st}, {off})" + if cast_dtype is not None and cast_dtype != base_dtype: + return f"aten.view.dtype({expr}, {cast_dtype})" + return expr + + name = data.get_name() + collapsed = collapsible and offset == d_offset + if collapsed: + same_layout = size == d_size and stride == d_stride + base_dtype = d_dtype + else: + same_layout = ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + ) + base_dtype = data.dtype + + if same_layout: + if dtype is not None and dtype != base_dtype: + return f"aten.view.dtype({name}, {dtype})" + return f"{name}" + + return apply_reinterpret(name, size, stride, offset, dtype, base_dtype) + + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): + self.writeline(f"{dst}.copy_({src}, {non_blocking})") + + def codegen_multi_output(self, node: ir.MultiOutput): + result_name = node.get_name() + arg_name = node.input_name(0) + self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + + def codegen_dynamic_select_index(self, node, clamp): + index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" + if clamp: + index_str = f"max(0, min({node.size}, {index_str}))" + self.writeline( + f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) + + def codegen_dynamic_slice_size(self, node): + def clamp_index(x): + pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size))) + neg = self.codegen_sizevar( + sympy.Max(0, sympy.Min(x + node.size, node.size)) + ) + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" + + def codegen_with_step(start_var, end_var, step): + if step == 1: + return f"{end_var} - {start_var}" + step_ = self.codegen_sizevar(step) + return f"({end_var} - {start_var} + {step_} - 1) // {step_}" + + # codegen start, end + sym = node.unbacked_size_symbol + start = clamp_index(node.start) + end = clamp_index(node.end) + self.writeline(f"{sym}_start = {start}") + self.writeline(f"{sym}_end = {end}") + with_step = codegen_with_step(f"{sym}_start", f"{sym}_end", node.step) + self.writeline(f"{sym} = max(0, {with_step})") + self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol)) + + def codegen_dynamic_scalar(self, node): + self.writeline(DynamicScalarLine(self, node)) + + def _codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if len(node.keypath) == 0: + self.writeline(f"{node.sym} = {data}.item()") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + self.writeline(f"{node.sym}_undivided = {data}.item()") + self.writeline( + f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, " + f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'" + ) + self.writeline( + f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + def add_torchbind_input(name, value): + if value is None: + output.writeline(f"{name} = None") + return + + import pickle + + assert isinstance(value, torch.ScriptObject) + + output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + if len(V.graph.torchbind_constants) > 0: + output.writeline("import pickle") + for name, torchbind_obj in V.graph.torchbind_constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_torchbind_input(name, torchbind_obj) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, ir.TorchBindObject): + if len(V.graph.torchbind_constants) == 0: + # otherwise we have already imported the pickle package + output.writeline("import pickle") + output.writeline(f"global {name}") + add_torchbind_input(name, value.get_real_obj()) + elif isinstance(value, sympy.Expr): # Don't need to add symbolic + # TODO: this fallback and those below actually will generate possibly + # invalid benchmark code, because it's not guaranteed 42 + # is actually a valid value for the kernel in question. + # See https://github.com/pytorch/pytorch/issues/124686 + add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + elif isinstance(value, ir.GeneratorState): + add_expr_input( + name, + f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", + ) + else: + shape = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_size() + ] + stride = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_stride() + ] + add_fake_input( + name, + shape, + stride, + value.get_device(), + value.get_dtype(), + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + self.writeline( + KernelDefinitionLine( + self, + kernel_name, + kernel_body, + metadata=metadata, + gpu=gpu, + cpp_definition=cpp_definition, + ) + ) + + @staticmethod + def _format_kernel_definition( + kernel_name: str, kernel_body: str, metadata: Optional[str] = None + ): + if config.triton.autotune_at_compile_time and metadata: + # Generating autotune block + # Need to replace C++ comment starter with Python comment starter + metadata = re.sub(r"^// ", "# ", metadata, flags=re.MULTILINE) + metadata_comment = f"{metadata}\n" if metadata else "" + body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}" + return body + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if config.triton.autotune_at_compile_time and gpu: + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=metadata + ) + self.kernel_autotune_defs.splice(body) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=metadata + ) + self.header.splice(body) + + def define_subgraph_launcher_fn(self, name: str, subgraph_code): + self.subgraph_definitions.splice(subgraph_code.value) + + def define_user_defined_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + grids: list[list[Union[int, sympy.Expr]]], + ): + from ..runtime.triton_heuristics import ( + config_to_dict, + FixedGrid, + PrecomputedGrid, + ) + from .common import ( + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + ) + from .triton import gen_common_triton_imports, TritonKernel + + original_name = kernel.__name__ + signature: list[KernelArgType] = [] + constants: dict[str, Any] = {} + arg_indices: list[int] = [] + equal_to_1_args: list[str] = [] + + def add_to_signature(idx, arg): + signature.append(arg) + arg_indices.append(idx) + + def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False): + if is_constexpr: + if triton_version_uses_attrs_dict(): + # tl.constexpr args appear in the signature in new versions of triton, + # but not in old versions of triton. + add_to_signature(idx, arg) + + if arg.name in kwargs: + # the arg may not appear in kwargs if it is an autotuned arg. + # in this case, it will be added in triton_heuristics after autotuning. + constants[arg.name] = kwargs[arg.name] + + else: + # the only case where arg name isn't in kwargs, should be + # when the arg is a constexpr. + assert arg.name in kwargs + + if equals_1: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the equal-to-1 arg in the signature (labeled as "constexpr"), + # and add the arg as a constant. + # new versions of triton: add the equal-to-1 arg in the signature (labeled as, e.g., "i32"), + # and add the arg as a constant. + add_to_signature(idx, ConstexprArg(name=arg.name)) + else: + add_to_signature(idx, arg) + constants[arg.name] = 1 + elif equals_none: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the none arg in the signature (as a constexpr arg) and as a constant + # old versions of triton: include the none arg as a constant (but not in the signature) + add_to_signature(idx, ConstexprArg(name=arg.name)) + constants[arg.name] = None + else: + add_to_signature(idx, arg) + + arg_names = [p.name for p in kernel.params] + constexprs = [p.num for p in kernel.params if p.is_constexpr] + for idx, key in enumerate(arg_names): + if idx in constexprs: + add_arg(idx, ConstexprArg(name=key), is_constexpr=True) + continue + + if key not in kwargs: + continue + + arg = kwargs[key] + + if kwargs[key] is None: + add_arg(idx, ConstexprArg(name=key), equals_none=True) + else: + if isinstance(arg, ir.TMADescriptor): + api_type, block_shape, dtype = ( + ("stable", arg.block_shape, arg.tensor.get_dtype()) + if isinstance(arg, ir.TMADescriptorStable) + else ("experimental", None, None) + ) + add_arg( + idx, + TMADescriptorArg( + name=key, + api_type=api_type, + block_shape=block_shape, + dtype=dtype, + ), + ) + elif isinstance(arg, ir.Buffer): + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ), + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ), + ) + else: + equals_1 = isinstance( + arg, (int, sympy.Integer) + ) and V.graph.sizevars.statically_known_equals( + arg, + 1, # type: ignore[arg-type] + ) + add_arg(idx, SizeArg(key, arg), equals_1=equals_1) + + triton_signature = signature_to_meta( + signature, + size_dtype=None, # try to infer based on symints + indices=arg_indices, + argdefs=[ArgName(x) for x in kernel.arg_names], + ) + triton_meta: dict[str, Any] = { + "signature": triton_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **dict.fromkeys(equal_to_1_args, 1), + }, + "configs": [ + config_of( + signature, + indices=arg_indices, + ) + ], + } + + if restore_value_args: + triton_meta["restore_value"] = tuple(restore_value_args) + + if reset_to_zero_args: + triton_meta["reset_to_zero"] = tuple(reset_to_zero_args) + + if len(grids) == 1: + # compute the grid in the wrapper and pass it in as an arg + inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args() + extra_launcher_call_args = [*map(sympy.sympify, grids[0])] + else: + + def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: + if isinstance(expr, sympy.Expr): + symbols = [*expr.free_symbols] + if not symbols: + return expr + symbols.sort(key=str) + for sym in symbols: + if sym in extra_launcher_args: + continue + extra_launcher_args[sym] = sympy.Symbol( + f"_launcher_s{len(extra_launcher_args)}" + ) + return sympy_subs(expr, extra_launcher_args) + assert isinstance(expr, int) + return sympy.Integer(expr) + + extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {} + grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids] + + assert grids and len(grids) == len(configs) + precomputed_grids = [] + for grid, cfg in sorted( + zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True + ): + precomputed_grids.append( + { + "config": config_to_dict(cfg), + "python": [*map(pexpr, grid)], + "cpp": [*map(cexpr, grid)], + "python_slow": [*map(pexpr, grid)], + } + ) + inductor_meta = { + "grid_type": PrecomputedGrid.__name__, + "precomputed_grids": precomputed_grids, + "extra_launcher_args": [*map(str, extra_launcher_args.values())], + } + extra_launcher_call_args = [*extra_launcher_args.keys()] + + # Distinguish between different functions using function id + cache_key: Any = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key.extend(str(inductor_meta)) + cache_key = tuple(cache_key) + if cache_key in self.user_defined_kernel_cache: + return ( + *self.user_defined_kernel_cache[cache_key], + extra_launcher_call_args, + ) + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + + compile_wrapper = IndentedBuffer() + if config.triton.unique_user_kernel_names: + compile_wrapper.writeline(f"async_compile.triton({name!r}, '''") + else: + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + inductor_meta["kernel_name"] = name + inductor_meta.update(TritonKernel.inductor_meta_common()) + + compile_wrapper.splice(gen_common_triton_imports()) + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={[*map(config_to_dict, configs)]!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + kernel_src = user_defined_triton_kernel_transitive_closure_source_code(kernel) + if config.triton.unique_user_kernel_names: + # We replace the original_name with the unique name. + kernel_src = kernel_src.replace(f"def {original_name}(", f"def {name}(") + kernel_src = kernel_src.replace("'''", "\\'\\'\\'") + compile_wrapper.splice(kernel_src) + + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + return name, triton_meta, extra_launcher_call_args + + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): + sym_name = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + sym_name += f"_{suffix}" + sym = sympy.Symbol(sym_name, is_integer=True, is_positive=True) + + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + arg = SymbolicCallArg(sym, tree.numel) + + is_benchmark_kernel = kernel_name == "" + if not is_benchmark_kernel: + self.writeline(SymbolicCallArgLine(self, arg, V.graph)) + + return arg + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + self.writeline(f"{arg.inner} = {pexpr(arg.inner_expr)}") + + def generate_workspace_allocation(self, ws: WorkspaceArg): + name = ws.get_name() + line = AllocateLine(self, ws) + if ws.zero_mode == WorkspaceZeroMode.UNINITIALIZED: + self.writeline(line) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH: + prior = self.allocated_workspaces.get(name) + if prior: + assert isinstance(prior, AllocateLine) and isinstance( + prior.node, WorkspaceArg + ) + # expand existing allocation + prior.node = WorkspaceArg.maximum(prior.node, ws) + else: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + self.allocated_workspaces[name] = line + else: + raise AssertionError(ws.zero_mode) + + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_allocation( + self, + name, + ws.device, + ws.dtype, + shape=(V.graph.sizevars.size_hint(ws.count),), + stride=(1,), + ) + ) + if ws.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_zero_buffer(self, name) + ) + + def generate_workspace_deallocation(self, ws: WorkspaceArg): + if ws.zero_mode != WorkspaceZeroMode.ZERO_PER_GRAPH: + self.writeline(FreeIfNotReusedLine(self, ws)) + + def make_zero_buffer(self, name): + return f"{name}.zero_(){self.ending}" + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_gpu_kernel( + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def prepare_triton_kernel_call(self, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return pexpr(V.graph.sizevars.simplify(arg)) + + return [wrap_arg(arg) for arg in call_args] + + def generate_example_arg_value(self, arg, arg_type, raw_arg=None): + if isinstance(arg_type, torch_dtype): + if isinstance(raw_arg, ir.TMADescriptor): + # first we generate the underlying buffer + buf_name = raw_arg.get_tensor().get_name() + buf = self.args_to_buffers[arg] + elif self.args_to_buffers.get(arg): + buf_name = arg + buf = self.args_to_buffers[arg] + else: + assert raw_arg is not None, ( + "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + ) + buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}" + buf = raw_arg + self.kernel_autotune_tmp_arg_idx += 1 + + assert buf is not None, f"Failed to find a buffer for arg {arg}" + size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_size() + ) + allocation_size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in V.graph.get_allocation_size(buf) + ) + stride = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_stride() + ) + device = buf.get_device() + dtype = buf.get_dtype() + offset = V.graph.sizevars.size_hint( + buf.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ) + value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})" + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + if isinstance(raw_arg, ir.TMADescriptor): + # generate another line initializing a host-side TMA + # descriptor from the underlying buffer created above + value = self._generate_tma_descriptor_call( + desc=raw_arg, + apply_size_hints=True, + ) + buf_name = arg + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + return buf_name + elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg): + # arg is a symbol or symbolic expression + if isinstance(arg, str): + if arg in self._meta_vars: + return arg + if raw_arg is None: + return "None" + arg = raw_arg + if isinstance(arg, SymbolicCallArg): + arg = arg.inner_expr + if arg in V.graph.sizevars.inv_precomputed_replacements: + arg = V.graph.sizevars.inv_precomputed_replacements[arg] + + return str( + V.graph.sizevars.atomically_apply_size_hint( + arg, fallback=config.unbacked_symint_fallback + ) + ) + + elif isinstance(arg, (str, int, float, bool)): + return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" + else: + raise NotImplementedError(f"Unsupported type {type(arg)}") + + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True, + and C++ when gpu=False. + """ + + # Store buffers corresponding to each call arg. + # This is used to generate example args for autotuning later on. + self.args_to_buffers.update( + { + arg: V.graph.try_get_buffer(arg) + for arg in call_args + if isinstance(arg, str) + } + ) + + device = device or V.graph.get_current_device_or_throw() + self.writeline( + KernelCallLine( + self, + kernel_name=kernel_name, + call_args=call_args, + # pyrefly: ignore [bad-argument-type] + raw_keys=raw_keys, + # pyrefly: ignore [bad-argument-type] + raw_args=raw_args, + # pyrefly: ignore [bad-argument-type] + arg_types=arg_types, + triton=triton, + # pyrefly: ignore [bad-argument-type] + triton_meta=triton_meta, + device=device, + graph_name=V.graph.name, + # pyrefly: ignore [bad-argument-type] + original_fxnode_name=original_fxnode_name, + ) + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + if not triton and device.type != "cuda": + if device.type == "cpu": + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + elif device.type == "mps": + # TODO: Fix me, MPS does not expose streams now + self.writeline( + self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args) + ) + else: + raise RuntimeError(f"device {device.type} nyi") + return + + call_args_str = self.prepare_triton_kernel_call(call_args) + call_args_str = ", ".join(call_args_str) + stream_name = PythonWrapperCodegen.write_get_raw_stream( + self, device.index, graph_name + ) + if not triton: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + ) + return + + self.write_triton_header_once() + + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len(arg_types), ( + "call_args and arg_types do not match" + ) + + autotune_args = None + if original_fxnode_name and V.graph.autotuning_mapping: + autotune_args = V.graph.autotuning_mapping.get( + original_fxnode_name, None + ) + + def get_autotune_deletion_call() -> str: + """After all the autotune kernel calls have been written (i.e. + self.kernel_autotune_example_args is complete), returns a deletion call + for all autotune example tensors that are unnecessary after kernel_name + is called.""" + tensors_to_delete = [ + tensor + for tensor, kn in self.kernel_autotune_example_args.values() + if kn == kernel_name + ] + if tensors_to_delete: + return f"del {', '.join(tensors_to_delete)}\n" + return "" + + def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): + """We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args. + This is particularly useful for jagged cases, where the dimension is often + being passed in as an input.""" + + target_arg = raw_args[idx] + if target_arg in reused_args: + return True + + for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)): + if i == idx or not isinstance(raw_arg, IRNode): + continue + + triton_input = "" + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + if triton_input == "": + continue + + try: + layout = raw_arg.get_layout() + for dim, s in enumerate(layout.size): + if s == target_arg: + reused_args[target_arg] = f"{triton_input}.shape[{dim}]" + return True + except NotImplementedError: + # If layout for this IRNode is not implemented, we could just skip. + # Only raise for other Error cases. + continue + return False + + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + assert raw_keys is None, "keys are not None but args are" + raw_keys = [None] * len(call_args) + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len(call_args), ( + "call_args and raw_args do not match" + ) + + reused_args = {} + for i, (arg, arg_type, raw_key, raw_arg) in enumerate( + # pyrefly: ignore [no-matching-overload] + zip(call_args, arg_types, raw_keys, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + triton_input: Optional[str] = None + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + + if triton_input: + arg_str = triton_input + if not isinstance(arg_type, torch_dtype) and ( + issubclass(arg_type, sympy.Basic) + or isinstance(arg, SymbolicCallArg) + ): + reused_args[raw_arg] = arg_str + elif raw_key == "" and infer_arg_by_inputs( + raw_keys, raw_args, i, reused_args + ): + # Empty raw_key means this is a arg that's not native to the triton kernel, + # and is being added by inductor. + arg_str = reused_args[raw_arg] + elif isinstance(arg_type, torch_dtype): + # workspace allocation is already generated by `generate_workspace_allocation()` + # in `TritonKernel.call_kernel()`. + if re.match(r"^(workspace|semaphore)", arg): + arg_str = arg + elif arg not in self.kernel_autotune_example_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg + ) + else: + arg_str = self.kernel_autotune_example_args[arg][0] + self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) + else: + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) + all_args.append(arg_str if key is None else f"{key}={arg_str}") + + # Make sure kernel launch under a device guard because models don't always run on device 0 + self.kernel_autotune_calls.writeline( + f"with {V.graph.device_ops.device_guard(device.index)}:" + ) + self.kernel_autotune_calls.do_indent() + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})" + ) + self.kernel_autotune_calls.do_unindent() + + self.kernel_autotune_calls.writeline( + DelayReplaceLine("", get_autotune_deletion_call, "") + ) + self.kernel_autotune_names.add(kernel_name) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})") + self.write_triton_header_once() + + def writeline(self, line): + self.lines.append(line) + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_arg_str(self, s, type_=None): + from torch.utils._triton import has_triton_package + + if has_triton_package(): + import triton + + if isinstance(s, SymTypes): + return pexpr(s.node.expr) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + # Explicitly call the Python version of val_to_arg_str + return repr( + type(s)(Shim(PythonWrapperCodegen.val_to_arg_str(self, a)) for a in s) + ) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ir.MutableBox, ReinterpretView)): + return s.codegen_reference() + elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] + return repr(s) + elif isinstance(s, ir.GeneratorState): + return s.codegen_reference() + elif is_opaque_value_type(type(s)): + opaque_type = type(s) + V.graph.opaque_value_type_classes[opaque_type.__name__] = opaque_type + return repr(s) + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer: BufferLike): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + allocation_shape = tuple(V.graph.get_allocation_size(buffer)) + stride = tuple(buffer.get_stride()) + is_pinned = buffer.get_is_pinned() + return self.make_allocation( + buffer.get_name(), device, dtype, shape, stride, allocation_shape, is_pinned + ) + + @cache_on_self + def write_memory_track_allocation_once(self): + import_str = """ + from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor + """ + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False + ): + if allocation_shape is None: + allocation_shape = shape + + codegen_shape_tuple = self.codegen_python_shape_tuple(shape) + codegen_allocation_shape_tuple = self.codegen_python_shape_tuple( + allocation_shape + ) + codegen_stride_tuple = self.codegen_python_shape_tuple(stride) + if torch._inductor.config.test_configs.track_memory_lifecycle: + out = ( + f"{name} = tracked_empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"dtype={dtype}, " + f"device='{device.type}', " + f"name='{name}')" + ) + elif device.type == "cpu" and is_pinned: + out = ( + f"{name} = empty_strided_cpu_pinned(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) + elif device.type in ("cpu", "cuda", "xpu", "mtia"): + # optimized path for faster allocations, saving ~2us versus the stuff below + out = ( + f"{name} = empty_strided_{device.type}(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) + # all other devices: + else: + out = ( + f"{name} = empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"device='{device.type}', dtype={dtype})" + ) + if codegen_shape_tuple != codegen_allocation_shape_tuple: + # need an extra as_strided call + out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})" + return out + + def make_comment(self, line): + self.writeline(CommentLine(line)) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer: Union[BufferLike, ir.TorchBindObject]): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: list[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def write_provenance_debug_handle( + self, + kernel_name, + debug_handle: Optional[int] = None, + ): + if debug_handle is not None: + self.writeline( + f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}" + ) + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name: str, view: ir.ReinterpretView) -> None: + self.writeline( + DeferredLine( + name, + f"{self.declare}{name} = {view.codegen_reference()}{self.ending} {self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + + if ( + name in V.graph.removed_buffers + or name in self.allocated + or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer, ir.InputBuffer)) + ): + return + self.allocated.add(name) + if ( + isinstance( + buffer.get_defining_op(), + (ir.ExternKernelAlloc, ir.MultiOutput), + ) + and not buffer.should_allocate() + ): + return + + layout = buffer.get_output_spec() + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + return + if isinstance(layout, ir.NoneLayout): + return + if isinstance(layout, ir.NonOwningLayout): + assert isinstance(layout.view, ir.ReinterpretView), ( + f"unexpected {type(layout.view)}: {layout.view}" + ) + box = layout.view.data + assert isinstance(box, ir.StorageBox), type(box) + input_buffer = box.data + assert isinstance(input_buffer, (ir.Buffer, ir.ReinterpretView)), type( + input_buffer + ) + if isinstance(input_buffer, ir.ReinterpretView): + + def unwrap_views(target) -> ir.Buffer: + if isinstance(target, ir.BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, ir.MutableBox): + return unwrap_views(target.data) + assert isinstance(target, ir.Buffer), type(target) + return target + + input_buffer = unwrap_views(input_buffer) + self.codegen_allocation(input_buffer) + self.writeline(ReinterpretLine(self, input_buffer, buffer, layout)) + return + + if isinstance(layout, ir.CommBufferLayout): + self.writeline(CommBufferAllocateLine(self, buffer)) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)): + self.writeline(FreeLine(self, buffer)) + return + + if isinstance(buffer.get_output_spec(), ir.CommBufferLayout): + # Comm buffers are not eligible for in-place reuse. Their reuse is + # achieved exclusively via buffer planning. + self.writeline(CommBufferFreeLine(self, buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + return not ( + name in V.graph.removed_buffers + or ( + name in V.graph.graph_inputs + and not isinstance( + V.graph.graph_inputs_original[name], ir.DonatedBuffer + ) + ) + or name in V.graph.constants + or name in V.graph.torchbind_constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ) + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer): + assert can_match_buffer_size(input_buffer, output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_unbacked_symbol_defs_for_outputs( + self, + output_name: str, + outputs: Any, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + self.writeline( + UnbackedSymbolDefsLine(self, output_name, outputs, unbacked_bindings) + ) + + def _codegen_unbacked_symbol_defs_for_outputs( + self, + output_name: str, + outputs: Any, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + if not unbacked_bindings: + return + + # This code is designed to generate code expressions from symbolic paths (keypaths) + # associated with certain symbols (unbacked bindings). These keypaths describe how + # to access the unbacked symbol in a structured way. + # For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath + # describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]]. + for s, keypath in unbacked_bindings.items(): + # `go` recursively constructs a code expression by processing each element of + # the keypath and construct the expression incrementally. + # For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)], + # it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...]) + def go(expr: str, keypath: pytree.KeyPath): + if keypath == (): + return expr + + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + return go( + f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] + ) + elif isinstance(keypath[0], CallMethodKey): + return go(f"{expr}.{keypath[0].name}()", keypath[1:]) + elif isinstance(keypath[0], pytree.SequenceKey): + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + # TODO: this is invalid C++ codegen + return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + # `go_outer` manages the top-level logic for generating the final expression. + # It handles special cases for C++ code generation and adjusts + # the keypath based on the context (e.g., single vs. multiple outputs). + def go_outer(): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper: + # Special handling for the top level buffer access, + # because self.get_name() is actually never bound; the + # individual output arguments are bound by + # generate_c_shim_fallback_kernel + if len(outputs) == 1: + out = outputs[0] + # When fallback kernel returns a list consisting of a single tensor, + # the output is represented as a MultiOutput with non empty indices. + # In this case, we strip the first key path away. + return go( + outputs[0].get_name(), + keypath[1:] + if isinstance(out, ir.MultiOutput) and len(out.indices) != 0 + else keypath, + ) + else: + assert isinstance(keypath[0], pytree.SequenceKey) + return go(outputs[keypath[0].idx].get_name(), keypath[1:]) + else: + return go(output_name, keypath) + + self.writeline( + f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}" + ) + + def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # `codegen_subgraph` function. + # + # However this does not work with cpp wrapper. With cpp wrapper, we make + # two passes and the kernels are shared from the first pass to the next. + # Therefore, both the Python and CppWrapper need to share the some + # codegen infra. For now, CppWrapperCpu has not been updated to lift the + # subgraph as functions. Therefore for cpp_wrapper first pass with + # PythonWrapper, we still fallback to the old way of inlining subgraphs + # in the output code. Once we update CppWrapperCpu, we can remove this + # function. + def _codegen_subgraph_prefix(): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + for inner_input, outer_input in zip( + subgraph.graph.graph_inputs, outer_inputs + ): + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def _codegen_subgraph_suffix(): + assert len(subgraph.graph.graph_outputs) == len(outer_outputs) + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + _codegen_subgraph_prefix() + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + _codegen_subgraph_suffix() + finally: + self.pop_codegened_graph() + + def codegen_partition_call( + self, + partition_id: int, + partition_signatures: ir.GraphPartitionSignature, + ): + """Generate code to call a graph partition""" + input_deallocation = partition_signatures.input_deallocation + output_nodes = partition_signatures.output_nodes + + input_names = list(input_deallocation.keys()) + [ + symbol_input.name for symbol_input in partition_signatures.symbol_inputs + ] + + inputs = ", ".join(input_names) + ("," if len(input_names) == 1 else "") + + output_names = [node.get_name() for node in output_nodes] + outputs = ", ".join(output_names) + ("," if len(output_nodes) == 1 else "") + + # Create a list of inputs for the subgraph call + self.writeline(f"partition{partition_id}_args = [{inputs}]") + + names_to_del = [ + name for name, deallocate in input_deallocation.items() if deallocate + ] + if names_to_del: + self.writeline(f"del {', '.join(names_to_del)}") + + # Call the subgraph launcher function + self.writeline( + f"({outputs}) = self.partitions[{partition_id}](partition{partition_id}_args)" + ) + self.writeline(f"del partition{partition_id}_args") + + def set_all_partition_names(self, num_partitions: int): + self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)] + + def codegen_subgraph_call_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + # Get the input and output names of the subgraph + outer_output_names = ", ".join(outer_flattened_outputs) + ( + "," if len(outer_flattened_outputs) == 1 else "" + ) + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Call the subgraph launcher function + self.writeline( + f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name): + # Get the input and output names of the subgraph + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Since the buffers are already put into the args list, we can free the + # buffers here. + V.graph.scheduler.free_buffers() + + # Call the subgraph launcher function + self.writeline( + f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_common(self, subgraph): + self.push_codegened_graph(subgraph.graph) + self.make_comment("") + self.make_comment(f"{self.comment} subgraph: {subgraph.name}") + + parent_graph = V.graph + subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper + subgraph.graph.fx_wrapper = parent_graph.fx_wrapper + + if subgraph.graph.name not in self.already_codegened_subgraphs: + # If it is already codegened, the parent wrapper already has + # subgraph fn by name subgraph.graph.name + with V.set_graph_handler(subgraph.graph): + # do not graph partition for subgraph + with config.patch("graph_partition", False): + # Call the codegen of subgraph recursively + subgraph_code, _ = subgraph.graph.codegen() + subgraph_name = subgraph.graph.name + self.already_codegened_subgraphs.add(subgraph_name) + self.define_subgraph_launcher_fn(subgraph_name, subgraph_code) + + def codegen_subgraph_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call_with_flattened_outputs( + subgraph, outer_inputs, outer_flattened_outputs + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name): + # Codegen subgraph by recursively calling the codegen for the subgraph. + # This lifts the subgraph as a function in the output code. + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name) + + def codegen_invoke_subgraph(self, invoke_subgraph): + name = invoke_subgraph.get_name() + + self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}") + outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs] + + if V.graph.aot_mode: + outer_outputs = [ + f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs)) + ] + self.codegen_subgraph_by_inlining( + invoke_subgraph.subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name) + + def codegen_conditional(self, conditional) -> None: + name = conditional.get_name() + + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + + predicate = conditional.predicate.codegen_reference() + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # move the Tensor predicate to host + predicate = f"{predicate}.item()" + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {predicate}:") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.true_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name) + + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.false_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name) + self.writeline(ExitSubgraphLine(self)) + + def codegen_while_loop(self, while_loop, stack_output): + """while_loop is codegened as a host side while_loop""" + + def codegen_subgraph(subgraph, outer_inputs, outer_outputs): + """Helper method to deduplicate subgraph codegen logic""" + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs) + else: + self.codegen_subgraph_with_flattened_outputs( + subgraph, outer_inputs, outer_outputs + ) + + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + + ckp_offset = len(outer_carried_inputs) + self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") + if stack_output: + self.writeline( + f"{name}.extend([[] for _ in range({len(outer_carried_inputs)})])" + ) + + for i, inp in enumerate(outer_carried_inputs): + # set the initial state before the loop + self.writeline(f"{name}[{i}] = {inp}") + + cond_outer_inputs = [ + *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], + *outer_additional_inputs, + ] + cond_outer_outputs = [f"{name}_cond_result"] + body_outer_inputs = list( + cond_outer_inputs + ) # same inputs for cond_fn and body_fn + # Carry over the state from body_fn. Note: We only carry over + # the carried_inputs part of the inputs, the additional ones + # are passed in as they're before. + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + # Check condition at the beginning and set up flag + codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline(f"should_loop = {cond_outer_outputs[0]}") + self.writeline("if not should_loop:") + if stack_output: + # Handle the case when loop never executes + for i, carried_input in enumerate(outer_carried_inputs): + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()") + self.writeline(ExitSubgraphLine(self)) + else: + for i, carried_input in enumerate(outer_carried_inputs): + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.writeline(f"{name}[{i}] = {carried_input}.clone()") + self.writeline(ExitSubgraphLine(self)) + + self.writeline("while should_loop:") + # Body execution + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + + # Collect outputs if enabled + if stack_output: + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + for i in range(len(outer_carried_inputs)): + self.writeline(f"{name}[{i + ckp_offset}].append({name}[{i}])") + self.writeline(ExitSubgraphLine(self)) + + # Condition check at end of loop + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline(f" should_loop = {cond_outer_outputs[0]}") + + # Stack outputs after loop completion + if stack_output: + self.writeline("# Stack outputs after loop completion") + for i in range(len(outer_carried_inputs)): + self.writeline(f"if len({name}[{i + ckp_offset}]) > 0:") + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.writeline( + f"{name}[{i}] = torch.stack({name}[{i + ckp_offset}], dim=0)" + ) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + if getattr(x, "free_symbols", None): + # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but + # the actual codegen will still generate the full expression here. + return None + if isinstance(x, int): + return x + val = V.graph._shape_env._maybe_evaluate_static(x) + if val is None: + return val + return int(val) # type: ignore[call-overload] + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = PythonWrapperCodegen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return ( + PythonWrapperCodegen.statically_known_list_of_ints_or_none(lst) is not None + ) + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return PythonWrapperCodegen.statically_known_list_of_ints_or_none( + buffer.get_size() + ) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + return + + def write_kernel_context_guard_begin( + self, + ): + """ + Mark the beginning of kernel context guard + """ + return + + def write_kernel_context_guard_end( + self, + ): + """ + Mark the end of kernel context guard + """ + return + + +class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): + """ + A wrapper codegen that generates code for a subgraph. For most of the + methods, we rely on the implementation in the PythonWrapperCodegen. But we + override a few functions to produce cleaner code (like avoiding writing + imports twice in the output code) + """ + + def __init__( + self, + subgraph_name: str, + parent_wrapper: PythonWrapperCodegen, + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # It is necessary to set the subgraph_name before calling super __init__ + # because __init__ calls set_launcher_fn_name + self.subgraph_name = subgraph_name + self.parent_wrapper = parent_wrapper + self.partition_signatures = partition_signatures + + super().__init__() + + root = self.get_root_graph() + # Only generate auto-tuning block in the main graph + self.kernel_autotune_defs = root.kernel_autotune_defs + self.kernel_autotune_calls = root.kernel_autotune_calls + # Only store kernel src to name mapping in the main graph + self.src_to_kernel = root.src_to_kernel + # Same here, only define user-defined Triton kernels in the main graph + self.user_defined_kernel_cache = root.user_defined_kernel_cache + + def set_launcher_fn_name(self) -> None: + # This sets up the name of the function containing the launcher code of + # the subgraph. + # pyrefly: ignore [bad-assignment] + self.launcher_fn_name = self.subgraph_name + + def write_header(self) -> None: + pass + + def add_benchmark_harness(self, output): + pass + + def benchmark_compiled_module(self, output): + pass + + def write_async_compile_wait(self): + pass + + def next_kernel_suffix(self) -> str: + # Ensures that subgraphs kernels do not clash with each other + return self.parent_wrapper.next_kernel_suffix() + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + return + + def write_launcher_fn_call_get_indent(self) -> int: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + return prefix_indent + + def get_wrapper_call_indent(self) -> int: + return 1 + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]: + if signature := self.partition_signatures: + inputs = signature.input_nodes | { + str(s): s for s in signature.symbol_inputs + } + else: + inputs = V.graph.graph_inputs + return inputs + + def get_graph_input_names(self) -> list[str]: + if signature := self.partition_signatures: + names = list(signature.input_nodes.keys()) + [ + symbol_input.name for symbol_input in signature.symbol_inputs + ] + else: + names = V.graph.graph_input_names + return names + + def get_graph_outputs(self) -> list[IRNode]: + if signature := self.partition_signatures: + outputs = signature.output_nodes + else: + outputs = V.graph.graph_outputs + return outputs + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + if (signature := self.partition_signatures) and name in signature.input_nodes: + # skip allocation if buffer is a subgraph input. + # This allows reusing an input buffer in graph partition, + # although this is not allowed in general. + return + + super().codegen_allocation(buffer) + + @cache_on_self + def write_triton_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # import_str = self.triton_header_str() + # self.kernel_autotune_calls.splice(import_str) + self.parent_wrapper.write_triton_header_once() + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # self.kernel_autotune_calls.writeline( + # V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + # ) + self.parent_wrapper.write_get_raw_stream_header_once() + + @cache_on_self + def get_root_graph(self) -> PythonWrapperCodegen: + root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self + while isinstance(root, SubgraphPythonWrapperCodegen): + root = root.parent_wrapper + + assert isinstance(root, PythonWrapperCodegen) + return root + + def generate_and_run_autotune_block(self): + # Only execute auto-tuning block in the main graph + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py new file mode 100644 index 0000000000000000000000000000000000000000..02c498d6debce64609751edae5c4e9287797fa6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py @@ -0,0 +1,1213 @@ +import dataclasses +import functools +import logging +import operator +import textwrap +from collections import Counter +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union + +import sympy + +import torch +from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, +) +from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + tracing_triton_hopifier_singleton, + triton_kernel_wrapper_mutation, +) +from torch._inductor.codecache import LambdaFuture, PyCodeCache +from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.select_algorithm import extern_kernels # noqa: F401 +from torch._inductor.utils import convert_to_symint +from torch._inductor.virtualized import V +from torch._library.triton import wrap_triton +from torch.fx import GraphModule +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + ConvertIntKey, + DivideByKey, + free_unbacked_symbols, +) +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv +from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp +from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis +from torch.utils._sympy.solve import try_solve + +from .. import config, ir +from ..runtime.triton_compat import Config +from ..utils import cache_property_on_self, LineContext, ValueWithLineMap +from .common import ( + CodegenSymbol, + FileBackedGraphModule, + WorkspaceArg, + WorkspaceZeroMode, +) +from .wrapper import ( + AllocateLine, + BufferLike, + CommBufferAllocateLine, + CommBufferFreeLine, + CommentLine, + ConditionalLine, + DynamicScalarLine, + EnterDeviceContextManagerLine, + EnterSubgraphLine, + ExitDeviceContextManagerLine, + ExitSubgraphLine, + ExternKernelAllocLine, + ExternKernelOutLine, + FreeIfNotReusedLine, + FreeLine, + IndexPutFallbackLine, + KernelCallLine, + KernelDefinitionLine, + Line, + MultiOutputLine, + NullLine, + PythonWrapperCodegen, + ReinterpretLine, + ReuseLine, + ScatterFallbackLine, + SubgraphPythonWrapperCodegen, + SymbolicCallArg, + SymbolicCallArgLine, + UnbackedSymbolDefsLine, + WrapperLine, +) + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class SymbolBuffer(CodegenSymbol): + """ + Represents a sympy.Symbol graph input. + """ + + symbol: sympy.Symbol + + def get_name(self) -> str: + return str(self.symbol) + + def get_example(self) -> Union[torch.Tensor, torch.SymInt]: + sym_int = convert_to_symint(self.symbol) + assert isinstance(sym_int, torch.SymInt) + return sym_int + + +CodegenBuffer = Union[BufferLike, SymbolBuffer] + + +@dataclasses.dataclass +class TritonKernel: + """ + Stores metadata about Triton kernels for use in FX. + """ + + tuner: CachingAutotuner + wrapped: TraceableTritonKernelWrapper + + +def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: + """ + Replace sympy.floor with FloorDiv. + """ + + def replace(expr: sympy.Expr) -> sympy.Expr: + expr = sympy.together(expr) + + # Division is represented as a Mul with a Rational factor or a Pow with negative + # exponent. We convert floor(Mul(...)) to FloorDiv(numerator, denominator) by + # partitioning factors into the numerator and denominator. + (numerator, denominator) = (sympy.S.One,) * 2 + for arg in sympy.Mul.make_args(expr): + if isinstance(arg, sympy.Rational): + numerator *= arg.numerator + denominator *= arg.denominator + elif isinstance(arg, sympy.Pow) and arg.exp.is_negative: + denominator *= arg.base**-arg.exp + else: + numerator *= arg + + return FloorDiv(numerator, denominator) + + return expr.replace(sympy.floor, replace) + + +class WrapperFxCodegen(PythonWrapperCodegen): + """ + Backend to generate wrapper code as an FX IR graph. + """ + + supports_caching = False + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.subgms: dict[str, torch.fx.GraphModule] = {} + + def codegen_inputs(self) -> None: + """ + This would generate code for symbolic input shapes, strides, etc. + Since the FX converter handles this, do nothing here. + """ + + def codegen_conditional(self, conditional: ir.Conditional) -> None: + """ + Conditional codegen normally emits a number of different wrapper lines. + Instead, FX conversion uses a dedicated line for the whole conditional. + """ + self.writeline(ConditionalLine(self, conditional)) + for subgraph in (conditional.true_subgraph, conditional.false_subgraph): + self.codegen_subgraph_common(subgraph) + + def define_subgraph_launcher_fn( + self, name: str, subgraph_code: Union[ValueWithLineMap, FileBackedGraphModule] + ) -> None: + """ + Record subgms as they're generated. + """ + assert isinstance(subgraph_code, FileBackedGraphModule) + self.subgms[name] = subgraph_code.gm + + @property + @cache_property_on_self + def is_subgraph(self) -> bool: + return isinstance(self, SubgraphPythonWrapperCodegen) + + def get_fx_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]: + """ + Get the input nodes corresponding to FX graph placeholders. + """ + # pyrefly: ignore [missing-argument] + if V.aot_compilation and not self.is_subgraph: + # AOT graphs must match the signature of the input module. + return { + node.name: V.graph.graph_inputs.get(node.name) + for node in V.graph.module.graph.find_nodes(op="placeholder") # type: ignore[operator, union-attr] + } + + return self.get_graph_inputs() + + def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]: + self.run_wrapper_ir_passes(is_inference) + + prologue = "\n".join( + [ + self.imports.getvalue(), + self.header.getvalue(), + ] + ) + gm = FxConverter( + lines=self.lines, + prologue=prologue, + graph_inputs=self.get_fx_graph_inputs(), + graph_outputs=self.get_graph_outputs(), + subgms=self.subgms, + # pyrefly: ignore [missing-argument] + is_subgraph=self.is_subgraph, + ).generate() + + compiled_fn = self.compile_graph(gm) + + return FileBackedGraphModule(gm, compiled_fn), None + + def compile_graph(self, gm: GraphModule) -> Callable[..., Any]: + """ + Converts the graph module into a runnable function. The default implementation + is simply an interpreter calling kernels in eager mode. Derived backends can + override this to do further compilation. + """ + return gm.forward + + def write_header(self) -> None: + """ + Python subgraphs normally lack headers. + Override this behavior to generate prologues for FX subgraphs. + """ + PythonWrapperCodegen.write_header(self) + + @classmethod + def create( + cls: type["WrapperFxCodegen"], + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ) -> "WrapperFxCodegen": + if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None + + # Subgraphs override some methods of PythonWrapperCodegen. + # Apply these overrides to the user-provided class, with priority given to + # user-provided methods. + class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen): # type: ignore[misc,valid-type] + def compile_graph(self, gm: GraphModule) -> Callable[..., Any]: + """ + Skip graph compilation for subgraphs. + """ + + def crash_if_run(*args: Any) -> None: + raise NotImplementedError("Cannot run a subgraph in isolation!") + + return crash_if_run + + return SubgraphFxWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + + return cls() + + +@dataclasses.dataclass +class FxConverter: + """ + Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the + input and output code are stored as attributes. + """ + + lines: list[Line] + prologue: str + graph_inputs: dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]] + graph_outputs: list[ir.IRNode] + subgms: dict[str, torch.fx.GraphModule] + is_subgraph: bool + + def __post_init__(self) -> None: + graph = torch.fx.Graph() + self.gm = GraphModule({}, graph) # Wrapper FX IR. + self.buffer_to_node: dict[ + Optional[str], torch.fx.Node + ] = {} # Symbol table for codegen. + self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels. + self._unique_symbol_ids: Counter[str] = Counter() + self.tracer = torch.fx.proxy.GraphAppendingTracer(graph) + self.expr_to_proxy: dict[sympy.Expr, torch.fx.Proxy] = {} + + def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: + """ + Imports a kernel from source, possibly autotuning block parameters. + """ + module_code = "\n".join([self.prologue, code]) + mod = PyCodeCache.load(module_code) + kernel = getattr(mod, kernel_name) + + if isinstance(kernel, LambdaFuture): + kernel = kernel.result() + + if not isinstance(kernel, CachingAutotuner): + raise NotImplementedError( + textwrap.dedent(f""" + Unsupported type for kernel {kernel_name}: {type(kernel)}. + FX conversion only supports Triton kernels. + """) + ) + + return kernel + + def _create_as_strided( + self, + input_node: torch.fx.Node, + size: tuple[Any, ...], + stride: tuple[Any, ...], + offset: Union[int, sympy.Expr], + ) -> torch.fx.Node: + return self.gm.graph.call_function( + torch.as_strided, + args=( + input_node, + self._generate_sym_nodes(size), + self._generate_sym_nodes(stride), + self._generate_sym_node(offset), + ), + ) + + def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None: + """ + Updates the symbol table to record that an Inductor buffer maps to the result of + an FX node. + """ + assert node not in self.buffer_to_node + self.buffer_to_node[buffer.get_name()] = node + + def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None: + """ + Removes the buffer from the symbol table. + """ + name = buffer.get_name() + del self.buffer_to_node[name] + + def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + """ + Maps call args back to FX nodes. + """ + return tuple( + self.buffer_to_node[arg] + if isinstance(arg, str) + else arg.inner_expr + if isinstance(arg, SymbolicCallArg) + else arg + for arg in args + ) + + def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer: + """ + Extract buffer data from an IR node. + """ + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, (ir.BaseView, ir.MutableBox)): + return self._get_buffer(node.data) + elif isinstance(node, sympy.Symbol): + return SymbolBuffer(node) + else: + raise NotImplementedError(f"Unable to extract buffer from node: {node}") + + def _generate_size_proxy( + self, node: torch.fx.Node, expr: sympy.Expr + ) -> torch.fx.Proxy: + proxy = torch.fx.Proxy(node, tracer=self.tracer) + self.expr_to_proxy[expr] = proxy + return proxy + + def _generate_graph_inputs(self) -> None: + """ + Converts graph inputs to FX placeholders. + """ + + for name, ir_node in self.graph_inputs.items(): + if ir_node is None: + # Create dummy input nodes to match the input signature + self.gm.graph.placeholder(name) + continue + + # Introduce a new symbol for constant inputs. + is_constant = isinstance(ir_node, (int, float, sympy.Integer, sympy.Float)) + buffer = ( + SymbolBuffer(sympy.Symbol(name, is_integer=True)) + if is_constant + else self._get_buffer(ir_node) + ) + placeholder_node = self.gm.graph.placeholder(buffer.get_name()) + placeholder_node.meta["val"] = ( + ir_node if is_constant else buffer.get_example() + ) + self._record_allocation(buffer, placeholder_node) + + # Record symbol definitions for dynamic shapes. + if isinstance(ir_node, sympy.Symbol): + self._generate_size_proxy(placeholder_node, ir_node) + + def _generate_graph_input_shapes(self) -> None: + """ + Generate nodes creating symints that are part of graph input + shape/strides. + """ + + def _codegen_symbol( + sym_or_exp: Union[sympy.Symbol, sympy.Expr], + base_node: torch.fx.Node, + target: torch._ops.OpOverload, + dim: int, + ) -> None: + def codegen_proxy() -> torch.fx.Proxy: + size_node = self.gm.graph.call_function(target, (base_node, dim)) + size_proxy = self._generate_size_proxy(size_node, sym_or_exp) + return size_proxy + + if isinstance(sym_or_exp, sympy.Symbol): + if sym_or_exp in self.expr_to_proxy: + return + codegen_proxy() + + elif isinstance(sym_or_exp, sympy.Integer): + return + + elif isinstance(sym_or_exp, sympy.Expr): + # Check if we need to solve for an undefined symbol. + undefined_symbols = [ + sym + for sym in sym_or_exp.free_symbols + if sym not in self.expr_to_proxy + ] + if len(undefined_symbols) == 0: + self._sympy_interp(sym_or_exp) + return + elif len(undefined_symbols) > 1: + raise ValueError(f"Underdetermined input expression: {sym_or_exp}") + + # Define a new symbol for the input size. + size_proxy = codegen_proxy() + size_symbol = sympy.Symbol( + size_proxy.node.name, integer=True, nonnegative=True + ) + self.expr_to_proxy[size_symbol] = size_proxy + + # Solve for the undefined symbol. + undefined_symbol = undefined_symbols[0] + solution = try_solve( + sympy.Eq(sym_or_exp, size_symbol), undefined_symbol + ) + if solution is None: + raise ValueError(f"Cannot solve input expression: {sym_or_exp}") + + # Since the symbol is a size, it must be an integer. + # Therefore, we can convert division to FloorDiv. + undefined_symbol_expr = solution[1] + if undefined_symbol.is_integer: + undefined_symbol_expr = replace_floor_div( + sympy.floor(undefined_symbol_expr) + ) + + # Generate FX for the symbol. + self._sympy_interp(undefined_symbol_expr) + self.expr_to_proxy[undefined_symbol] = self.expr_to_proxy[ + undefined_symbol_expr + ] + + for ir_node in self.graph_inputs.values(): + if isinstance(ir_node, ir.TensorBox): + buffer = self._get_buffer(ir_node) + placeholder_node = self.buffer_to_node[buffer.get_name()] + + for dim, size in enumerate(ir_node.get_size()): + _codegen_symbol( + size, placeholder_node, torch.ops.aten.sym_size.int, dim + ) + for dim, stride in enumerate(ir_node.get_stride()): + _codegen_symbol( + stride, placeholder_node, torch.ops.aten.sym_stride.int, dim + ) + + def _generate_graph_constants(self) -> None: + for name, value in V.graph.constants.items(): + node = self.gm.graph.get_attr(name) + node.meta["val"] = value + setattr(self.gm, name, value) + self.buffer_to_node[name] = node + + def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]: + """ + Generates FX IR for transformations on a buffer, such as ReinterpretView. + Does nothing if no such transformations are present. + """ + + if isinstance(node, ir.ShapeAsConstantBuffer): + # Generate FX nodes to compute the shape expression. + return self._sympy_interp(node.expr).node + + def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]: + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, ir.NoneAsConstantBuffer): + return None + elif isinstance(node, ir.MutableBox): + return generate_to_buffer(node.data) + elif isinstance(node, ir.ReinterpretView): + # We need to introduce a new symbol if the output is a ReinterpretView. + # Use a WorkspaceArg for this. + buffer = self._get_buffer(node.data) + assert isinstance(buffer, (ir.Buffer, WorkspaceArg)) + unique_name = self.gm.graph._graph_namespace.create_name( + f"{buffer.get_name()}_view", None + ) + device = buffer.get_device() + assert device + reused_as = WorkspaceArg( + count=buffer.get_size(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + device=device, + outer_name=unique_name, + dtype=buffer.get_dtype(), + ) + + # Generate FX IR for the view. + self._generate_reinterpret_helper(buffer, reused_as, node.layout) + + return reused_as + else: + raise NotImplementedError(f"Unrecognized buffer/view node: {node}") + + buffer = generate_to_buffer(node) + return self.buffer_to_node[buffer.get_name()] if buffer is not None else None + + def _generate_outputs( + self, + ) -> Union[Optional[torch.fx.Node], list[Optional[torch.fx.Node]]]: + """ + Generate FX IR for graph outputs. + """ + output_nodes = [ + self._generate_buffer(node) for idx, node in enumerate(self.graph_outputs) + ] + + # Parent graphs with single return elements don't use a tuple. + output_value = ( + output_nodes[0] + if len(output_nodes) == 1 and not self.is_subgraph + else output_nodes + ) + + return output_value + + def _generate_subgm_getattrs(self) -> None: + """ + Generate getattr nodes for subgms. + """ + + def generate_getattr(name: str, subgm: torch.fx.GraphModule) -> torch.fx.Node: + self.gm.add_submodule(name, subgm) + node = self.gm.graph.get_attr(name) + node.meta["val"] = subgm + return node + + self.subgm_getattrs = { + name: generate_getattr(name, subgm) for name, subgm in self.subgms.items() + } + + def _get_subgm_attr(self, subgraph: ir.Subgraph) -> torch.fx.Node: + """ + Look up the getattr node for a subgraph. + """ + graph = subgraph.graph + assert graph is not None + return self.subgm_getattrs[graph.name] + + def generate(self) -> torch.fx.GraphModule: + """ + Main entrypoint for FX codegen. + """ + self._generate_graph_inputs() + self._generate_graph_constants() + self._generate_subgm_getattrs() + + with _set_node_metadata_hook( + self.gm, + functools.partial(_node_metadata_hook, fake_mode=V.fake_mode), + ): + self._generate_graph_input_shapes() + + # Generate FX IR from Wrapper IR lines. + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen_fx(self)(line) + elif isinstance(line, LineContext): + # Ignore line context in FX IR. + pass + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Found line of unrecognized type '{type(line)}': + '{line}' + + FX conversion only supports Wrapper IR lines. + """ + ) + ) + + output = self._generate_outputs() + + self.gm.graph.output(output) + self.gm.recompile() + return self.gm + + def _sympy_interp(self, expr: sympy.Expr) -> torch.fx.Proxy: + # hash cons + if expr in self.expr_to_proxy: + return self.expr_to_proxy[expr] + # base cases, don't cache + if isinstance( + expr, + ( + sympy.Integer, + sympy.Number, + sympy.Symbol, + sympy.logic.boolalg.BooleanAtom, + ), + ): + return sympy_interp( + OptimizedPythonReferenceAnalysis, self.expr_to_proxy, expr + ) + + # hash cons on arguments, run expr handler + self.expr_to_proxy[expr] = _run_sympy_handler( + OptimizedPythonReferenceAnalysis, + [self._sympy_interp(arg) for arg in expr.args], + expr, + ) + return self.expr_to_proxy[expr] + + def _generate_sym_node( + self, s: Union[int, sympy.Expr] + ) -> Union[int, torch.fx.Node]: + if isinstance(s, (int, sympy.Integer)): + return int(s) + elif isinstance(s, sympy.Symbol): + assert s in self.expr_to_proxy, ( + f"Could not find a node corresponding to the symbol {s}" + ) + return self.expr_to_proxy[s].node + elif isinstance(s, sympy.Expr): + return self._sympy_interp(s).node + + elif isinstance(s, torch.fx.Node): + return s + + else: + raise ValueError(f"{s} of type {type(s)} is not a valid input") + + def _generate_sym_nodes( + self, shape: Sequence[sympy.Expr] + ) -> list[Union[int, torch.fx.Node]]: + return [self._generate_sym_node(s) for s in shape] + + def _generate_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, AllocateLine) + buffer = line.node + name = buffer.get_name() + assert name not in V.graph.removed_buffers + + device = buffer.get_device() + assert device + dtype = buffer.get_dtype() + shape = self._generate_sym_nodes(buffer.get_size()) + stride = self._generate_sym_nodes(buffer.get_stride()) + + node = self.gm.graph.call_function( + torch.empty_strided, + args=(shape, stride), + kwargs={"dtype": dtype, "device": device.type}, + ) + assert name + node.name = name + self._record_allocation(buffer, node) + + def _generate_conditional(self, line: WrapperLine) -> None: + assert isinstance(line, ConditionalLine) + + def get_subgm_attr(subgraph: Optional[ir.Subgraph]) -> torch.fx.Node: + assert subgraph is not None + return self._get_subgm_attr(subgraph) + + # Access the subgraphs as getattrs. + ir_node = line.node + (true_subgm, false_subgm) = [ + get_subgm_attr(subgraph) + for subgraph in (ir_node.true_subgraph, ir_node.false_subgraph) + ] + + def generate_buffer(node: Optional[ir.IRNode]) -> Optional[torch.fx.Node]: + assert node is not None + return self._generate_buffer(node) + + predicate = generate_buffer(ir_node.predicate) + assert ir_node.operands is not None + operands = tuple(generate_buffer(arg) for arg in ir_node.operands) + fx_node = self.gm.graph.call_function( + torch.ops.higher_order.cond, + args=(predicate, true_subgm, false_subgm, operands), + ) + self._record_allocation(ir_node, fx_node) + + def _generate_comment(self, line: WrapperLine) -> None: + assert isinstance(line, CommentLine) + # We ignore comments in FX IR. + + def _generate_dynamic_scalar(self, line: WrapperLine) -> None: + assert isinstance(line, DynamicScalarLine) + + ir_node = line.node + (input_ir_node,) = ir_node.inputs + assert isinstance(input_ir_node, ir.IRNode) + input_fx_node = self._generate_buffer(input_ir_node) + keypath = ir_node.keypath + graph = self.gm.graph + + def generate_item(x: Optional[torch.fx.Node]) -> torch.fx.Node: + assert x is not None + return graph.call_function( + aten.item.default, + args=(x,), + ) + + if len(keypath) == 0: + result_fx_node = generate_item(input_fx_node) + elif len(keypath) == 1 and isinstance(keypath[0], ConvertIntKey): + where_fx_node = graph.call_function( + aten.where.Scalar, + args=(input_fx_node, 1, 0), + ) + result_fx_node = generate_item(where_fx_node) + else: + raise NotImplementedError(f"Unsupported keypath: {keypath}") + + result_symbol = ir_node.sym + result_buffer = SymbolBuffer(result_symbol) + self._record_allocation(result_buffer, result_fx_node) + self._generate_size_proxy(result_fx_node, result_symbol) + + def _generate_enter_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, EnterDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_exit_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, ExitDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_enter_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, EnterSubgraphLine) + # We ignore memory planning lines in FX IR. + + def _generate_exit_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, ExitSubgraphLine) + # We ignore memory planning lines in FX IR. + + def _generate_free(self, line: WrapperLine) -> None: + assert isinstance(line, FreeLine) + + buf = line.node + + # No need to free placeholders. + if self.buffer_to_node[buf.get_name()].op == "placeholder": + return + + self._free(buf) + + def _generate_free_if_not_reused(self, line: WrapperLine) -> None: + assert isinstance(line, FreeIfNotReusedLine) + buf = line.node + assert buf.get_name() not in V.graph.removed_buffers + if not line.is_reused: + self._free(buf) + + def _generate_line_context(self, line: WrapperLine) -> None: + assert isinstance(line, LineContext) + # We ignore line context in FX IR. + + def _generate_reinterpret(self, line: WrapperLine) -> None: + assert isinstance(line, ReinterpretLine) + self._generate_reinterpret_helper(line.node, line.reused_as, line.layout) + + def _generate_reinterpret_helper( + self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout + ) -> None: + input_node = self.buffer_to_node[input_buffer.get_name()] + + # Look up output metadata. + name = result_buffer.get_name() + assert name + size = tuple(layout.size) + stride = tuple(layout.stride) + if isinstance(layout, ir.NonOwningLayout): + # Look up the view's layout. + view = layout.view + assert isinstance(view, ir.ReinterpretView), ( + f"unexpected type: {type(view)}" + ) + layout = view.layout + offset = input_buffer.get_offset() + layout.offset + + # Map ReinterpretView to as_strided. + result_node = self._create_as_strided(input_node, size, stride, offset) + result_node.name = name + self._record_allocation(result_buffer, result_node) + + def _generate_reuse(self, line: WrapperLine) -> None: + assert isinstance(line, ReuseLine) + old = line.node + new = line.reused_as + assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new)) + assert old.get_dtype() == new.get_dtype() + + old_node = self.buffer_to_node[old.get_name()] + result_node = old_node + + # Change shape and stride. + size = tuple(new.get_size()) + stride = tuple(new.get_stride()) + offset = new.get_offset() + if ( + tuple(old.get_size()) != size + or tuple(old.get_stride()) != stride + or old.get_offset() != offset + ): + result_node = self._create_as_strided(old_node, size, stride, offset) + + self._record_allocation(new, result_node) + + # Free the old buffer, if we allocated a new tensor. + if ( + old.get_name() not in V.graph.get_output_names() + and line.delete_old + and result_node is not old_node + ): + self._free(old) + + def _generate_multi_output(self, line: WrapperLine) -> None: + assert isinstance(line, MultiOutputLine) + + arg_node = self.buffer_to_node[line.arg_name] + + # For non-tuple / non-list outputs, map the + # output to the same node as the input. + if len(line.indices) == 0: + self.buffer_to_node[line.result_name] = arg_node + return + + # Extract the index for tuple access. + inds = line.indices[0][1:] + assert len(inds) == 1, f"Cannot convert {inds} to an index." + idx = inds[0] + + node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) + node.name = line.result_name + self.buffer_to_node[line.result_name] = node + + def _generate_fallback_call( + self, + ir_node: ir.ExternKernel, + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> None: + fx_node = self.gm.graph.call_function( + ir_node.op_overload, # type: ignore[arg-type] + args=args, + kwargs=kwargs, + ) + result_buffer = ir_node.codegen_reference() + self.buffer_to_node[result_buffer] = fx_node + + def _generate_index_put_fallback(self, line: WrapperLine) -> None: + assert isinstance(line, IndexPutFallbackLine) + ir_node = line.node + + def generate_buffer_or_none( + x: Union[ir.IRNode, Sequence[ir.IRNode], None], + ) -> Optional[torch.fx.Node]: + """ + Handles None before calling _generate_buffer. + """ + if x is None: + return None + + assert isinstance(x, ir.IRNode) + return self._generate_buffer(x) + + (x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]] + indices = tuple(generate_buffer_or_none(t) for t in line.indices) + accumulate = ir_node.constant_args[0] + args = (x, indices, values, accumulate) + self._generate_fallback_call(ir_node, args) + + def _generate_scatter_fallback(self, line: WrapperLine) -> None: + assert isinstance(line, ScatterFallbackLine) + ir_node = line.node + assert ir.is_node_sequence(ir_node.inputs) + (x, index, src) = [self._generate_buffer(t) for t in ir_node.inputs] + ( + [] if ir_node.src_is_tensor else [ir_node.constant_args[1]] + ) + args = (x, ir_node.constant_args[0], index, src) + kwargs = {} + if reduce := ir_node.kwargs.get("reduce"): + kwargs["reduce"] = reduce + + self._generate_fallback_call(ir_node, args, kwargs) + + def _generate_null(self, line: WrapperLine) -> None: + assert isinstance(line, NullLine) + # Does nothing. + + def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferAllocateLine) + raise NotImplementedError("Comm buffer allocation is not yet supported") + + def _generate_comm_buffer_free(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferFreeLine) + self._free(line.node) + + def _generate_triton_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + + # Collect all kwargs, including autotuned block sizes. + call_args = self._lookup_args(line.call_args) + kernel = self.kernels[line.kernel_name] + tuner = kernel.tuner + + class UnbackedSymintsError(Exception): + pass + + def tune_kernel(tuner: CachingAutotuner, call_args: Sequence[Any]) -> None: + from triton.runtime import driver + + log.info("Autotuning Triton kernel %s at compile time.", kernel_name) + # pyrefly: ignore # missing-attribute + device = driver.active.get_current_device() + # pyrefly: ignore # missing-attribute + stream = driver.active.get_current_stream(device) + + def node_to_tuning_arg(arg: Any) -> Any: + """ + Create real tensors for autotuning arguments, substituting size hints + for dynamic shapes. + """ + + def to_size_hint(arg: Any) -> Any: + if len(free_unbacked_symbols(arg)) > 0: + # NYI: tuning args require backed symints. + raise UnbackedSymintsError + return pytree.tree_map(V.graph.sizevars.size_hint, arg) + + if not isinstance(arg, torch.fx.Node): + return to_size_hint(arg) + + fake = arg.meta["val"] + return torch.empty_strided( + to_size_hint(fake.shape), + to_size_hint(fake.stride()), + dtype=fake.dtype, + device=device, + ).zero_() + + arg_values = [node_to_tuning_arg(arg) for arg in call_args] + tuner.run(*arg_values, stream=stream) + + # Optionally autotune the kernels. + # The FX backend currently only supports compile-time tuning. + kernel_name = tuner.fn.__name__ + if config.triton.autotune_at_compile_time: + try: + tune_kernel(tuner, call_args) + except UnbackedSymintsError: + log.info( + "Detected unbacked symints. Skipping autotuning for kernel %s.", + kernel_name, + ) + else: + log.info( + "Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.", + kernel_name, + ) + + triton_meta = tuner.triton_meta + signature = triton_meta["signature"] + + def add_constants_to_call_args( + call_args: Sequence[Any], cfg: Config + ) -> tuple[Any, ...]: + """ + Add constant kwargs to the arg list. + """ + # Add args from the proper Triton signature. + # Exclude constants and config kwargs, as those are tracked separately. + new_call_args = [] + constants = triton_meta["constants"] + call_kwargs = { + key: val + for key, val in zip(signature, call_args) + # pyrefly: ignore [missing-attribute] + if key not in constants and key not in cfg.kwargs + } + + # Add constants stored as Triton metadata, in signature order. + call_kwargs |= constants + new_call_args = [ + call_kwargs[key] + for key in signature + # pyrefly: ignore [missing-attribute] + if key not in cfg.kwargs + ] + + # Add Inductor's extra launcher args to the end. + if extra_launcher_args := tuner.inductor_meta.get("extra_launcher_args"): + new_call_args.extend( + call_args[len(call_args) - len(extra_launcher_args) :] + ) + + return tuple(new_call_args) + + kernel_config = tuner.compile_results[0].config + extra_options = getattr(kernel_config, "extra_options", None) + call_args = add_constants_to_call_args(call_args, kernel_config) + call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) + call_kwargs = dict(zip(signature, call_args)) + # pyrefly: ignore [missing-attribute] + assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), ( + f"kwargs overlap config: {call_kwargs}" + ) + # pyrefly: ignore [missing-attribute] + call_kwargs.update(kernel_config.kwargs) + + # Replace sympy.floor with FloorDiv, to make the expression traceable. + grid = [replace_floor_div(x) if isinstance(x, sympy.Expr) else x for x in grid] + wrapper_grid = [tuple(self._generate_sym_nodes(grid))] + call_kwargs = { + name: self._generate_sym_node(val) for name, val in call_kwargs.items() + } + + # Store non-graphable kwargs in the side table. + ( + call_kwargs, + constant_args_idx, + ) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs) + + triton_node = self.gm.graph.call_function( + triton_kernel_wrapper_mutation, + kwargs={ + "kernel_idx": kernel.wrapped.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": wrapper_grid, + "tma_descriptor_metadata": {}, + "kwargs": call_kwargs, + }, + ) + if extra_options: + triton_node.meta["extra_options"] = extra_options + + def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None: + assert isinstance(line, ExternKernelAllocLine) + node = line.node + self._generate_extern_kernel_common(node, node) + + def _generate_extern_kernel_out( + self, + line: WrapperLine, + ) -> None: + assert isinstance(line, ExternKernelOutLine) + node = line.node + out_node = node.output_view if node.output_view else node + self._generate_extern_kernel_common(node, out_node) + + def _generate_extern_kernel_common( + self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode + ) -> None: + """ + Generates FX IR from either ExternKernelAlloc or ExternKernelOut. + """ + + # Get FX nodes corresponding to the call args. + assert ir.is_node_sequence(kernel.inputs) + tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) + if hasattr(kernel, "unflatten_args"): + args, _ = kernel.unflatten_args(tensor_nodes, kernel.constant_args) + else: + args = tensor_nodes + tuple(kernel.constant_args) + + # Get the result buffer. + # Some kernels write to a pre-existing output tensor via the "out" kwarg. + kwargs = kernel.kwargs.copy() + + result_buffer: Optional[str] = None + if isinstance(kernel, ir.ExternKernelOut): + kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()] + elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)): + result_buffer = kernel.get_name() + elif isinstance(kernel.layout, ir.NoneLayout): + pass + else: + raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}") + + fx_node = self.gm.graph.call_function( + kernel.op_overload, # type: ignore[arg-type] + args=args, + kwargs=kwargs, + ) + + # Assign the result to the given name. + if result_buffer: + assert "out" not in kwargs, ( + f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one." + ) + fx_node.name = result_buffer + self.buffer_to_node[result_buffer] = fx_node + + def _generate_kernel_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + if not line.triton: + raise NotImplementedError("FX conversion only supports Triton kernels.") + + self._generate_triton_call(line) + + def _generate_kernel_definition(self, line: WrapperLine) -> None: + assert isinstance(line, KernelDefinitionLine) + + # Generate code for the kernel. + kernel_code = PythonWrapperCodegen._format_kernel_definition( + line.kernel_name, line.kernel_body, metadata=line.metadata + ) + + # Import the module and store the JIT kernel. + tuner = self._import_kernel(kernel_code, line.kernel_name) + wrapped = wrap_triton(tuner.fn) + self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped) + + def _generate_symbolic_call_arg(self, line: WrapperLine) -> None: + assert isinstance(line, SymbolicCallArgLine) + # Store the arg: expr mapping for later use. + arg = line.arg + + inner_expr_proxy = self._sympy_interp(arg.inner_expr) + self.expr_to_proxy[arg.inner] = inner_expr_proxy + + def _generate_unbacked_symbol_defs(self, line: WrapperLine) -> None: + assert isinstance(line, UnbackedSymbolDefsLine) + graph = self.gm.graph + + def convert_key(node: torch.fx.Node, path: pytree.KeyPath) -> torch.fx.Node: + """ + Generate FX IR for each key entry. + """ + # Base case. + if len(path) == 0: + return node + + # Process the first entry and recurse. + entry = path[0] + if isinstance(entry, CallMethodKey): + target = { + "size": aten.sym_size.int, + "stride": aten.sym_stride.int, + "storage_offset": aten.sym_storage_offset, + }[entry.name] + assert callable(target) + node = graph.call_function( + target, + args=( + (node, path[1].idx) + if len(path) > 1 and isinstance(path[1], pytree.SequenceKey) + else (node,) + ), + ) + return convert_key(node, path[1 + len(node.args) :]) + elif isinstance(entry, pytree.SequenceKey): + node = graph.call_function(operator.getitem, args=(node, entry.idx)) + return convert_key(node, path[1:]) + elif isinstance(entry, DivideByKey): + node = graph.call_function( + operator.floordiv, args=(node, entry.divisor) + ) + return convert_key(node, path[1:]) + else: + raise NotImplementedError(f"Unrecognized entry type: {type(entry)}") + + root_node = self.buffer_to_node[line.output_name] + unbacked_bindings = line.unbacked_bindings + assert unbacked_bindings is not None + for s, keypath in unbacked_bindings.items(): + # Check if we already generated this symbol. + if s.name in self.buffer_to_node: + continue + + node = convert_key(root_node, keypath) + out_buffer = SymbolBuffer(s) + self._record_allocation(out_buffer, node) + self._generate_size_proxy(node, s) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..5d538ec20ca215b1dc5da23171a06999026c0eae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Optional + +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class XPUDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _xpu_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.xpu.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.xpu.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.xpu._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::DeviceGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTIXpuGuard" + + def cpp_stream_guard(self) -> str: + return "at::xpu::XPUStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTIXpuStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::xpu::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + """ + return source_codes + + def kernel_driver(self) -> str: + return "" + + def cpp_stream_type(self) -> str: + return "sycl::queue*" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_xpu_stream" + + def cpp_kernel_type(self) -> str: + return "std::unique_ptr" + + def cpp_device_ptr(self) -> str: + return "void *" + + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None + ) -> Optional[tuple[list[str], str]]: + return [f"void *global_scratch_{idx} = 0;"], f"global_scratch_{idx}" + + +register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8dc65c08ec457c1cb87354a7a95afb4d15203d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py @@ -0,0 +1,774 @@ +# mypy: allow-untyped-defs +import functools +from collections import deque + +import torch +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map + +from ..._dynamo.utils import counters +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + ShapeAsConstantBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import lowerings +from ..pattern_matcher import ( + Arg, + CallFunction, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, + TritonTemplateCaller, +) +from ..utils import ceildiv + + +B2B_GEMM_PASS = PatternMatcherPass( + pass_name="b2b_gemm_pass", +) + + +@SymbolicGridFn +def b2b_gemm_grid(M, P, meta, *, cdiv): + return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1) + + +b2b_gemm_left_template = TritonTemplate( + name="b2b_gemm_left", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_LEFT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (subgraph(A @ B) @ C) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for _ in range(num_o_block): + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for __ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + acc_ab += tl.dot(a, b, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_ab", + inner_mm="acc_ab" + ) | indent_except_first(2) }} + acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask", val_shape=("BLOCK_SIZE_M", "BLOCK_SIZE_P"))}} +""", +) + + +b2b_gemm_right_template = TritonTemplate( + name="b2b_gemm_right", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_RIGHT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop (two cases) + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (A @ subgraph(B @ C)) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for _ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for __ in range(num_o_block): + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_bc += tl.dot(b, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_bc", + inner_mm="acc_bc" + ) | indent_except_first(2) }} + acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask", val_shape=("BLOCK_SIZE_M", "BLOCK_SIZE_P"))}} +""", +) + + +# Note: load_ratio_left and load_ratio_right are only calculating numbers +# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C) + + +def load_ratio_left( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o)) + | store | M * O + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = M * N + N * O + M * O + O * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(O, o) + * (o * p + ceildiv(N, n) * (m * n + n * o)) + ) + return base / gemm + + +def load_ratio_right( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p)) + | store | N * P + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = N * O + O * P + M * N + N * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(N, n) + * (m * n + ceildiv(O, o) * (n * o + o * p)) + ) + return base / gemm + + +# the block sizes are limited by hardware (the shared memory) +# intuitively, the optimization works when the intermediate matrix is large +# and we assign large block sizes to large dimensions +b2b_gemm_configs = [ + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, +] + + +def is_b2b_gemm_good_on( + is_left_assoc: bool, + A_node: torch.fx.Node, + B_node: torch.fx.Node, + C_node: torch.fx.Node, +) -> bool: + """ + checks whether the sizes are good for b2b_gemm + """ + # basic checks + if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]): + return False + fake_tensors = ( + A_node.meta["val"], + B_node.meta["val"], + C_node.meta["val"], + ) # torch._subclasses.fake_tensor.FakeTensor + + A, B, C = fake_tensors + + def check_all_attr_true(objects, attr): + return all(hasattr(obj, attr) and getattr(obj, attr) for obj in objects) + + if not check_all_attr_true(fake_tensors, "is_cuda") and not check_all_attr_true( + fake_tensors, "is_xpu" + ): + return False + if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]): + return False + if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])): + return False + # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1 + M, N = A.shape + O, P = C.shape + ratios = [] + if is_left_assoc: + for config in b2b_gemm_configs: + ratio = load_ratio_left( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + else: + for config in b2b_gemm_configs: + ratio = load_ratio_right( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + ratios.sort(reverse=True) + average_ratio = 1.0 + for r in ratios[:3]: # top 3 choices + average_ratio *= r + average_ratio = average_ratio ** (1 / 3) + return ( + average_ratio > 1 + ) # even if average_ratio is close to 1, the number of stores is always better + + +def unoptimized_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + *, + out: torch.Tensor, +) -> torch.Tensor: + """ + The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial. + """ + if is_left_assoc: + torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out) + else: + torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out) + return out + + +unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm) + + +def build_subgraph_buffer( + args: list[TensorBox], + subgraph: Subgraph, +): + """ + This function is adapted from ../kernel/flex_attention.py. + The goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph + subgraph: The Subgraph ir for which to produce the output node + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs)) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + device = output_buffer.data.get_device() + assert device is not None + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=device, + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] should be a single element representing the output of the subgraph + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("B2B-GEMM was passed a subgraph with no output node!") + + +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox | ShapeAsConstantBuffer: + """ + Creates a placeholder input buffers for producing subgraph_output + """ + input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], [])) + return TensorBox.create(input_buffer) + + +def tuned_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch._inductor.ir.TensorBox, + B: torch._inductor.ir.TensorBox, + C: torch._inductor.ir.TensorBox, + *, + layout=None, +) -> torch._inductor.ir.TensorBox: + # call .realize() to get rid of Pointwise + A.realize() + B.realize() + C.realize() + layout = FixedLayout( + A.get_device_or_error(), + A.get_dtype(), + [A.shape[0], C.shape[1]], # type: ignore[index] + ) + placeholders = [ + create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error()) + ] + subgraph_buffer = build_subgraph_buffer( + placeholders, # type: ignore[arg-type, list-item] + subgraph, + ) + choices: list[TritonTemplateCaller] = [] + for config in b2b_gemm_configs: + if is_left_assoc: + b2b_gemm_left_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + else: + b2b_gemm_right_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + # add the unoptimized choice to mitigate performance degradation + choices.append( + unoptimized_choice.bind( + (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph + ) + ) + # autotune + return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout) + + +# match the inner mm of a potential b2b_gemm +@register_graph_pattern( + CallFunction(torch.ops.aten.mm, Arg(), Arg()), + # pyrefly: ignore [bad-argument-type] + pass_dict=B2B_GEMM_PASS, +) +def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None: + # match.args: list[torch.fx.Node] + + def is_pointwise_node(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and (torch.Tag.pointwise in node.target.tags) + ) + + def is_mm(node: torch.fx.Node) -> bool: + return node.target is torch.ops.aten.mm.default + + # the inner MM + inner_mm = match.nodes[-1] + + # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it + # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _). + outer_mm = None + node = inner_mm + while len(node.users) > 0: + node = next(iter(node.users)) + if is_mm(node): + outer_mm = node + break + elif is_pointwise_node(node): + continue + else: + break + if not outer_mm: + return + + # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C)) + # we call it the "f_node" + # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm + f_node = inner_mm + while next(iter(f_node.users)) is not outer_mm: + f_node = next(iter(f_node.users)) + + def all_reach_via_pointwise_with_no_other_inputs( + src: torch.fx.Node, + dst: torch.fx.Node, + ) -> tuple[bool, OrderedSet[torch.fx.Node]]: + """ + check whether every user path from src reaches dst via pointwise nodes, + with no other input nodes for the intermediates and dst; + return + (1) the Boolean value + (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True) + """ + visited = OrderedSet[torch.fx.Node]() + input_counter: dict[torch.fx.Node, int] = {} + + all_reachable = True + queue = deque([src]) + while queue: + node = queue.popleft() + if node not in visited: + if node is dst: + visited.add(node) + elif (node is src) or is_pointwise_node(node): + for user in node.users: + # for nodes other than dst, bookkeep their users' input counts + if user not in input_counter: + input_counter[user] = len(user.all_input_nodes) + input_counter[user] -= 1 + # continue BFS + queue.append(user) + visited.add(node) + else: + all_reachable = False + break + + return ( + all_reachable and all(count == 0 for count in input_counter.values()), + visited, + ) + + # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes + ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs( + inner_mm, f_node + ) + if not ok: + return + + # check inner_mm's inputs and f_node's outputs + if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1): + return + + # at this point, the nodes between inner_mm and f_node (both included) + # are all used internally inside (A @ subgraph(B @ C)) + # i.e. they neither have other users nor have other inputs + + # original graph and module + graph, module = inner_mm.graph, inner_mm.graph.owning_module + + # construct the new (sub)graph + subgraph_node_list: list[ + torch.fx.Node + ] = [] # ordered list of nodes used for node removal later + new_graph: torch.fx.Graph = torch.fx.Graph() + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node + new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node + new_input_node: torch.fx.Node + new_output_node: torch.fx.Node + for node in graph.nodes: # preserve the order of nodes + if node in subgraph_node_set: + subgraph_node_list.append(node) + new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x)) + node_remapping[node] = new_node + if node is inner_mm: + new_input_anchor = new_node + if node is f_node: + new_output_anchor = new_node + # pyrefly: ignore [unbound-name] + if new_input_anchor is not new_output_anchor: # subgraph is non-trivial + # update the input node + # pyrefly: ignore [unbound-name] + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + # pyrefly: ignore [unbound-name] + new_input_node.meta.update(new_input_anchor.meta) + # pyrefly: ignore [unbound-name] + new_input_anchor.replace_all_uses_with(new_input_node) + # pyrefly: ignore [unbound-name] + new_graph.erase_node(new_input_anchor) + # add the output node + # pyrefly: ignore [unbound-name] + new_output_node = new_graph.output(new_output_anchor) + # pyrefly: ignore [unbound-name] + new_output_node.meta.update(new_output_anchor.meta) + else: # subgraph is trivial, e.g. (A @ (B @ C)) + # update the input node + # pyrefly: ignore [unbound-name] + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + # pyrefly: ignore [unbound-name] + new_input_node.meta.update(new_input_anchor.meta) + # pyrefly: ignore [unbound-name] + new_input_anchor.replace_all_uses_with(new_input_node) + # pyrefly: ignore [unbound-name] + new_graph.erase_node(new_input_anchor) + # update the output node (don't use new_output_anchor since it has been erased) + new_output_node = new_graph.output(new_input_node) + new_output_node.meta.update(new_input_node.meta) + new_graph.lint() + + # construct the subgraph + subgraph = Subgraph( + name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph) + ) + + # two cases + # (1) (subgraph(A @ B) @ C), called "left_assoc" + # (2) (A @ subgraph(B @ C)), called "right_assoc" + is_left_assoc = outer_mm.args[0] is f_node + + # find the nodes A, B, C and check the sizes + A: torch.fx.Node + B: torch.fx.Node + C: torch.fx.Node + if is_left_assoc: + A = inner_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[1] # type: ignore[assignment] + C = outer_mm.args[1] # type: ignore[assignment] + else: + A = outer_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[0] # type: ignore[assignment] + C = inner_mm.args[1] # type: ignore[assignment] + if not is_b2b_gemm_good_on(is_left_assoc, A, B, C): + return + + # finally update the original graph + counters["inductor"]["b2b_gemm"] += 1 + graph = match.graph + with graph.inserting_before(outer_mm): + function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph) + function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined] + function._inductor_lowering_function = True # type: ignore[attr-defined] + replacement: torch.fx.Node = graph.call_function( + function, + (A, B, C), + match.kwargs, + ) + replacement.meta.update(outer_mm.meta) + outer_mm.replace_all_uses_with(replacement) + # erase unnecessary nodes + graph.erase_node(outer_mm) + for node in reversed(subgraph_node_list): + graph.erase_node(node) + graph.lint() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9bce1a8a2d599da6e8fa1f9b5a9442d6cbb954 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,503 @@ +# mypy: allow-untyped-defs +import functools +import itertools + +import torch + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype(computation_node): + computation_node_dtype = computation_node.meta["val"].dtype + if computation_node_dtype not in (torch.float16, torch.bfloat16): + return + + if len(computation_node.users) != 1: + return + + computation_node_user = next(iter(computation_node.users.keys())) + if not isinstance(computation_node_user.meta["val"], torch.Tensor): + return + + if computation_node_user.meta["val"].dtype != torch.float32: + return + + while computation_node_user.target in _binary_ops: + if len(computation_node_user.users) != 1: + return + + computation_node_user = next(iter(computation_node_user.users.keys())) + + if computation_node_user.target != prims.convert_element_type.default: + return + + computation_node.meta["_allow_mixed_dtype_folding"] = computation_node_dtype + + +def mark_mixed_dtype_allowed_computation_ops(gm): + """ + Mark convolutions/linear which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for target in [aten.convolution.default, aten.addmm.default, aten.mm.default]: + for node in gm.graph.find_nodes(op="call_function", target=target): + mark_mixed_dtype(node) + + +def recover_original_precision_folded_computation_ops(gm): + """ + After binary folding conv/linear weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + for target, idx in ( + (aten.convolution.default, (1, 2)), + (aten.addmm.default, (0, 2)), + (aten.mm.default, (1,)), + ): + for node in graph.find_nodes(op="call_function", target=target): + orig_dtype = node.meta.get("_allow_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for i in idx: + old_input = node.args[i] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.cache +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _addmm_args = [Arg() for _ in range(3)] + _mm_args = [Arg() for _ in range(2)] + _computation_ops = [aten.convolution.default, aten.addmm.default, aten.mm.default] + _computation_calls = [ + CallFunction(aten.convolution.default, *_conv_args, _users=1), + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + Arg(), + _users=1, + ), + CallFunction(aten.mm.default, *_mm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.mm.default, *_mm_args, _users=1), + Arg(), + _users=1, + ), + ] + + """ + In order to fuse add/sub/mul/div with conv/linear, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv/linear output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv/linear + weights/bias without changing their sizes. + It needs to broadcast to the conv/linear output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight, bias, and conv/linear output + is they all contain a dim with value = channels-out. In the + conv/linear output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _op_not_broadcasting_with_linear(weight_tensor, other_tensor, has_reshape): + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + other_shapes = [ + torch.Size( + [ + weight_shape[1], + ] + ), + torch.Size([1, weight_shape[1]]), + torch.Size( + [ + 1, + ] + ), + torch.Size([1, 1]), + ] + if has_reshape: + other_shapes.extend( + [ + torch.Size([1, 1, weight_shape[1]]), + torch.Size([1, 1, 1]), + ] + ) + return other_shape in other_shapes + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if len(conv_node.args[1].users) != 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + elif not isinstance(other, float): + return False + + return True + + def _check_linear_and_broadcast_op(linear_node, other, has_reshape): + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + if weight_node.op != "get_attr": + return False + if bias_node is not None and bias_node.op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if len(weight_node.users) != 1: + return False + + weight_meta_value = weight_node.meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not linear_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_linear( + weight_meta_value, other_meta_value, has_reshape + ): + return False + elif not isinstance(other, float): + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + has_reshape = False + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + other = binary_node.args[1] + elif binary_node.args[0].target is aten.reshape.default: + computation_node = binary_node.args[0].args[0] + other = binary_node.args[1] + has_reshape = True + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + else: + computation_node = binary_node.args[1].args[0] + other = binary_node.args[0] + has_reshape = False + if computation_node.target is aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + elif computation_node.target in [aten.addmm.default, aten.mm.default]: + return ( + config.enable_linear_binary_folding + and _check_linear_and_broadcast_op(computation_node, other, has_reshape) + ) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape, weight): + if isinstance(other, float): + with torch.utils._python_dispatch._disable_current_modes(): + other_tensor = torch.tensor( + other, dtype=weight.dtype, device=weight.device + ) + graph.owning_module.register_buffer("other_tensor", other_tensor) + res = graph.create_node("get_attr", "other_tensor") + res = graph.create_node( + "call_function", + aten.reshape.default, + (res, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + elif other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target is aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + def _create_new_linear_node(graph, linear_node, binary_node, other): + assert linear_node.target in [aten.addmm.default, aten.mm.default] + input_node = ( + linear_node.args[1] + if linear_node.target is aten.addmm.default + else linear_node.args[0] + ) + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + weight_meta_value = weight_node.meta.get("val") + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", + binary_node.target, + (0 if bias_node is None else bias_node, other_reshape), + ) + return graph.create_node( + "call_function", + aten.addmm.default, + (new_bias_node, input_node, weight_node), + ) + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1, weight_meta_value.size(1)] + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight_node = graph.create_node( + "call_function", binary_node.target, (weight_node, other_reshape1) + ) + new_weight_node.meta.update(weight_node.meta) + if bias_node is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", binary_node.target, (bias_node, other_reshape) + ) + new_bias_node.meta.update(bias_node.meta) + return graph.create_node( + "call_function", + linear_node.target, + (new_bias_node, input_node, new_weight_node), + ) + else: + return graph.create_node( + "call_function", linear_node.target, (input_node, new_weight_node) + ) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + reshape_node = None + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + elif binary_node.args[0].target is aten.reshape.default: + computation_node = binary_node.args[0].args[0] + reshape_node = binary_node.args[0] + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + else: + computation_node = binary_node.args[1].args[0] + reshape_node = binary_node.args[1] + graph = match.graph + with graph.inserting_before(reshape_node if reshape_node else binary_node): + assert computation_node.target in _computation_ops + if computation_node.target is aten.convolution.default: + counters["inductor"]["binary_folding_conv"] += 1 + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + else: + new_computation_node = _create_new_linear_node( + graph, computation_node, binary_node, other + ) + new_computation_node.meta.update(computation_node.meta) + if reshape_node: + assert reshape_node.target is aten.reshape.default + computation_node.replace_all_uses_with(new_computation_node) + binary_node.replace_all_uses_with(reshape_node) + else: + binary_node.replace_all_uses_with(new_computation_node) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py new file mode 100644 index 0000000000000000000000000000000000000000..e72cdccddb44010f316cad92d8e10e1d13af6400 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/bucketing.py @@ -0,0 +1,1100 @@ +import collections +import logging +import operator +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Literal, TypeAlias + +import torch +import torch.distributed as dist +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import detect_fake_mode +from torch._inductor.comm_analysis import ( + get_collective_type_from_kernel_name, + NCCL_COLL, +) +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._logging import trace_structured +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.traceback import NodeSource, NodeSourceAction +from torch.utils._ordered_set import OrderedSet + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + +BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"] + + +# Helper functions moved to top for better organization +def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined] + _, group_size, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + return (group_name, dtype) + + +def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]: + _, group_size, group_name = node.args + assert isinstance(group_name, str) + return (group_name,) + + +def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined] + _, reduce_op, group_size, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + return (group_name, reduce_op, dtype) + + +def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: + _, reduce_op, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + return (group_name, reduce_op, dtype) + + +def _schedulable_wait_node(node: torch.fx.Node) -> bool: + """ + Add additional check on if the wait node is schedulable + We should not schedule a fx node that is: + 1. wait on a collective that is not callable + 2. wait on a non-NCCL communication node + """ + if not is_wait_tensor(node): + return False + assert isinstance(node.args[0], torch.fx.Node) + if not isinstance(node.args[0].target, Callable): + return False + is_callable: bool = node.args[0].op == "call_function" + coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name()) + is_collective: bool = coll != NCCL_COLL.UNSUPPORTED + return is_callable and is_collective + + +def _populate_node_meta( + bucket_nodes: list[torch.fx.Node], new_nodes: list[torch.fx.Node] +): + if bucket_nodes: + for n in new_nodes: + # For the following keys, we only store the information of the first node so + # gm.print_readable shows some information + # Full information are stored in "bucketing_{key}_sources" + for key, default in [ + ("nn_module_stack", ""), + ("fwd_nn_module_stack", ""), + ("stack_trace", ""), + ("custom", {}), + ]: + n.meta[key] = bucket_nodes[0].meta.get(key, default) + + # Collect sources from all bucket nodes for this metadata key, for debugging purposes only + bucketing_sources_key = f"bucketing_{key}_sources" + # Use set to remove duplicates + if key == "stack_trace": + sources = OrderedSet( + [ + node.meta.get(key, default) + for node in bucket_nodes + if node.meta.get(key, default) + ] + ) + else: + # type might not be hashable + sources = [ + node.meta.get(key, default) + for node in bucket_nodes + if node.meta.get(key, default) + ] + n.meta[bucketing_sources_key] = sources + + # used by inductor provenance tracking + n.meta["from_node"] = [ + NodeSource( + original_node, + "bucketing_pass", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + for original_node in bucket_nodes + ] + + +def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None: + if is_all_gather_into_tensor(node): + group_key_fn = ( + _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key + ) + return group_key_fn(node) + elif is_reduce_scatter_tensor(node): + return _rs_group_key(node) + elif is_all_reduce_tensor(node): + return _ar_group_key(node) + else: + return None + + +def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined] + assert len(dtypes) > 0 + return min(dtypes, key=operator.attrgetter("itemsize")) + + +def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: + """ + Determine the size of a bucket based on its ID. + + Args: + bucket_id (int): The ID of the bucket. + + Returns: + float: The size of the bucket. + """ + return 2000.0 + + +def bucket_all_gather( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: BucketMode = "default", +) -> None: + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode) + if len(ag_buckets) == 0: + return + merge_all_gather(gm, ag_buckets, mode) + + +def bucket_reduce_scatter( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: BucketMode = "default", +) -> None: + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + rs_buckets = bucket_reduce_scatter_by_mb( + gm, bucket_cap_mb_by_bucket_idx, None, mode + ) + if len(rs_buckets) == 0: + return + merge_reduce_scatter(gm, rs_buckets, mode) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type] + return node.op == "call_function" and ( + node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default + ) + + +def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default + ) + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target is torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_all_reduce_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target is torch.ops._c10d_functional.all_reduce.default + ) + + +def is_all_to_all_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target is torch.ops._c10d_functional.all_to_all_single.default + ) + + +def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type] + + +def collect_node_descendants( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Collects the descendants of each node in the graph. + Args: + graph (torch.fx.Graph): The graph to collect descendants from. + Returns: + dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants. + """ + node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = ( + collections.defaultdict(OrderedSet) + ) + outdegree = collections.defaultdict(int) + queue = [] + + for node in graph.nodes: + n_outdegree = len(node.users) + if n_outdegree == 0: + queue.append(node) + else: + outdegree[node] = len(node.users) + + while queue: + node = queue.pop() + for input_node in node.all_input_nodes: + node_descendants[input_node] |= node_descendants[node] + node_descendants[input_node].add(node) + outdegree[input_node] -= 1 + + if outdegree[input_node] == 0: + queue.append(input_node) + + return node_descendants + + +def greedy_bucket_collective_by_mb( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float], + filter_node: Callable[[torch.fx.Node], bool], + node_group_key: Callable[[torch.fx.Node], Any], + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, +) -> list[list[torch.fx.Node]]: + """ + Bucketing adjacent collectives with equal node_group_key. + We can not bucket non adjacent collectives, + as this will effectively change the order of collectives. + Reordering can lead to different order on different ranks. + """ + g = gm.graph + found_candidates = False + for node in g.nodes: + if filter_node(node): + found_candidates = True + break + if not found_candidates: + return [] + + # TODO: pearce kelly algorithm for detecting cycles + node_descendents = collect_node_descendants(gm.graph) + + nodes_groups: list[list[torch.fx.Node]] = [] + cur_group: list[torch.fx.Node] = [] + cur_group_key = None + + for node in g.nodes: + if is_wait_tensor(node) and filter_node(node.args[0]): + if (filter_wait_node is None) or filter_wait_node(node): + coll_node = node.args[0] + group_key = node_group_key(coll_node) + if group_key == cur_group_key: + cur_group.append(coll_node) + else: + if len(cur_group) > 1: + nodes_groups.append(cur_group) + cur_group = [coll_node] + cur_group_key = group_key + + if len(cur_group) > 1: + nodes_groups.append(cur_group) + + buckets: list[list[torch.fx.Node]] = [] + for nodes in nodes_groups: + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + bucket_size_bytes = int( + bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 + ) + for node in nodes: + if node in cur_bucket_descendents: + # if there is a path from node to the current bucket, we cannot horizontally fuse (bucket) + continue + assert "val" in node.meta + n_val = node.meta["val"] + out_size_bytes = n_val.numel() * n_val.element_size() + n_input_val = node.all_input_nodes[0].meta["val"] + in_size_bytes = n_input_val.numel() * n_input_val.element_size() + size_bytes = max(out_size_bytes, in_size_bytes) + if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket: + # Current bucket is full, create new bucket + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_descendents = OrderedSet() + cur_bucket_size_bytes += size_bytes + cur_bucket.append(node) + cur_bucket_descendents |= node_descendents[node] + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + return buckets + + +def bucket_all_gather_by_mb( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float], + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, + mode: BucketMode = "default", +) -> list[list[torch.fx.Node]]: + """ + Identifies all all_gather nodes and groups them into buckets, + based on size limit `bucket_cap_mb_by_bucket_idx`. + + Args: + gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers. + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow + to specify different sizes of the buckets at the start, + as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx + is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`. + filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified, + only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed. + + Returns: + list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. + """ + + group_key_fn = ( + _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key + ) + + return greedy_bucket_collective_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + is_all_gather_into_tensor, + group_key_fn, + filter_wait_node, + ) + + +def bucket_reduce_scatter_by_mb( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float], + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, + mode: BucketMode = "default", +) -> list[list[torch.fx.Node]]: + """ + Identifies all reduce_scatter nodes and groups them into buckets, + based on size limit `bucket_cap_mb_by_bucket_idx`. + + Args: + gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters. + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow + to specify different sizes of the buckets. + filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified, + only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed. + + Returns: + list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. + """ + + assert "multidtype" not in mode, ( + "reduce scatter bucketing does not support multidtype" + ) + + return greedy_bucket_collective_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + is_reduce_scatter_tensor, + _rs_group_key, + filter_wait_node, + ) + + +def bucket_all_reduce_by_mb( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float], + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, +) -> list[list[torch.fx.Node]]: + return greedy_bucket_collective_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + is_all_reduce_tensor, + _ar_group_key, + filter_wait_node, + ) + + +def bucket_all_reduce( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, +) -> None: + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx) + if len(ar_buckets) == 0: + return + for bucket in ar_buckets: + merge_all_reduce_bucket(gm.graph, bucket, mode) + + +@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={}) +def _pre_bucket_reduce_scatter( + rs_ins: list[torch.Tensor], + group_size: int, +) -> torch.Tensor: + rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins] + new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() + return new_rs_in + + +def _pre_bucket_reduce_scatter_fake( + rs_ins: list[torch.Tensor], + group_size: int, +) -> torch.Tensor: + out_numel = sum(rs_in.numel() for rs_in in rs_ins) + return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype) + + +_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake) + + +def reduce_scatter_merge_fn_to_trace_custom_ops( + rs_ins: list[torch.Tensor], + group_size: int, + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, # type: ignore[name-defined] + device: torch.device, # type: ignore[name-defined] +) -> list[torch.Tensor]: # type: ignore[no-untyped-def] + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + new_out_numels = [x.numel() // group_size for x in rs_ins] + + new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size) + + # TODO - either use torch.cat or make sure inductor foreach codegen + # fires more reliably + new_rs_out = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + new_rs_in, reduce_op, group_size, group_name + ) + ) + new_out_flat = new_rs_out.split(new_out_numels, 0) + new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)] + return new_outs + + +def reduce_scatter_merge_fn_to_trace( + rs_ins: list[torch.Tensor], + group_size: int, + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, # type: ignore[name-defined] + device: torch.device, # type: ignore[name-defined] +) -> list[torch.Tensor]: # type: ignore[no-untyped-def] + rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins] + + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + new_out_numels = [x.numel() // group_size for x in rs_ins] + + new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() + + new_rs_out = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + new_rs_in, reduce_op, group_size, group_name + ) + ) + new_out_flat = new_rs_out.split(new_out_numels, 0) + new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)] + return new_outs + + +def all_reduce_merge_fn_to_trace( + ar_ins: list[torch.Tensor], + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, # type: ignore[name-defined] + device: torch.device, # type: ignore[name-defined] +) -> list[torch.Tensor]: # type: ignore[no-untyped-def] + ar_ins_flattened = [x.view(-1) for x in ar_ins] + new_ar_in = torch.cat(ar_ins_flattened) + new_ar_out = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name) + ) + split_sizes = [x.numel() for x in ar_ins] + new_outs_flat = new_ar_out.split(split_sizes) + new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)] + return new_outs + + +# List of all torch dtypes for serialization through custom ops +# TODO: custom ops support list[dtype] input +_ALL_DTYPES = tuple( + [ + getattr(torch, attr) + for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) + ] +) + + +@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) +def _pre_bucket_all_gather( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[ + int + ], # dtype enum values, that inputs are converted to before all_gather + rank: int, +) -> torch.Tensor: + # Convert int indices back to torch.dtype + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) + foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) + # View each destination slice as its output dtype, then copy + # The copy operation handles dtype conversion from input dtype to output dtype + foreach_copy_dsts_typed = [ + dst.view(out_dtype) + for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True) + ] + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened) + return new_ag_out + + +def _pre_bucket_all_gather_fake( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[int], + rank: int, +) -> torch.Tensor: + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + return new_ag_out + + +_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake) + + +def all_gather_merge_fn_to_trace_custom_ops( + _ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] + rank: int, +) -> list[torch.Tensor]: + # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion + # by viewing destination slices as output dtypes and letting copy do the conversion + ag_ins = _ag_ins + ins_sizes = [ag_in.shape for ag_in in ag_ins] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes) + ] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] + ag_input_numel = sum(ins_split_sizes) + + # Convert out_dtypes to indices for custom_op + # TODO: custom ops support list[dtype] input + out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes] + + new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( + ag_ins, group_size, group_name, dtype, out_dtype_ints, rank + ) + new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) + wait_tensor = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + new_ag_in, group_size, group_name, out=new_ag_out + ) + ) + new_ag_out_reshaped = wait_tensor.reshape(group_size, -1) + outs_bucket_dtype = torch.split_with_sizes( + new_ag_out_reshaped, + ins_split_sizes, + dim=1, + ) + outs_reshaped = [ + o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:]) + for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes) + ] + return outs_reshaped + + +def all_gather_merge_fn_to_trace( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] + rank: int, +) -> list[torch.Tensor]: + ins_sizes = [ag_in.shape for ag_in in ag_ins] + ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) + foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + wait_tensor = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + new_ag_in, group_size, group_name, out=new_ag_out + ) + ) + new_ag_out_reshaped = wait_tensor.reshape(group_size, -1) + outs = torch.split_with_sizes( + new_ag_out_reshaped, + ins_split_sizes, + dim=1, + ) + outs_reshaped = [ + o.reshape((shape[0] * group_size,) + shape[1:]) + for o, shape in zip(outs, ins_sizes) + ] + return outs_reshaped + + +def all_gather_merge_fn_to_trace_functional( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] + rank: int, + use_fsdp_ag_copy_in: bool = False, +) -> list[torch.Tensor]: + # Implementation that is functional in graph, + # but uses custom op torch.ops.fsdp.all_gather_copy_in. + ins_sizes = [ag_in.shape for ag_in in ag_ins] + ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + if use_fsdp_ag_copy_in: + new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in( + ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank + ) + else: + new_ag_in = torch.cat(ag_ins_flattened, dim=0) + wait_tensor = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + new_ag_in, group_size, group_name, out=new_ag_out + ) + ) + new_ag_out_reshaped = wait_tensor.reshape(group_size, -1) + outs = torch.split_with_sizes( + new_ag_out_reshaped, + ins_split_sizes, + dim=1, + ) + outs_reshaped = [ + o.reshape((shape[0] * group_size,) + shape[1:]) + for o, shape in zip(outs, ins_sizes) + ] + return outs_reshaped + + +def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def] + with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True): + fake_mode = detect_fake_mode(inps) + assert fake_mode is not None + with fake_mode, enable_python_dispatcher(): + out = make_fx(fn)(*inps) + for node in out.graph.find_nodes( + op="call_function", target=torch.ops.aten.detach.default + ): + node.replace_all_uses_with(node.args[0]) + out.graph.erase_node(node) + return out + + +def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] + g: torch.fx.Graph, + fn_to_trace, + inps, + insert_before_node: torch.fx.Node, + g_fn_inps: list[torch.fx.Node], + g_fn_outs: list[torch.fx.Node], +) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def] + """ + Helper function that traces :attr:`fn_to_trace` with inputs + :attr:`inps`. + The result function graph will be inserted before :attr:`insert_before_node`, + using :attr:`g_fn_inps` nodes of original graph as inputs of function graph, + function graph outputs will replace :attr:`g_fn_outs` in original graph. + + Returns: + (replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes + """ + with dynamo_timed( + "fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True + ): + fn_gm = _trace( + fn_to_trace, + inps, + ) + fn_g = fn_gm.graph + fn_g_ins = fn_g.find_nodes(op="placeholder") + env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))} + g_fn_new_outs: list[torch.fx.Node] = [] + new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes + + with g.inserting_before(insert_before_node): + for _n in fn_g.nodes: + if _n.op == "placeholder": + continue + _new_n = g.node_copy(_n, lambda x: env[x]) + env[_n] = _new_n + if _n.op == "output": + g_fn_new_outs = _new_n.args[0] # type: ignore[assignment] + g.erase_node(_new_n) + else: + new_nodes.append(_new_n) # Track non-output nodes + + replacements = { # noqa: C416 + orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs) + } + for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs): + orig_out.replace_all_uses_with(new_out) + + return replacements, new_nodes + + +def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool: + node_in = n.args[0] + return ( + is_all_gather_into_tensor(n) + and isinstance(node_in, torch.fx.Node) + and node_in.op == "call_function" + and ( + node_in.target is torch.ops.prims.convert_element_type.default + or node_in.target is torch.ops.aten._to_copy.default + ) + and len(node_in.users) == 1 + ) + + +def process_collective_bucket( + g: torch.fx.Graph, + bucket_nodes: list[torch.fx.Node], + fn_to_trace: Callable[..., list[torch.Tensor]], + trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]], + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, +) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + """ + Process a single bucket of collective operation nodes with flexible insertion control. + + Args: + g: The graph to modify + bucket_nodes: Nodes in the current bucket to process + fn_to_trace: Function to trace and insert + trace_args_fn: Function to create trace arguments from inputs + insert_before: Where to insert the traced function (default: after last bucket node) + wait_insertion_point: If provided, move all nodes from wait() onwards to before this node + + Returns: + new_nodes: List of all newly inserted nodes + replacements: Dictionary mapping old wait nodes to new output nodes + """ + # Collect inputs and waits from current bucket + bucket_ins: list[torch.fx.Node] = [] + bucket_waits: list[torch.fx.Node] = [] + ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list) + + for n in bucket_nodes: + assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}" + wait_n = next(iter(n.users)) + + # Handle convert_element_type operations (for all_gather) + node_in = n.args[0] + if has_mergeable_all_gather_convert_dtype(n): + ag_node_to_pre_nodes[n].append(node_in) + node_in = node_in.args[0] + + assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node + bucket_ins.append(node_in) + bucket_waits.append(wait_n) + + # Create trace arguments + trace_args = trace_args_fn(bucket_ins) + + # Determine insertion point + if insert_before is None: + insert_before = bucket_nodes[-1].next + + # Insert traced function and get replacements + new nodes + replacements, new_nodes = _insert_fn_trace_before_node( + g, + fn_to_trace, + trace_args, + insert_before, + bucket_ins, + bucket_waits, + ) + + # If requested, move wait nodes and everything after to specified location + if wait_insertion_point is not None: + # Find the first wait node in new_nodes + wait_start_idx = None + for i, node in enumerate(new_nodes): + if is_wait_tensor(node): + wait_start_idx = i + break + + # Move all nodes from wait onwards (including the wait) + if wait_start_idx is not None: + nodes_to_move = new_nodes[wait_start_idx:] + for node in nodes_to_move: + wait_insertion_point.prepend(node) + + # Preserve metadata from original collective nodes to new bucketed nodes + if bucket_nodes: + overlap_log.debug( + "Bucketing nodes: %s, New nodes: %s", + ",".join([n.name for n in bucket_nodes]), + ",".join([n.name for n in new_nodes]), + ) + _populate_node_meta(bucket_nodes, new_nodes) + + # Erase old nodes + for node, wait_n in zip(bucket_nodes, bucket_waits): + g.erase_node(wait_n) + g.erase_node(node) + # Erase any convert_element_type nodes we tracked + for pre_node in reversed(ag_node_to_pre_nodes[node]): + g.erase_node(pre_node) + + return new_nodes, replacements + + +def merge_reduce_scatter_bucket( + g: torch.fx.Graph, + rs_nodes: list[torch.fx.Node], + mode: BucketMode = "default", + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, +) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + # Validate bucket consistency + rs0 = rs_nodes[0] + rs0_val = rs0.meta["val"] + _, reduce_op, group_size, group_name = rs0.args + reduce_dtype = rs0_val.dtype + device = rs0_val.device + + for n in rs_nodes: + rs_val = n.meta["val"] + assert ( + n.args[1] == reduce_op + and n.args[2] == group_size + and n.args[3] == group_name + and rs_val.device == device + and rs_val.dtype == reduce_dtype + ) + + # Choose merge function based on mode + rs_merge_fn = reduce_scatter_merge_fn_to_trace + if mode and "custom_ops" in mode: + rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops + + # Process bucket with lazy input collection + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_size, + group_name, + reduce_op, + reduce_dtype, + device, + ) + + return process_collective_bucket( + g, + rs_nodes, + rs_merge_fn, + create_trace_args, + insert_before=insert_before, + wait_insertion_point=wait_insertion_point, + ) + + +def merge_all_reduce_bucket( + g: torch.fx.Graph, + ar_nodes: list[torch.fx.Node], + mode: str | None = None, + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, +) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + ar0 = ar_nodes[0] + ar0_val = ar0.meta["val"] + _, reduce_op, group_name = ar0.args + reduce_dtype = ar0_val.dtype + device = ar0_val.device + + for n in ar_nodes: + ar_val = n.meta["val"] + assert ( + n.args[1] == reduce_op + and n.args[2] == group_name + and ar_val.device == device + and ar_val.dtype == reduce_dtype + ) + + ar_merge_fn = all_reduce_merge_fn_to_trace + + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_name, + reduce_op, + reduce_dtype, + device, + ) + + return process_collective_bucket( + g, + ar_nodes, + ar_merge_fn, + create_trace_args, + insert_before=insert_before, + wait_insertion_point=wait_insertion_point, + ) + + +def merge_all_gather_bucket( + g: torch.fx.Graph, + ag_nodes: list[torch.fx.Node], + mode: BucketMode = "default", + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, +) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + from torch.distributed.distributed_c10d import _resolve_process_group + + ag0 = ag_nodes[0] + _, group_size, group_name = ag0.args + assert isinstance(group_name, str) + _ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined] + + for n in ag_nodes: + assert n.args[1] == group_size and n.args[2] == group_name + _ag_dtypes.append(n.meta["val"].dtype) + + bucket_dtype = pick_bucket_dtype(_ag_dtypes) + + # Choose merge function based on mode + ag_merge_fn = all_gather_merge_fn_to_trace + if mode is not None and "custom_ops" in mode: + ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment] + + # Process bucket with lazy input collection + rank: int = dist.get_rank(_resolve_process_group(group_name)) + + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_size, + group_name, + bucket_dtype, + _ag_dtypes, + rank, + ) + + return process_collective_bucket( + g, + ag_nodes, + ag_merge_fn, + create_trace_args, + wait_insertion_point=wait_insertion_point, + ) + + +def merge_reduce_scatter( + gm: torch.fx.GraphModule, + rs_buckets: list[list[torch.fx.Node]], + mode: BucketMode = "default", +) -> None: + """ + Merges specified buckets of reduce_scatter to joint reduce_scatter. + """ + with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_bucketing_passes_reduce_scatter_buckets", + "encoding": "string", + }, + payload_fn=lambda: str(rs_buckets), + ) + + g = gm.graph + + for rs_nodes in rs_buckets: + merge_reduce_scatter_bucket(g, rs_nodes, mode) + + +def merge_all_gather( + gm: torch.fx.GraphModule, + ag_buckets: list[list[torch.fx.Node]], + mode: BucketMode = "default", +) -> None: + """ + Merges specified buckets of all_gather to joint all_gather. + """ + with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_bucketing_passes_all_gather_buckets", + "encoding": "string", + }, + payload_fn=lambda: str(ag_buckets), + ) + + g = gm.graph + + for ag_nodes in ag_buckets: + merge_all_gather_bucket(g, ag_nodes, mode) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e3ca625c5d97bcd0e52508ed084f5bf82b2bb2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/control_dependencies.py @@ -0,0 +1,226 @@ +# mypy: allow-untyped-defs +""" +Effect ordering pass for inductor. + +This pass adds ordering dependencies to FX graphs using the control_deps HOP +for precise control over scheduling constraints. When you need exact ordering between +operations (e.g., collective_start -> mm -> wait), this pass wraps operations +with control_deps to make dependencies explicit. +""" + +from typing import Any + +import torch.fx as fx +from torch._higher_order_ops.utils import register_fake +from torch._ops import HigherOrderOperator +from torch.utils._ordered_set import OrderedSet + + +class ControlDeps(HigherOrderOperator): + """ + Higher-order operator that enforces ordering by making dependencies explicit. + + Schema: control_deps(additional_deps, target, *args, **kwargs) -> result + where: + - additional_deps: tuple of tensors that must be computed before this op + - subgraph: GraphModule containing the exact operation to execute + - args/kwargs: arguments for the target function + + This ensures all tensors in additional_deps are computed before the target + executes, creating explicit scheduling dependencies. + """ + + def __init__(self) -> None: + super().__init__("control_deps") + + def __call__(self, additional_deps, subgraph, *args, **kwargs): + """Call the operator with dependencies and subgraph. + + Args: + additional_deps: Tuple of tensors that must be computed first + subgraph: GraphModule containing the exact operation to execute + *args: Arguments to pass to the subgraph + """ + if not isinstance(additional_deps, (tuple, list)): + raise TypeError( + f"additional_deps must be tuple/list, got {type(additional_deps).__name__}" + ) + if not (isinstance(subgraph, fx.GraphModule) or callable(subgraph)): + raise TypeError( + f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}" + ) + return super().__call__(additional_deps, subgraph, *args, **kwargs) + + +control_deps = ControlDeps() + + +# Register fake implementation for tracing +@register_fake(control_deps) +def _(additional_deps, subgraph, *args, **kwargs): + """Fake tensor implementation - execute the subgraph.""" + return subgraph(*args, **kwargs) + + +def get_subgraph_name(gm: fx.GraphModule, name): + name = f"subgraph_{name}" + + if not hasattr(gm, name): + return name + + i = 0 + while hasattr(gm, f"{name}_{i}"): + i += 1 + + return f"{name}_{i}" + + +def preserve_node_ordering( + graph: fx.Graph, + additional_deps_map: dict[fx.Node, OrderedSet[fx.Node]], + verbose: bool = False, +) -> None: + """ + Preserve node ordering using control_deps HOP with subgraph. + + This function wraps operations with control_deps that: + 1. Makes additional dependencies explicit (first argument) + 2. Creates a subgraph internally to preserve the exact original operation + 3. Preserves the original node names + + Args: + graph: The FX graph to modify + additional_deps_map: Mapping from dependent nodes to their dependencies + verbose: If True, print debug information + """ + if not additional_deps_map: + return + + # Track replacements so we can update dependencies + replacements: dict[fx.Node, fx.Node] = {} + + # Process each node that needs additional dependencies + for dependent_node, dep_nodes in additional_deps_map.items(): + assert dependent_node.op == "call_function", dependent_node.op + + original_name = dependent_node.name + original_args = dependent_node.args + original_kwargs = dependent_node.kwargs + original_meta = dependent_node.meta.copy() + + updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes] + + # Create a subgraph that preserves the exact original operation + subgraph_module = _create_subgraph_for_node(graph, dependent_node) + + owning_mod = graph.owning_module + assert owning_mod is not None + subgraph_attr_name = get_subgraph_name(owning_mod, original_name) + setattr(graph.owning_module, subgraph_attr_name, subgraph_module) + + # Create control_deps call with: + # 1. Additional dependencies as first arg (explicit) + # 2. Subgraph via get_attr (like b2b gemm pass) + # 3. Original arguments (only fx.Node args and kwargs are passed) + with graph.inserting_before(dependent_node): + # Create get_attr node for the subgraph + get_subgraph = graph.get_attr(subgraph_attr_name) + + # add additional args + node_args = [a for a in original_args if isinstance(a, fx.Node)] + for value in original_kwargs.values(): + if isinstance(value, fx.Node): + node_args.append(value) + + # Create with temporary name first + ordered_node = graph.call_function( + control_deps, + args=( + tuple(updated_dep_nodes), # additional_deps + get_subgraph, # subgraph via get_attr (like b2b gemm) + *node_args, # original node arguments (from both args and kwargs) + ), + kwargs={}, + name=f"__temp_{original_name}", # Temporary name to avoid conflict + ) + + # Copy metadata from original node + ordered_node.meta = original_meta + # this will be constrained on the target node in subgraph if it exists + ordered_node.meta.pop("eager_input_vals", None) + + # Replace all uses of the original node with the ordered version + dependent_node.replace_all_uses_with(ordered_node) + + # Remove the original node from the graph + graph.erase_node(dependent_node) + + # Now rename the ordered node to the original name + ordered_node.name = original_name # PRESERVE ORIGINAL NAME + + # Track the replacement for future dependencies + replacements[dependent_node] = ordered_node + + +def _create_subgraph_for_node(graph: fx.Graph, node: fx.Node) -> fx.GraphModule: + """ + Create a subgraph that exactly recreates a node's operation. + + The subgraph takes only the fx.Node arguments and recreates the operation + with the exact target, args structure, and kwargs. + + Args: + graph: The parent graph + node: The node to wrap in a subgraph + + Returns: + A GraphModule containing the subgraph + """ + # Get the owning module + # torch.distributed.breakpoint(0) + owning_module = graph.owning_module + + # Create a new graph for the subgraph + subgraph = fx.Graph(owning_module) + + new_args: list[Any] = [] + placeholder_idx = 0 + for _, arg in enumerate(node.args): + if not isinstance(arg, fx.Node): + new_args.append(arg) + continue + + placeholder = subgraph.placeholder(f"arg_{placeholder_idx}") + placeholder_idx += 1 + if "val" in arg.meta: + placeholder.meta.update(arg.meta) + new_args.append(placeholder) # type: ignore[arg-type] + + new_kwargs: dict[str, Any] = {} + for key, value in node.kwargs.items(): + if not isinstance(value, fx.Node): + new_kwargs[key] = value + continue + + placeholder = subgraph.placeholder(f"kwarg_{key}") + if "val" in value.meta: + placeholder.meta.update(value.meta) + + new_kwargs[key] = placeholder # type: ignore[assignment] + + # Recreate the exact original operation in the subgraph + assert callable(node.target) + result = subgraph.call_function( + node.target, + tuple(new_args), + new_kwargs, # type: ignore[arg-type] + ) + + # Copy metadata from the original node + result.meta.update(node.meta) + + out = subgraph.output(result) + if "val" in result.meta: + out.meta["val"] = result.meta["val"] + + return fx.GraphModule(owning_module, subgraph) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..44314b912786f9537286108dc33c94905a5db0de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py @@ -0,0 +1,589 @@ +# Owner(s): ["oncall: distributed"] +import collections +import inspect +import logging +import math +import operator +from collections.abc import Callable, Generator +from dataclasses import dataclass +from functools import partial +from typing import Any, cast + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from ..fx_utils import get_fake_args_kwargs +from ..virtualized import V + + +aten = torch.ops.aten +logger: logging.Logger = logging.getLogger("comm_fusion") + + +def move_block_after(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.append(node) + target_node = node + + +def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.prepend(node) + target_node = node + + +def call_function( + graph: fx.Graph, + target: str | Callable[..., Any], + args: tuple[fx.node.Argument, ...] | None = None, + kwargs: dict[str, fx.node.Argument] | None = None, +) -> fx.Node: + # We accept target as a str to avoid typing error as the type of + # a node.target is str | Callable[..., Any]. + # This also allows us to avoid writing check for every call. + if isinstance(target, str): + raise RuntimeError(f"Call function should not get a str target {target=}") + node = graph.call_function(target, args, kwargs) + _, args, kwargs = get_fake_args_kwargs(node) + with V.fake_mode: + node.meta["val"] = target(*args, **kwargs) + # node.meta["val"] may be a container. So we use tree_map here + # to recursively extract the tensor metadata. + node.meta["tensor_meta"] = tree_map( + _extract_tensor_metadata, (node.meta["val"],) + )[0] + return node + + +@dataclass(unsafe_hash=True) +class CommBlock: + shape: torch.Size | list[torch.Size] + node_list: list[fx.Node] + inputs: list[fx.Node] + wait_nodes: list[fx.Node] + comm_node: fx.Node + outputs: OrderedSet[fx.Node] + + +def get_comm_block(comm_node: fx.Node) -> CommBlock | None: + """ + Given a collective node (e.g., allreduce), find out all the nodes belong to + this communication. + + Args: + comm_node(fx.Node): The target communication/collective node. + Returns: + The CommBlock that encapsulates the related nodes (e.g., wait_node) of + the given comm_node. + """ + node_list = [] + wait_nodes = [] + inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs)) + input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)] + # If the users of the wait node are following items, we consinder them + # to be a part of the output. + intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias") + + first_user = next(iter(comm_node.users)) + if ( + len(comm_node.users) == 1 + and first_user.target is torch.ops._c10d_functional.wait_tensor.default + ): + # Collective with only one output + node_list = [comm_node, first_user] + wait_nodes.append(first_user) + elif len(comm_node.users) > 1 and first_user.target is operator.getitem: + # Collective with only more than one output + node_list.append(comm_node) + for user in comm_node.users: + if user.target != operator.getitem: + return None + if len(user.users) != 1: + return None + wait_node = next(iter(user.users)) + if wait_node.target != torch.ops._c10d_functional.wait_tensor.default: + return None + wait_nodes.append(wait_node) + node_list.append(user) + node_list.extend(wait_nodes) + else: + return None + + # Identify all the outputs of this collective block. + outputs = OrderedSet[fx.Node]() + nodes = collections.deque(wait_nodes) + while nodes: + node = nodes.popleft() + for user in node.users: + if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs): + nodes.append(user) + node_list.append(user) + else: + outputs.add(node) + break + + tensor_meta = input_nodes[0].meta["tensor_meta"] + shape: torch.Size | list[torch.Size] + if isinstance(tensor_meta, TensorMetadata): + shape = tensor_meta.shape + elif isinstance(tensor_meta, (list, tuple)): + shape = [tm.shape for tm in tensor_meta] + else: + logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta)) + return None + + return CommBlock( + shape=shape, + node_list=node_list, + wait_nodes=wait_nodes, + comm_node=comm_node, + inputs=input_nodes, + outputs=outputs, + ) + + +def get_all_comm_blocks( + graph: fx.Graph, + comm_ops: tuple[torch._ops.OpOverload, ...], + comm_filter: Callable[..., bool] | None = None, +) -> list[CommBlock]: + if comm_filter is None: + + def always_true(comm_block: CommBlock) -> bool: + return True + + comm_filter = always_true + + blocks = [] + for node in graph.nodes: + if node.target not in comm_ops: + continue + comm_block = get_comm_block(node) + if comm_block is not None and comm_filter(comm_block): + blocks.append(comm_block) + return blocks + + +def _fuse_allreduce_by_concat( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce using concat.""" + # Flatten all the inputs to the all_reduce nodes. + with graph.inserting_after(last_input_node): + cat_inputs = [] + for input_node in all_input_nodes: + assert isinstance(input_node.args[0], fx.Node) + input_node = input_node.args[0] + cat_inputs.append( + call_function(graph, aten.flatten.using_ints, (input_node,)) + ) + + # Concat all the flattened nodes. + with graph.inserting_after(cat_inputs[0]): + cat_node = call_function(graph, aten.cat, (cat_inputs,)) + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_after(cat_node): + div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0])) + + # Create a new Comm/all_reduce node. + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + with graph.inserting_after(div_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = div_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs) + + # Create a new Wait node. + with graph.inserting_after(fused_comm_node): + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + flatten_args[0] = fused_comm_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs) + + # Move the fused all_reduce and its args to right after the input node + nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] + # pyrefly: ignore [bad-argument-type] + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape, + node_list=[fused_comm_node, fused_wait_node], + wait_nodes=[fused_wait_node], + comm_node=fused_comm_node, + inputs=[div_node], + outputs=OrderedSet([fused_wait_node]), + ) + + +def _fuse_with_coalesced_op( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce by coalesced.""" + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + dividends = [div.args[0] for div in all_input_nodes] + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_before(last_input_node): + last_input_node = call_function( + graph, aten._foreach_div.Scalar, (dividends, divisors[0]) + ) + input_node = last_input_node + + # Create a new Comm/all_reduce_coalesced node. + with graph.inserting_after(last_comm_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = input_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function( + graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs + ) + + # Create a new wait node. + getitem_nodes = [] + wait_nodes = [] + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + for idx in range(len(all_input_nodes)): + with graph.inserting_after(fused_comm_node): + gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx)) + getitem_nodes.append(gi_node) + flatten_args[0] = gi_node + args, kwargs = tree_unflatten(flatten_args, spec) + with graph.inserting_after(gi_node): + wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs)) + + # Move the new all_reduce_coalesced and its args to right after the input node + nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=[ + tm.shape + for tm in cast( + list[TensorMetadata], fused_comm_node.meta.get("tensor_meta") + ) + ], + node_list=[fused_comm_node] + getitem_nodes + wait_nodes, + wait_nodes=wait_nodes, + comm_node=fused_comm_node, + inputs=[input_node], + outputs=OrderedSet(wait_nodes), + ) + + +def _scatter_fused_allreduce_waits( + graph: fx.Graph, + fused_comm_block: CommBlock, + orig_comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + split_and_reshape: bool = True, +) -> None: + """ + Scatters the result of the fused communication node to the original users. + If the fused method is concat splitting the output and reshape will be inserted, + before inserting getitem. Otherwise getitem will be used as the users of the + wait node. + """ + + # Before we mass up the order, we need to get the index of the last wait node + # in orig_comm_blocks. This index will be later used to determine what users + # nodes need to be move to maintain a correct topological sort order. + last_wait_node_idx = 0 + # pyrefly: ignore [bad-assignment] + for node in graph.nodes: + last_wait_node_idx = max( + node_indices.get(node, last_wait_node_idx), last_wait_node_idx + ) + if node == orig_comm_blocks[-1].wait_nodes[0]: + break + + if split_and_reshape: + fused_wait_node = fused_comm_block.wait_nodes[0] + with graph.inserting_after(fused_wait_node): + split_node = call_function( + graph, + aten.split, + ( + fused_wait_node, + [math.prod(cast(list[int], cb.shape)) for cb in orig_comm_blocks], + ), + ) + with graph.inserting_after(split_node): + fused_outputs = [] + for idx, comm_block in enumerate(orig_comm_blocks): + split_idx_node = call_function( + graph, operator.getitem, (split_node, idx) + ) + with graph.inserting_after(split_idx_node): + fused_outputs.append( + call_function( + graph, aten.reshape, (split_idx_node, comm_block.shape) + ) + ) + else: + fused_outputs = fused_comm_block.wait_nodes + + # Scatter the fused outputs. + incorrect_order_nodes = [] + for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs): + # Some descendant users of the orig_comm_blocks may be scheduled before + # the fused all_reduce. For example, the user nodes of the very first + # all_reduce may be scheduled before the second all_reduce. Since the + # fused all_reduce is inserted right after the last all_reduce, the + # order can be wrong. + # `incorrect_order_nodes` records these nodes. + + orig_wait = comm_block.wait_nodes[0] + nodes = collections.deque(list(orig_wait.users)) + while nodes: + user_node = nodes.popleft() + if not isinstance(user_node, fx.Node): + continue + # pyrefly: ignore [unsupported-operation] + if node_indices[user_node] < last_wait_node_idx: + incorrect_order_nodes.append(user_node) + nodes.extend(list(user_node.users)) + + orig_wait.replace_all_uses_with(fused_output) + + last_fused_result = fused_outputs[0] + fused_outputs_set = OrderedSet(fused_outputs) + for node in graph.nodes: + if node in fused_outputs_set: + last_fused_result = node + + # Move the incorrect_order_nodes to right after the last fused_result. + incorrect_order_nodes = sorted( + incorrect_order_nodes, key=lambda node: node_indices[node] + ) + move_block_after(incorrect_order_nodes, last_fused_result) + + +def _fuse_allreduce( + graph: fx.Graph, + comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + use_concat: bool, +) -> CommBlock: + """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock.""" + + if len(comm_blocks) == 1: + return comm_blocks[0] + + # Find the last input node of all the CommBlocks. This node will be served + # as the inserting point of the new collective op. + last_input_node = comm_blocks[0].inputs[0] + last_input_index = -1 + all_input_nodes = [] + for comm_block in comm_blocks: + input_node = comm_block.inputs[0] + all_input_nodes.append(input_node) + index = node_indices[input_node] + if index >= last_input_index: + assert index != last_input_index + last_input_node = input_node + last_input_index = index + + if use_concat: + fused_comm_block = _fuse_allreduce_by_concat( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + else: + fused_comm_block = _fuse_with_coalesced_op( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + + _scatter_fused_allreduce_waits( + graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat + ) + + for comm_block in comm_blocks: + for wait in comm_block.wait_nodes: + graph.erase_node(wait) + graph.erase_node(comm_block.comm_node) + graph.eliminate_dead_code() + + return fused_comm_block + + +def _bucket_size_fusion( + graph: fx.Graph, comm_blocks: list[CommBlock], bucket_size_mb: int +) -> Generator[list[CommBlock], None, None]: + MB = 1024**2 + bucket_size = 1 * MB + bucket_cap_size = bucket_size_mb * MB + curr_size = 0 + curr_blocks = [] + + count = 0 + fuse_count = 0 + for i, block in enumerate(comm_blocks): + curr_blocks.append(block) + itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize + curr_size += cast(torch.Size, block.shape).numel() * itemsize + count += 1 + if curr_size < bucket_size and i != len(comm_blocks) - 1: + continue + + fuse_count += 1 + if torch.distributed.get_rank() == 0: + logger.info( + "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d", + fuse_count, + count, + curr_size, + bucket_size, + ) + + # Set the debug counters + counters["inductor"]["ddp_buckets"] = fuse_count + yield curr_blocks + + bucket_size = bucket_cap_size + curr_blocks = [] + curr_size = 0 + count = 0 + + +def _fuse_ddp_communication( + graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any] +) -> None: + for output in reversed(graph.nodes): + if output.op == "output": + break + + def ddp_reducer_filter(block: CommBlock) -> bool: + if ( + not isinstance(block.comm_node.args[0], fx.Node) + or block.comm_node.args[0].target != aten.div.Tensor + ): + return False + + if len(block.wait_nodes[0].users) != 1: + # gradient/wait node should only be used by one user + return False + + # Two cases: + # 1. gradient/wait node should be directly used by the output + # if gradient is None before bwd. + # 2. gradient/wait node should be directly used by copy_. + if ( + output not in block.wait_nodes[0].users + and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default + ): + return False + + return True + + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter) + node_indices = {node: i for i, node in enumerate(graph.nodes)} + + for block in algorithm_fn(graph, comm_blocks): + fusion_fn(graph, block, node_indices) + + +def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=False), + ) + + +def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=True), + ) + + +def schedule_comm_wait(graph: fx.Graph) -> None: + """ + Delay the execution of wait tensors of allreduce until its first user. + + This algorithm considers the intermediate users, like split, getitem, + of the wait node and schedule those intermediate users as well. + This will result in a better overlapping result. + """ + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_reduce_coalesced.default, + torch.ops._c10d_functional.all_reduce_coalesced_.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops) + if not comm_blocks: + return + + # Find all the end users. + allreduce_users = OrderedSet[fx.Node]() + for allreduce in comm_blocks: + for output in allreduce.outputs: + allreduce_users.update(output.users) + + node_indices = {node: i for i, node in enumerate(graph.nodes)} + for allreduce in comm_blocks: + # Find the earliest/first user -- target_node. + assert len(allreduce.outputs) >= 1, ( + f"Found a allreduce that has zero outputs/users -- {allreduce}." + ) + # Initialize the target node to avoid typing issues. + target_node = next(iter(next(iter(allreduce.outputs)).users)) + target_node_index = 2**31 + for user in (user for output in allreduce.outputs for user in output.users): + index = node_indices[user] + if index < target_node_index: + target_node = user + target_node_index = index + + # Move wait nodes and all the subsequent nodes in the comm_block to + # before the first user -- target_node. + wait_idx = -1 + for wait_idx, node in enumerate(allreduce.node_list): + if node == allreduce.wait_nodes[0]: + break + assert wait_idx >= 0 + move_block_before(allreduce.node_list[wait_idx:], target_node) + + +def fuse_ddp_communication( + graph: fx.Graph, passes: list[Callable[..., None] | str], bucket_size_mb: int +) -> None: + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, f"fuse_ddp_communication_pass_{i}" + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in OrderedSet( + v.name for v in inspect.signature(func).parameters.values() + ): + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..3613ab1ed17b5e35815d1bca359b94b29b511abc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch import Tensor +from torch._dynamo.utils import counters, is_node_meta_valid +from torch.fx.experimental.symbolic_shapes import ( + statically_known_false, + statically_known_true, +) + +from .. import config +from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern +from .split_cat import construct_pattern_matcher_pass + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +# The following two constants are for CUDA device only +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 +# The following two constants are for CPU device only +CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1 +CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048 + +min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION +max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION +cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION +if "decompose_mm_pass" in config.post_grad_fusion_options: + min_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) + max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION + ) + cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: + return (a.device.type == b.device.type) and (b.device.type == device) + + +def realize_inputs(inputs: list[torch.fx.Node]): + for inp in inputs: + if isinstance(inp, torch.fx.node.Node): + inp.meta["inductor_realize_to_strides"] = True + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if check_device(mat1, mat2, device="cuda") or check_device( + mat1, mat2, device="xpu" + ): + if mat1.shape[0] < min_first_dimension_decomposition: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + # use bool() to deal with BooleanAtom type + if ( + bool(mat1.shape[1] < max_other_dimension_decomposition) + + bool(mat1.shape[2] < max_other_dimension_decomposition) + + bool(mat2.shape[2] < max_other_dimension_decomposition) + < 2 + ): + return False + return True + elif check_device(mat1, mat2, device="cpu"): + if ( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + and mat2.shape[0] <= cpu_max_first_dimension_decomposition + ): + return True + return False + + +def should_decompose_mm(mat1, mat2) -> bool: + """ + Determines whether matrix multiplication (mm) should be decomposed into pointwise operations + based on the input matrices' metadata, shapes, device placement, and configuration options. + Args: + mat1: The first matrix operand. Expected to be an object with a `.meta` attribute containing + a "val" key, or a tensor-like object with a `.shape` attribute. + mat2: The second matrix operand. Same requirements as `mat1`. + Returns: + bool: True if the matrix multiplication should be decomposed according to the following logic: + - Both inputs must have valid node metadata. + - Both matrices must be 2-dimensional. + - If the configuration option `skip_dynamic_shape_dim_check` is False: + - Decomposition is only considered for statically-shaped matrices. + - For CUDA devices: `mat1.shape[0]` must be at least `min_first_dimension_decomposition`, + and both dimensions of `mat2` must be less than `max_other_dimension_decomposition`. + - For CPU devices: All relevant dimensions must be less than or equal to their respective + CPU decomposition thresholds. + - If `skip_dynamic_shape_dim_check` is True: + - Decomposition is considered for dynamic shapes as well, using a combination of + `statically_known_true` and `statically_known_false` checks to handle uncertainty. + - The same dimension and device checks apply, but allow for dynamic/static uncertainty. + - Returns False if any of the above conditions are not met. + Notes: + - Relies on helper functions such as `is_node_meta_valid`, `check_device`, `statically_known_true`, + and `statically_known_false`, as well as configuration values like + `min_first_dimension_decomposition`, `max_other_dimension_decomposition`, etc. + - Designed for use in graph optimization or fusion passes where decomposing large or dynamic + matrix multiplications can improve performance or memory usage. + """ + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 2 or len(mat2.shape) != 2: + return False + # case 1: we skip decompose mm if the input is dynamic shape + if not config.post_grad_fusion_options["decompose_mm_pass"].get( + "skip_dynamic_shape_dim_check", False + ): + return ( + ( + check_device(mat1, mat2, device="cuda") + or check_device(mat1, mat2, device="xpu") + ) + and statically_known_true( + mat1.shape[0] >= min_first_dimension_decomposition + ) + and statically_known_true(mat2.shape[0] < max_other_dimension_decomposition) + and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) + ) or ( + check_device(mat1, mat2, device="cpu") + and statically_known_true( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + ) + and statically_known_true( + mat2.shape[0] <= cpu_max_other_dimension_decomposition + ) + and statically_known_true( + mat2.shape[1] <= cpu_max_other_dimension_decomposition + ) + ) + # case 2: we decompose mm if the input is dynamic shape + else: + return ( + ( + check_device(mat1, mat2, device="cuda") + or check_device(mat1, mat2, device="xpu") + ) + and ( + statically_known_true( + mat1.shape[0] >= min_first_dimension_decomposition + ) + or not statically_known_false( + mat1.shape[0] >= min_first_dimension_decomposition + ) + ) + and ( + statically_known_true(mat2.shape[0] < max_other_dimension_decomposition) + or not statically_known_false( + mat2.shape[0] < max_other_dimension_decomposition + ) + ) + and ( + statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) + or not statically_known_false( + mat2.shape[1] < max_other_dimension_decomposition + ) + ) + ) or ( + check_device(mat1, mat2, device="cpu") + and ( + statically_known_true( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + ) + or not statically_known_false( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + ) + ) + and ( + statically_known_true( + mat2.shape[0] <= cpu_max_other_dimension_decomposition + ) + or not statically_known_false( + mat2.shape[0] <= cpu_max_other_dimension_decomposition + ) + ) + and ( + statically_known_true( + mat2.shape[1] <= cpu_max_other_dimension_decomposition + ) + or not statically_known_false( + mat2.shape[1] <= cpu_max_other_dimension_decomposition + ) + ) + ) + + +def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to( + mat1.dtype + ) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return ( + torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1 + ) + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + realize_inputs([mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py new file mode 100644 index 0000000000000000000000000000000000000000..7b431c2f17117ae0c9e570072759a72417711562 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Any + +import torch +from torch import SymBool, SymFloat, SymInt +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet + + +@dataclass +class _SymExprHash: + """ + Hash for a py_sym_types that will use the underlying sympy expression + """ + + sym_obj: SymInt | SymFloat | SymBool + + def __hash__(self) -> int: + return hash((type(self.sym_obj), self.sym_obj.node.expr)) + + def __eq__(self, value) -> bool: + if not isinstance(value, _SymExprHash): + return False + return self.sym_obj.node.expr == value.sym_obj.node.expr + + +class _SymHashingDict: + """ + Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse + existing sym proxies. + + SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, + fallback to symnodes. + """ + + def __init__(self): + self.sym_hash_dict = {} + + def __setitem__(self, key, value): + self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) + + def __getitem__(self, key): + return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] + + def __contains__(self, key): + return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict + + def get(self, key, default=None): + return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) + + def _wrap_to_sym_expr_hash(self, key): + return _SymExprHash(key) if isinstance(key, py_sym_types) else key + + +def dedupe_symints(graph: torch.fx.Graph): + """ + Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. + + We only dedupe from graph inputs to avoid adding a potential dependency in the forward + from the backward. + + """ + + sym_dict = _SymHashingDict() + resolvable_from_input_symints = OrderedSet[Any]() + + for node in graph.nodes: + val = node.meta.get("val", None) + if val is None or not isinstance(val, py_sym_types): + continue + + if node.op == "placeholder": + resolvable_from_input_symints.add(node) + sym_dict[val] = node + elif existing_node := sym_dict.get(val): + node.replace_all_uses_with(existing_node) + graph.erase_node(node) + elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): + sym_dict[val] = node + resolvable_from_input_symints.add(node) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..72c853f7e5f66c980222244e822942d2fad640f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import ( + CallFunctionVarArgs, + CallModuleVarArgs, + Match, + register_graph_pattern, +) +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + assert bn.running_mean is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +def efficient_conv_bn_eval_decomposed( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv: torch._ops.OpOverload, + conv_weight, + conv_bias, + x, + conv_remainging_args, +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + """ + assert bn_running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv_weight + if conv_bias is not None: + bias_on_the_fly = conv_bias + else: + bias_on_the_fly = torch.zeros_like(bn_running_var) + + if bn_weight is None: + bn_weight = torch.ones_like(bn_running_var) + + if bn_bias is None: + bn_bias = torch.zeros_like(bn_running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv_weight.ndim - 1) + if "conv_transpose" in conv.__str__(): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn_running_mean + ) + + input = x + return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.ops.aten.batch_norm.default, + ] + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 9 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-4]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_weight = bn_node.args[1] + bn_bias = bn_node.args[2] + bn_running_mean = bn_node.args[3] + bn_running_var = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): # type: ignore[arg-type] + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", + target=conv_node.target, # type: ignore[union-attr] + name="get_conv", + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fca2087a5d5220b7256f313bbb25d2d23d9ab7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,310 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop +from torch._inductor.utils import GPU_TYPES + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_computation_ops( + gm + ) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_computation_ops( + gm + ) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + while pass_number > len(pass_patterns) - 1: + pass_patterns.append(PatternMatcherPass()) + return register_graph_pattern( + pattern, + extra_check=extra_check, + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + # pyrefly: ignore [bad-argument-type] + pass_dict=binary_folding_pass, + ) + + +@functools.cache +def addmm_patterns_init(): + """ + addmm related patterns. + To avoid duplication, also includes int8 WoQ GEMM pattern without bias. + """ + device = next( + (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" + ) + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False) + + def check_int8_woq_concat_linear_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if not is_cpu or not config.cpp.enable_concat_linear: + # Currently, this pattern is only supported on CPU + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + if not all( + match.kwargs[wgt].target is torch.ops.prims.convert_element_type.default + for wgt in weight_inputs + ): + return False + + if not all( + next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype + is torch.int8 + for wgt in weight_inputs + ): + return False + + if not all( + match.kwargs[wgt].meta["val"].dtype is torch.bfloat16 + for wgt in weight_inputs + ): + return False + + return True + + def check_concat_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if is_cpu and not config.cpp.enable_concat_linear: + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + equal_shape_inputs = [weight_inputs] + + if "b1" in match.kwargs: + bias_inputs = ["b1", "b2"] + if "b3" in match.kwargs: + bias_inputs.append("b3") + + equal_shape_inputs.append(bias_inputs) + + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + + if not all( + inp.op == "get_attr" + and inp.meta["val"].shape == inps[0].meta["val"].shape + for inp in inps + ): + return False + return True + + def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3): + return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3) + + def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_s = torch.cat((s1, s2, s3), dim=0) + mm = (inp @ cat_w).mul(cat_s) + n1, n2 = w1.size(1), w2.size(1) + return mm.tensor_split([n1, n1 + n2], dim=-1) + + register_replacement( + # pyrefly: ignore [bad-argument-type] + int8_woq_fusion_pattern, + # pyrefly: ignore [bad-argument-type] + int8_woq_fusion_replacement, + [val(), val(), val(), val(), scale(), scale(), scale()], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_patterns[0], + extra_check=check_int8_woq_concat_linear_weights, + exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"), + ) + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + # pyrefly: ignore [bad-argument-type] + matmul_fuse_pattern, + # pyrefly: ignore [bad-argument-type] + matmul_replacement, + [val(), val(), val(), val()], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + # pyrefly: ignore [bad-argument-type] + matmul_fuse_pattern_two, + # pyrefly: ignore [bad-argument-type] + matmul_replacement_two, + [val(), val(), val()], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + # pyrefly: ignore [bad-argument-type] + addmm_fuse_pattern_second, + # pyrefly: ignore [bad-argument-type] + addmm_fuse_replacement_second, + [val() for _ in range(7)], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + graph.erase_node(node) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..1e71c350ed7b67b47e7a77af7cbd4b93bfc48f98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fsdp.py @@ -0,0 +1,115 @@ +import logging +from collections.abc import Callable + +import torch +from torch._inductor.fx_passes.bucketing import ( + bucket_all_gather_by_mb, + bucket_reduce_scatter_by_mb, + BucketMode, + merge_all_gather, + merge_reduce_scatter, +) + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool: + # Assume all_gather_into_tensor input is either graph input + # or dtype conversion of graph input + ag_node = wait.args[0] # type: ignore[arg-type, union-attr] + return ( + is_graph_input(ag_node.args[0]) # type: ignore[arg-type, union-attr] + or ( # type: ignore[arg-type, union-attr] + ag_node.args[0].op == "call_function" # type: ignore[arg-type, union-attr] + and ag_node.args[0].target # type: ignore[arg-type, union-attr] + == torch.ops.prims.convert_element_type.default # type: ignore[arg-type, union-attr] + and is_graph_input(ag_node.args[0].args[0]) # type: ignore[arg-type, union-attr] + ) + ) + + +def is_graph_output(node: torch.fx.Node) -> bool: + return all(user.op == "output" for user in node.users) + + +def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool: + if is_graph_output(wait): + return True + + if len(wait.users) == 1: + user = next(iter(wait.users)) + assert user is not None + return ( + is_graph_output(user) + and user.op == "call_function" + and user.target is torch.ops.prims.convert_element_type.default + ) + + return False + + +def bucket_fsdp_all_gather( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: BucketMode = "default", +) -> None: + """ + Bucketing pass for SimpleFSDP all_gather ops. + + Attributes: + gm (torch.fx.GraphModule): Graph module of the graph. + bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that + takes in bucket id and returns size of a bucket in megabytes. + """ + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + assert bucket_cap_mb_by_bucket_idx is not None + ag_buckets = bucket_all_gather_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + filter_wait_node=is_fsdp_all_gather_wait, + ) + if len(ag_buckets) == 0: + return + merge_all_gather(gm, ag_buckets, mode) + + +def bucket_fsdp_reduce_scatter( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: BucketMode = "default", +) -> None: + """ + Bucketing pass for SimpleFSDP reduce_scatter ops. + + Attributes: + gm (torch.fx.GraphModule): Graph module of the graph. + bucket_cap_mb_by_bucket_idx (Callable[[int], float] | None): callback function that + takes in bucket idx and returns size of a bucket in megabytes. By default + torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used. + + """ + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + rs_buckets = bucket_reduce_scatter_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + filter_wait_node=is_fsdp_reduce_scatter_wait, + ) + if len(rs_buckets) == 0: + return + merge_reduce_scatter(gm, rs_buckets, mode) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9a09d2531348849ed997fc762aef44a09c43e6a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py @@ -0,0 +1,1152 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import logging +import math + +import torch + +from ..._dynamo.utils import counters +from ..pattern_matcher import ( + filter_nodes, + fwd_only, + gen_register_replacement, + joint_fwd_bwd, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +_scaled_dot_product_attention = aten.scaled_dot_product_attention + + +def _sfdp_pattern_1(query, key, value, inv_scale): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_1(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_2(query, key, value, scale_factor): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_2(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale_factor) + .softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_5(query, key, value, attn_mask): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + # attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ value + + +def _sfdp_replacement_5(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_7(query, key, value, dropout_p): + # in real workloads inputs to matmul are permuted + # causing matmul to expand to a series of expand and clone calls + # we want the same to happen during pattern tracing + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_7(query, key, value, dropout_p): + # sdpa prefers inputs in permuted format + # it makes a copy to put them in this format + # if they aren't already + # to make replacement efficient ensure that inputs to sdpa + # are in required order + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_8(query, key, value): + # no dropout version of pattern 7 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_8(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_9(query, key, value, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_9(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_10(query, key, value): + # no dropout version of 9 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_10(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_11(query, key, value, inv_scale): + # Mainly for huggingface models + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) + + +def _sfdp_replacement_11(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + +def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_13(query, key, value, dropout_p): + attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) + return torch.bmm(attn_weight, value) + + +def _sfdp_replacement_13(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + dropout_p=dropout_p, + scale=1.0, + ).squeeze(0) + + +def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): + # for BertLarge + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) + .softmax(dim=-1) + .matmul(v) + ) + + +def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): + # for DistilBert + # Permutations are needed to create clones in graph. + # Ref: https://github.com/pytorch/pytorch/issues/119911 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v + + +def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): + # for BertLarge with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + torch.nn.functional.dropout( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( + dim=-1 + ), + dropout_p, + ) + .to(dtype=query.dtype) + .matmul(v) + ) + + +def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): + # for DistilBert with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): + # for hf_GPT2 with dropout (introduces clone node) for inference + # it also returns permuted key & value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + return ( + ( + torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul( + value + ) + ), + key, + value, + ) + + +def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + permuted_key = key.transpose(1, 2) + permuted_value = value.transpose(1, 2) + return ( + _scaled_dot_product_attention( + query.transpose(1, 2), + permuted_key, + permuted_value, + attn_mask=causal_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ), + permuted_key, + permuted_value, + ) + + +def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): + # for token-classification+gpt2 / text-generation+gpt2 + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + attn_weights = attn_weights + attn_mask + attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) + return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) + + +def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = torch.where(causal_mask, attn_mask, fill_value) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ) + + +def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p): + # for DistilBert with dropout transformers==4.44.2 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + q = q.div(math.sqrt(q.size(-1))) + scores = q @ k.transpose(-2, -1) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(query.size(-1)), + ) + + +def _sfdp_pattern_21(query, key, value, attn_mask): + # for T5 with inplace add + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value) + + +def _sfdp_replacement_21(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + is_causal=False, + scale=1.0, + ) + + +def _sfdp_pattern_22(query, key, value, attn_mask): + # for T5 with inplace add and return key and value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_22(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_pattern_23(query, key, value): + # for T5 with inplace add and + # return key and value and + # attn_mask is generated by atem.full(..., 0) + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + fp32_score = score.float() + score = fp32_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_23(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_pattern_24(query, key, value, attention_mask): + """ + this pattern is for MBartForCausalLM/PLBartForCausalLM. + attn_mask has a different dtype with QKV. + there is no scale in sdpa. + """ + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + head_size = query.size(3) + q = query.view(bs * n_head, -1, head_size) + k = key.reshape(bs * n_head, -1, head_size) + v = value.reshape(bs * n_head, -1, head_size) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + if query.dtype == torch.half: + attn_weights = attn_weights.to(torch.half) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, head_size) + return attn_output + + +def _sfdp_replacement_24(query, key, value, attention_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask.to(dtype=query.dtype), + is_causal=False, + scale=1, + ) + + +def _sfdp_params_check(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + # When we tensorify floats we end up turning floats + # into 0d scalar tensors. It doesn't make any sense + # to have a 0d scalar tensor attention mask so + # conveniently we can insert this check to get + # tests that erroneously passing in a float + # attention mask to fail as expected. + or attn_mask.dim() == 0 + ): + return False + return True + + +def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check(match) + + return fn + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def _get_sfdp_patterns(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + # attn_mask + b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) + # need 2d attn_mask to generate patterns with view op + m_inp_2d = functools.partial(torch.empty, (2, 4), device=device) + # inv_scale + c_inp = functools.partial(torch.tensor, 2.0, device=device) + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + d = {"dropout_p": 0.113377} + + # we could also generate all these patterns in 3d.. TODO + g_3d_inp = functools.partial( + torch.empty, (1024, 128, 128), device=device, requires_grad=True + ) + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + + # softmax will generate a dtype conversion on inputs if they are in half, + # but will not in float, so we generate a pattern for both + for dtype in [torch.float, torch.half]: + g = functools.partial(g_inp, dtype=dtype) + b = functools.partial(b_inp, dtype=dtype) + b_float = functools.partial(b_inp, dtype=torch.float) + b_bool = functools.partial(b_inp, dtype=torch.bool) + m = functools.partial(m_inp, dtype=dtype) + m_float = functools.partial(m_inp, dtype=torch.float) + m_bool = functools.partial(m_inp, dtype=torch.bool) + m_2d = functools.partial(m_inp_2d, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + g_3d = functools.partial(g_3d_inp, dtype=dtype) + g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) + m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool) + + candidates = [ + ( + _sfdp_pattern_1, + _sfdp_replacement_1, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_2, + _sfdp_replacement_2, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_3, + _sfdp_replacement_3, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_4, + _sfdp_replacement_4, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_5, + _sfdp_replacement_5, + [g(), g(), g(), b()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_6, + _sfdp_replacement_6, + [g(), g(), g(), b()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_7, + _sfdp_replacement_7, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_8, + _sfdp_replacement_8, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_9, + _sfdp_replacement_9, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_10, + _sfdp_replacement_10, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_11, + _sfdp_replacement_11, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_12, + _sfdp_replacement_12, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g_3d(), g_3d(), g_3d()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), m_2d(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_17, + _sfdp_replacement_17, + [g(), g(), g(), m_2d(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g(), g(), g(), m_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_19, + _sfdp_replacement_19, + [g(), g(), g(), b_bool(), b_float()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_20, + _sfdp_replacement_20, + [g(), g(), g(), m_2d()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g_bs1(), g_bs1(), g_bs1()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_24, + _sfdp_replacement_24, + [g(), g(), g(), b_float()], + {}, + _sfdp_extra_check, + ), + ] + mask_fp32_patterns = ["pattern_16"] + if dtype == torch.half: + # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + + for pattern, replacement, args, workaround, extra_check in candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + if ( + any(p in name for p in mask_fp32_patterns) + and args[3].dtype == torch.float32 + ): + name += "_mask_fp32" + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield ( + training_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + }, + ) + + if workaround: + assert len(workaround) == 1 and "dropout_p" in workaround + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield ( + inference_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, + }, + ) + + +@functools.cache +def _sfdp_init(): + for key, register_replacement_kwargs in _get_sfdp_patterns(): + gen_register_replacement(key, **register_replacement_kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py new file mode 100644 index 0000000000000000000000000000000000000000..5758551a9b8a5cad4f2a5aa1a21357a9ab12cfcc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/graph_view.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import itertools +import re +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch.fx as fx # noqa: TC001 +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + from collections.abc import Callable + + +def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]: + nn_stack = node.meta.get("nn_module_stack", "") + if nn_stack: + return list(nn_stack.values()) + + fwd_nn_stack = node.meta.get("fwd_nn_module_stack", "") + if fwd_nn_stack: + return list(fwd_nn_stack.values()) + + return [] + + +def _addindent(s_: str, num_spaces: int) -> str: + s: list[str] = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first: str = s.pop(0) + s: list[str] = [(num_spaces * " ") + line for line in s] + joint_s: str = "\n".join(s) + joint_s = first + "\n" + joint_s + return joint_s + + +class GraphView: + """ + A hierarchical class for organizing and managing torch.fx nodes by their module stack. + + This class provides a tree-like structure where each node in the hierarchy corresponds + to a module or submodule in a traced FX graph. Each `GraphView` instance can hold a list + of FX nodes (`self.data`) belonging to that module scope, maintain a unique set of nodes + (`self.unique_nodes`), and manage its child containers (`self.children`). + + Attributes: + name (str): The name of the module or container scope. + klass (type[Any]): The class type associated with this module/container. + data (list[fx.Node]): A list of FX graph nodes belonging to this module. + unique_nodes (OrderedSet[fx.Node]): A deduplicated set of nodes to ensure no duplicates. + children (dict[str, GraphView]): A mapping of child module names to their corresponding GraphView instances. + """ + + def __init__(self, name: str, klass: type[Any]) -> None: + self.name: str = name + self.klass: type[Any] = klass + self.data: list[fx.Node] = [] + self.unique_nodes: OrderedSet[fx.Node] = OrderedSet() + self.children: dict[str, GraphView] = {} + + def add(self, data: fx.Node) -> None: + if data not in self.unique_nodes: + self.data.append(data) + self.unique_nodes.add(data) + + def get_child( + self, module_stack: str, klass: Optional[type[Any]] = None + ) -> GraphView: + if module_stack not in self.children: + new_stack = GraphView(module_stack, klass or self.klass) + self.children[module_stack] = new_stack + return self.children[module_stack] + + def __getitem__(self, name: str) -> GraphView: + return self.children[name] + + def __getattr__(self, name: str) -> GraphView: + return self.children[name] + + def __repr__(self) -> str: + child_lines: list[str] = [] + for name, child in self.children.items(): + mod_str = repr(child) + mod_str = _addindent(mod_str, 2) + child_lines.append(f"({name}): {mod_str}") + main_str = f"{self.klass.__name__}(" + if child_lines: + main_str += "\n " + "\n ".join(child_lines) + "\n" + main_str += ")" + return main_str + + +def _clean_stack_name(stack_name: str) -> str: + """ + Clean up FX node's nn_module_stack metadata string to match the module name hierarchies + + Example: + Input: "L['self']._modules['layers']['0']._modules['attention']" + Output: "layers.0.attention" + """ + cleaned = re.sub(r"^L\['self'\]\.?", "", stack_name) + parts = re.findall(r"\['([^']+)'\]", cleaned) + return ".".join(parts) if parts else cleaned + + +def _is_root(stack: str) -> bool: + return stack == "" + + +def make_graph_view( + graph: fx.Graph, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, +) -> Optional[GraphView]: + """ + Code from: https://github.com/meta-pytorch/autoparallel/pull/158 + + Make a graph view from the fx.Graph. This is a tree structure that + represents the module hierarchy of the graph, and enables us to + easily find the nodes that belong to each module, and gives a slightly + easier way of visualize different parts of the graph by extracting + subgraphs that belong to a particular module FQN. + + For example, if we have the following model with module hierarchy: + + Transformer( + (tok_embeddings): Embedding(128256, 4096) + (layers): ModuleDict( + (0): TransformerBlock( + (attention): Attention( + (wq): Linear(in_features=4096, out_features=4096, bias=False) + (wk): Linear(in_features=4096, out_features=1024, bias=False) + (wv): Linear(in_features=4096, out_features=1024, bias=False) + (wo): Linear(in_features=4096, out_features=4096, bias=False) + (sdpa): ScaledDotProductAttention() + ) + (feed_forward): FeedForward( + (w1): Linear(in_features=4096, out_features=14336, bias=False) + (w2): Linear(in_features=14336, out_features=4096, bias=False) + (w3): Linear(in_features=4096, out_features=14336, bias=False) + ) + (attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) + (ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) + ) + ) + (norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) + (output): Linear(in_features=4096, out_features=128256, bias=False) + ) + + Then we can get a GraphView for the fx.Graph that enables us to do + + graph_view = make_graph_view(graph) + subgraph = get_subgraph_by_path(graph_view, "layers.0") + + where subgraph contains all the nodes that belong to this region + + module_stack_fn: Optional callable for extracting module hierarchy information from nodes. + + Signature: Callable[[fx.Node], list[tuple[str, type[Any]]]] + + Takes an FX node and returns a list of (module_path, module_class) tuples representing + the nested module hierarchy for that node, ordered from outermost to innermost scope. + + - module_path (str): Dot-separated path identifying the module in the hierarchy + (e.g., "layers.0.attention.wq") + - module_class (type): The Python class type of the module + + This enables custom logic for determining module membership, useful for: + - Graphs without standard nn_module_stack metadata + - Filtering or grouping nodes by custom criteria + + Example of getting the module stack from annotation: + + def module_stack_fn(node): + module_stack = node.meta.get("custom", {}).get("module_path", "") + return [(module_stack, torch.nn.Module)] + + If None, defaults to extracting from node.meta["nn_module_stack"] or + node.meta["fwd_nn_module_stack"]. + """ + + def nn_module_stack_meta(node: fx.Node) -> list[tuple[str, type[Any]]]: + result = [] + for module_stack, module_class in _get_module_stack(node): + module_stack = _clean_stack_name(module_stack) + result.append((module_stack, module_class)) + return result + + if module_stack_fn is None: + module_stack_fn = nn_module_stack_meta + nodes: list[fx.Node] = list(graph.nodes) + nodes_by_module_stack_root: GraphView | None = None + for node in nodes: + for module_stack, module_class in module_stack_fn(node): + nodes_by_module_stack: GraphView | None = nodes_by_module_stack_root + for name in module_stack.split("."): + if nodes_by_module_stack is None: + nodes_by_module_stack = GraphView(name, module_class) + nodes_by_module_stack_root = nodes_by_module_stack + if _is_root(module_stack): + new_stack: GraphView = nodes_by_module_stack + else: + new_stack = nodes_by_module_stack.get_child(name, module_class) + nodes_by_module_stack = new_stack + nodes_by_module_stack.add(node) + + return nodes_by_module_stack_root + + +def get_subgraph_by_path( + graph_view: GraphView, paths: Union[str, list[str]] +) -> list[fx.Node]: + """ + Get subgraph by path(s). + Args: + graph_view (object): Root graph view object. + paths (str or list of str): Path(s) to subgraph. + Returns: + list[fx.Node]: fx nodes belong to the subgraph + """ + + def get_node_by_path(node: GraphView, path: str) -> GraphView: + for p in path.split("."): + if p in node.children: + node = node.children[p] + else: + return GraphView("", object) + return node + + if isinstance(paths, list): + nodes = list( + itertools.chain.from_iterable( + get_node_by_path(graph_view, p).data for p in paths + ) + ) + return nodes + else: + node = get_node_by_path(graph_view, paths) + return node.data diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f46d4d3ba216f15da9464e6052c36bdaa8b7c68a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py @@ -0,0 +1,1440 @@ +# mypy: allow-untyped-defs +import collections +import logging +import operator +from collections import OrderedDict +from collections.abc import Iterable, Iterator +from typing import Any + +import torch +from torch._dynamo.utils import counters, is_node_meta_valid +from torch._logging import trace_structured +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + CallFunctionVarArgs, + get_arg_value, + stable_topological_sort, +) +from ..utils import OPTIMUS_EXCLUDE_POST_GRAD + + +try: + # importing this will register fbgemm lowerings for inductor + import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 + + has_fbgemm = True +except Exception: + has_fbgemm = False + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + +DEFAULT_BETA = 1 +DEFAULT_ALPHA = 1 + +MIN_FUSE_SET_SIZE = 5 +MAX_FUSE_SET_SIZE = 300 +MAX_FUSE_SEARCH_DEPTH = 5 +# The maximum tensor size that can go into the fusion group +MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False + +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = OrderedSet([operator.getitem]) + + +default_graph_search_options = { + "min_fuse_set_size": MIN_FUSE_SET_SIZE, + "max_fuse_set_size": MAX_FUSE_SET_SIZE, + "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, + "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, +} + +graph_search_options = default_graph_search_options + + +def update_stack_example_value(node, metadata, dim=0, op=torch.stack): + """ + Update the example value of the node in the graph to enable followup split cat opt. + """ + if node is not None and hasattr(node, "meta"): + if op is torch.stack: + example_value = torch.stack(metadata, dim=dim) + elif op is torch.unbind: + example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] + else: + return + node.meta["example_value"] = example_value + + +def update_pointwise_example_value(pointwise_node, input, other, op): + """ + Update the example value of the add node in the graph to enable followup split cat opt. + """ + if pointwise_node is not None and hasattr(pointwise_node, "meta"): + if op is torch.add: + example_value = torch.add(input, other) + elif op is torch.mul: + example_value = torch.mul(input, other) + else: + return + pointwise_node.meta["example_value"] = example_value + + +class GroupBatchFusionBase: + def __init__(self, **kwargs) -> None: + self.graph_search_options = kwargs.pop( + "graph_search_options", default_graph_search_options + ) + + def match(self, node): + raise NotImplementedError("match called on base") + + def fuse(self, graph, subset): + raise NotImplementedError("fuse called on base") + + +PRE_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} +POST_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} + + +def register_fusion(name: str, pre_grad=True): + def decorator(fusion_cls: GroupBatchFusionBase): + if pre_grad: + PRE_GRAD_FUSIONS[name] = fusion_cls + else: + POST_GRAD_FUSIONS[name] = fusion_cls + return fusion_cls + + return decorator + + +def list_group_batch_fusions(pre_grad=True) -> list[str]: + if pre_grad: + return list(PRE_GRAD_FUSIONS.keys()) + else: + return list(POST_GRAD_FUSIONS.keys()) + + +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: list[Any]) -> Any: + unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function( # type: ignore[operator] + aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} + ) + unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) + stacked_inputs = graph.call_function( # type: ignore[operator] + aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} + ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] + return stacked_inputs + + +class GroupFusion(GroupBatchFusionBase): + """ + Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. + """ + + +class BatchFusion(GroupBatchFusionBase): + """ + Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. + """ + + +class BatchPointwiseOpsFusionFactory(BatchFusion): + def __init__(self, op, **kwargs) -> None: + super().__init__(**kwargs) + self.op = op + + +@register_fusion("batch_linear_post_grad", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA # type: ignore[return-value] + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + input_shapes = input.meta["val"].shape + return ( + len(input_shapes) == 2 + and isinstance(input_shapes[0], int) + and isinstance(input_shapes[1], int) + ) + + def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | None: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users] + else: + users = "" # type: ignore[assignment] + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] + return None + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) # type: ignore[possibly-undefined] + batch_weights.append(weight) # type: ignore[possibly-undefined] + batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) + + with graph.inserting_before(subset[-1]): # type: ignore[operator] + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) + fused_bmm = graph.call_function( # type: ignore[operator] + aten.bmm, + args=(fused_inputs, fused_weights), + ) + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): # type: ignore[operator] + new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) # type: ignore[operator] + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) + if batch_biases[i]: + has_bias = True + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( # type: ignore[operator] + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to( + batch_biases_meta[i]["val"], broadcast_shape + ) # type: ignore[assignment] + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) + new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["batch_linear_post_grad"] += 1 + + +@register_fusion("group_linear", pre_grad=False) +class GroupLinearFusion(GroupFusion): + def _addmm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA + and len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def _mm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + return ( + len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def match(self, node: torch.fx.Node) -> tuple[str, bool] | None: + if CallFunctionVarArgs(aten.mm.default).match( + node + ) and self._mm_node_can_be_fused(node): + group_key = ("group_linear", True) + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias = node.args[0] + group_key = ("group_linear", bias is None) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_weights = [] + group_biases = [] + group_nodes = [] + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + else: + assert CallFunctionVarArgs(aten.mm.default).match(node) + input, weight = node.args + bias = None + + group_nodes.append(node) + group_inputs.append(input) + group_weights.append(weight) + group_biases.append(bias) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + fused_mm = graph.call_function( # type: ignore[operator] + torch.ops.fbgemm.gmm.default, + args=(group_inputs, group_weights, group_biases), + kwargs={"smart_fused": True}, + ) + + for i, original_mm in enumerate(group_nodes): + with graph.inserting_after(fused_mm): # type: ignore[operator] + new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) # type: ignore[operator] + original_mm.replace_all_uses_with(new_mm) + new_mm.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["group_linear"] += 1 + + +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise math operator (e.g., add, mul) in post grad pass. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def _pointwise_node_can_be_fused(self, node: torch.fx.Node): + # note: we only consider the case where the inputs are tensors + # for mixed precision training, we need to make sure the inputs + # of the aten.cat when do the stack should be the same dtype + # otherwise, the output of the aten.cat may be not the same as + # its inputs, and cause dtype not same error in mm or addmm + input, other = node.args + return ( + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] + # input and other can be scalars, where they have no attribute 'meta' + if hasattr(input, "meta") + and hasattr(other, "meta") + and is_node_meta_valid(input) # type: ignore[arg-type, union-attr] + and is_node_meta_valid(other) # type: ignore[arg-type, union-attr] + # torch.SymInt or torch.SymFloat object has no attribute 'shape' + and isinstance(input.meta["val"], torch.Tensor) # type: ignore[union-attr] + and isinstance(other.meta["val"], torch.Tensor) # type: ignore[union-attr] + else False + ) + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(self.op).match( + node + ) and self._pointwise_node_can_be_fused(node): + alpha = node.kwargs.get("alpha", DEFAULT_ALPHA) + rounding_mode = node.kwargs.get("rounding_mode", None) + input, other = node.args + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target is aten.select or other.target is aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target is aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(shape), + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] + str(alpha), + str(rounding_mode), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs, batch_others = [], [] + alpha = subset[0].kwargs.get("alpha", DEFAULT_ALPHA) + batch_inputs_meta, batch_others_meta = [], [] + + for node in subset: + input, other = node.args + batch_inputs.append(input) + batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) + + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs, stack_others), + kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, + ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) + for i, original_add in enumerate(subset): + with graph.inserting_after(batch_op): # type: ignore[operator] + new_add = graph.call_function( # type: ignore[operator] + torch.ops.aten.select, args=((batch_op, 0, i)) + ) + original_add.replace_all_uses_with(new_add) + new_add.meta.update(original_add.meta) + graph.erase_node(original_add) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +@register_fusion("batch_linear_lhs") +class BatchLinearLHSFusion(BatchFusion): + """ + Batch linear left-hand side fusion. This pass tries to fuse the following patterns: + + torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) + -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) + + We have a separate pass to eliminate contiguous transpose in a generic way. + """ + + def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None: + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + bias = get_arg_value(node, 2, "bias") + group_key = ("batch_linear_lhs", bias is None, input) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_input = None + batch_weights, batch_weights_meta = [], [] + batch_biases, batch_biases_meta = [], [] + split_sections = [] + for node in subset: + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + batch_nodes.append(node) + if batch_input is None: + batch_input = input + else: + assert batch_input is input + batch_weights.append(weight) + batch_weights_meta.append(weight.meta["example_value"]) + if bias: + batch_biases.append(bias) + batch_biases_meta.append(bias.meta["example_value"]) + split_sections.append(weight.meta["example_value"].shape[0]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + cat_weights = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_weights,), kwargs={"dim": 0} + ) + cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0) + transposed_weights = graph.call_function( # type: ignore[operator] + torch.transpose, args=(cat_weights, 0, 1) + ) + transposed_weights.meta["example_value"] = torch.transpose( + cat_weights.meta["example_value"], 0, 1 + ) + if len(batch_biases) > 0: + cat_biases = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_biases,), kwargs={"dim": 0} + ) + cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0) + fused_lhs = graph.call_function( # type: ignore[operator] + torch.addmm, + args=(cat_biases, batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.addmm( + cat_biases.meta["example_value"], + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + else: + fused_lhs = graph.call_function( # type: ignore[operator] + torch.mm, + args=(batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.mm( + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + fused_lhs_list = graph.call_function( # type: ignore[operator] + torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} + ) + + for i, node in enumerate(batch_nodes): + with graph.inserting_after(fused_lhs_list): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(fused_lhs_list, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_linear_lhs"] += 1 + + +# Poor person's check for if a node in the graph mutates its input. +# (the graph is torch IR, so we will see torch fns and python operators) +def _is_mutable_node(tgt): + if str(tgt).endswith("_"): + # e.g. torch.mul_, torch.Tensor.mul_ + return True + if ( + hasattr(tgt, "__module__") + and tgt.__module__ == "_operator" + and tgt.__name__.startswith("i") + ): + # e.g. operator.iand, operator.imul + return True + return False + + +def is_linear_node_can_be_fused(node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + return ( + is_node_meta_valid(node) + and is_node_meta_valid(input) + and is_node_meta_valid(weight) + and len(input.meta["example_value"].shape) == 2 + and len(weight.meta["example_value"].shape) == 2 + # the mm -> bmm transform adds an unbind() op, + # which is not safe for autograd when the output of the mm is mutated. + # don't pattern match if any users of the mm mutate the input. + and not any(_is_mutable_node(user.target) for user in node.users) + ) + + +@register_fusion("batch_linear") +class PreGradBatchLinearFusion(BatchFusion): + """ + Batch linear fusion in pre grad pass. + Fuse linear with same size with torch.baddmm + """ + + def _getitem_args(self, getitem_node: torch.fx.Node): + if getitem_node.target != operator.__getitem__ or ( + getitem_node.op != "call_function" + ): + return None + return getitem_node.args[0] + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users] + else: + users = "" # type: ignore[assignment] + group_key = ( + "batch_linear", + self._getitem_args(input), + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape), + bias is None, + str(users), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_inputs_metadata = [] + batch_weights_metadata = [] + batch_biases_metadata = [] + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + weight = get_arg_value(node, 1, "weight") + batch_weights.append(weight) + batch_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 2, "bias") + batch_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + batch_biases_metadata.append(bias.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + stack_weights = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weights, batch_weights_metadata) + transpose_weight = graph.call_function( # type: ignore[operator] + torch.transpose, args=(stack_weights, 1, 2) + ) + transpose_weight.meta["example_value"] = torch.transpose( + stack_weights.meta["example_value"], 1, 2 + ) + if all(bias is None for bias in batch_biases): + bmm = graph.call_function( # type: ignore[operator] + torch.bmm, + args=(stack_inputs, transpose_weight), + ) + bmm.meta["example_value"] = torch.bmm( + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + else: + stack_biases = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_biases, batch_biases_metadata) + unsqueeze_biases = graph.call_function( # type: ignore[operator] + torch.unsqueeze, args=(stack_biases, 1) + ) + unsqueeze_biases.meta["example_value"] = torch.unsqueeze( + stack_biases.meta["example_value"], 1 + ) + bmm = graph.call_function( # type: ignore[operator] + torch.baddbmm, + args=(unsqueeze_biases, stack_inputs, transpose_weight), + ) + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None + + bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) # type: ignore[operator] + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + for i, linear in enumerate(batch_nodes): + with graph.inserting_after(bmm): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(bmm, i)) # type: ignore[operator] + linear.replace_all_uses_with(getitem) + getitem.meta.update(linear.meta) + graph.erase_node(linear) # type: ignore[operator] + counters["inductor"]["batch_linear"] += 1 + + +@register_fusion("batch_layernorm") +class BatchLayernormFusion(BatchFusion): + """ + Batch layer norm fusion in pre grad pass + """ + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 2, "weight") + bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users] + else: + users = "" # type: ignore[assignment] + group_key = ( + ( + "batch_layernorm", + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape) + if weight is not None + else "", + str(bias.meta["example_value"].shape) if bias is not None else "", + str(get_arg_value(node, 1, "normalized_shape")), + str(get_arg_value(node, 4, "eps")), + str(users), + ) + if "example_value" in input.meta + and is_node_meta_valid(weight) + and is_node_meta_valid(bias) + else None + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_shapes = [] + group_weights = [] + group_biases = [] + group_epss = [] + group_nodes = [] + group_inputs_metadata = [] + group_biases_metadata = [] + group_weights_metadata = [] + for node in subset: + group_nodes.append(node) + input = get_arg_value(node, 0, "input") + group_inputs.append(input) + group_inputs_metadata.append(input.meta["example_value"]) + group_shapes.append(get_arg_value(node, 1, "normalized_shape")) + weight = get_arg_value(node, 2, "weight") + group_weights.append(weight) + if weight is not None and hasattr(weight, "meta"): + group_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 3, "bias") + group_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + group_biases_metadata.append(bias.meta["example_value"]) + eps = get_arg_value(node, 4, "eps") + if eps is None: + eps = 1e-5 + group_epss.append(eps) + stack_dim = -1 - len(group_shapes[-1]) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + if all(weight is None for weight in group_weights): + group_weights = None # type: ignore[assignment] + assert all(eps == group_epss[0] for eps in group_epss), ( + "all epsilon values must be equal" + ) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_input = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} + ) + update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) + if group_weights is not None: + stack_weight = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weight, group_weights_metadata) + else: + stack_weight = None + if group_biases is not None: + stack_bias = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_bias, group_biases_metadata) + else: + stack_bias = None + + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.nn.functional.layer_norm, + args=(stack_input, group_shapes[-1]), + kwargs={"eps": group_epss[-1]}, + ) + batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] + + if group_weights is not None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + # pyrefly: ignore [missing-attribute] + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + # pyrefly: ignore [missing-attribute] + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + elif group_weights is not None and group_biases is None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + # pyrefly: ignore [not-callable] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + # pyrefly: ignore [missing-attribute] + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + elif group_weights is None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + # pyrefly: ignore [not-callable] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + # pyrefly: ignore [missing-attribute] + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + + batch_layer_norm_unbind = graph.call_function( # type: ignore[operator] + torch.unbind, + args=(batch_layer_norm,), + kwargs={"dim": stack_dim}, + ) + update_stack_example_value( + batch_layer_norm_unbind, + batch_layer_norm.meta["example_value"], + op=torch.unbind, + dim=stack_dim, + ) + + for i, node in enumerate(group_nodes): + with graph.inserting_after(batch_layer_norm_unbind): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(batch_layer_norm_unbind, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_layernorm"] += 1 + + +class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. + We fuse it in random place, and the introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" + # for relu op, we also use the inplace to construct the key + group_key = ( + "batch_" + self.op.__name__.lower().split(".")[0], + str(input.meta["example_value"].shape), + str(node.kwargs.get("inplace", False)), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + if self.op is torch.nn.functional.relu: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], + # pyrefly: ignore [bad-argument-type] + inplace=subset[0].kwargs.get("inplace", False), + ) + else: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"] + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = ( + parent.target # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False) + else "" + ) + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): # type: ignore[operator] + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +class BatchMathOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch simple match related ops such as nan_to_num in pre grad pass. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # check the input has the same shape and its users have the same target + # check all clamp operators have the same min and max values, and + # nan_to_num operators use the same default value. + child = next(iter(node.users.keys())) + group_key = ( + str(input.meta["example_value"].shape) + + str(node.kwargs) + + str(child.target) + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + kwargs = subset[0].kwargs + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs=kwargs, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], **kwargs + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +@register_fusion("batch_tanh") +class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.tanh, **kwargs) + + +@register_fusion("batch_sigmoid") +class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.sigmoid, **kwargs) + + +@register_fusion("batch_relu") +class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.nn.functional.relu, **kwargs) + + +@register_fusion("batch_detach") +class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.detach, **kwargs) + + +@register_fusion("batch_nan_to_num") +class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nan_to_num, **kwargs) + + +@register_fusion("batch_clamp") +class BatchClampPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.clamp, **kwargs) + + +@register_fusion("batch_dropout") +class BatchDropoutPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nn.functional.dropout, **kwargs) + + +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.relu.default, **kwargs) + + +@register_fusion("batch_aten_add", pre_grad=False) +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.add.Tensor, **kwargs) + + +@register_fusion("batch_aten_sub", pre_grad=False) +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sub.Tensor, **kwargs) + + +@register_fusion("batch_aten_div", pre_grad=False) +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.div.Tensor, **kwargs) + + +@register_fusion("batch_aten_mul", pre_grad=False) +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.mul.Tensor, **kwargs) + + +class _OrderedSet: + def __init__(self, param=None) -> None: + if param: + self.rep = OrderedDict(dict.fromkeys(param)) + else: + self.rep = OrderedDict() + + def __contains__(self, o) -> bool: + return o in self.rep + + def __len__(self) -> int: + return self.rep.__len__() + + def append(self, o): + self.rep[o] = None + + def __iter__(self): + return self.rep.keys().__iter__() + + +def find_independent_subset_greedy( + node_list: Iterable[torch.fx.Node], + graph_search_options: dict[str, Any], +) -> Iterator[Iterable[torch.fx.Node]]: + """ + Yields a list of subsets of `node_list` where no element in the subset + depends on any other element in the subset. This results in a set of + independent nodes which can be fused together. + + The order of `node_list` is preserved within each subset so we can benefit + from split-cat elimination in later passes. + + During iteration it is only safe to mutate the graph by changing the nodes + that have been returned. + + graph_search_options: + - min_fuse_set_size: Minimum size of the subset to consider. Subsets below + this size will be ignored. + - max_fuse_set_size: Maximum size of the subset to consider. Subsets will + be broken to be at most this size. + """ + + # Compute all the children of `node` which are members of + # `interesting_nodes`. + def find_dependent_nodes(node, interesting_nodes): + visited_node_set = OrderedSet[torch.fx.Node]() + dep_set = OrderedSet[torch.fx.Node]() + + work = [node] + while work: + node = work.pop() + for input_node in node.all_input_nodes: + if input_node in interesting_nodes: + dep_set.add(input_node) + + if input_node not in visited_node_set: + visited_node_set.add(input_node) + work.append(input_node) + + return dep_set + + min_fuse_set_size = graph_search_options["min_fuse_set_size"] + max_fuse_set_size = graph_search_options["max_fuse_set_size"] + + # node_list needs to be a set because we only track the nodes that are left + # in it (and we want to do the `in` on a set, not a list). But we want to + # keep the correct order. + node_list = _OrderedSet(node_list) + + cache: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = {} + while node_list: + subset: list[torch.fx.Node] = [] + subset_deps = OrderedSet[torch.fx.Node]() + + next_round_node_list = _OrderedSet() + for node in node_list: + if len(subset) >= max_fuse_set_size or node in subset_deps: + next_round_node_list.append(node) + continue + + dep_set = cache.pop(node, None) + if dep_set is None: + dep_set = find_dependent_nodes(node, node_list) + + if not dep_set.intersection(subset): + subset.append(node) + subset_deps.update(dep_set) + else: + next_round_node_list.append(node) + cache[node] = dep_set + + if len(subset) >= min_fuse_set_size: + # Careful here - the caller uses the subsets to fuse nodes together + # so we need to clear any cache entry that contains one of the + # returned nodes because the dependency list could be different + # (larger) after the merge. + cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} + yield subset + + node_list = next_round_node_list + + +def get_fusion_candidates( + rule: GroupBatchFusionBase, + root_node: torch.fx.Node, + fused_set: OrderedSet[torch.fx.Node], +) -> collections.defaultdict[Any, list[torch.fx.Node]]: + """ + Search fusion candidates for a specific rule using BFS starting from the root node. + We only search the subgraph within graph_search_options["max_fuse_search_depth"]. + """ + q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque() + + candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = ( + collections.defaultdict(list) + ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + + visited_set = OrderedSet[torch.fx.Node]() + + for next_node in root_node.all_input_nodes: + q.append((1, next_node)) + visited_set.add(next_node) + + while len(q) > 0: + depth, node = q.popleft() + + if node in fused_set: + continue + + key = rule.match(node) + if key is not None: + candidate_nodes = candidate_dict[key] + if node not in candidate_nodes: + candidate_nodes.append(node) + else: + if depth < rule.graph_search_options["max_fuse_search_depth"]: + for next_node in node.all_input_nodes: + if next_node not in visited_set: + visited_set.add(next_node) + q.append((depth + 1, next_node)) + + return candidate_dict + + +def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): + stable_topological_sort(graph) # type: ignore[arg-type] + fused_set = OrderedSet[torch.fx.Node]() + log_to_scuba = False + + for node in reversed(graph.nodes): # type: ignore[arg-type] + candidates = get_fusion_candidates(rule, node, fused_set) + + for key, candidate_nodes in candidates.items(): + if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: + continue + + for subset in find_independent_subset_greedy( + candidate_nodes, rule.graph_search_options + ): + rule.fuse(graph, subset) + fused_set.update(subset) + log.debug( + f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 + ) + log_to_scuba = True + if log_to_scuba: + from torch.fx._lazy_graph_module import _LazyGraphModule + + # Force graph to re-compile otherwise the output python code may be broken + gm = graph._owning_module + if isinstance(gm, _LazyGraphModule): + _LazyGraphModule.recompile() + else: + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + graph_str = gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + name = f"optimus_{str(rule.__class__.__name__)}" + if "MTIA" in name: + name = f"cff_{str(rule.__class__.__name__)}" + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": name, + "encoding": "string", + }, + payload_fn=lambda: graph_str, + ) + + +def generate_fusion_from_config(config_options: dict[str, Any], pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + for name, options in config_options.items(): + # we skip all patterns from pattern_matcher passes (e.g., split_cat) + if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS: + continue + fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] + _options = graph_search_options.copy() + _options.update(options) + fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] + return fusions + + +def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + # we keep all current pre grad fusions to keep + # current implementation, will remove this later + if pre_grad: + fusions += generate_fusion_from_config( + config.pre_grad_fusion_options, pre_grad=True + ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if ( + x not in OPTIMUS_EXCLUDE_POST_GRAD + and config.post_grad_fusion_options[x].get("require_fbgemm", False) + ) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) + + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..021abb0d6b13bd94c146b9a058c058745252e904 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,1048 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +import typing +from collections import Counter +from collections.abc import Sequence +from typing import Any + +import torch +import torch._guards +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch._inductor.constant_folding import ConstantFolder +from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict +from torch._inductor.utils import get_gpu_type +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + statically_known_true, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + Arg, + CallFunction, + init_once_fakemode, + KeywordArg, + Match, + MULTIPLE, + PatternMatcherPass as PatternMatcherPassBase, + register_graph_pattern, + stable_topological_sort, +) +from .decompose_mem_bound_mm import check_device +from .replace_random import replace_random_passes + + +PatternMatcherPass = functools.partial( + PatternMatcherPassBase, subsystem="joint_graph_passes" +) + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten +prims = torch.ops.prims + +pass_patterns = [ + patterns, + PatternMatcherPass(), +] + + +@init_once_fakemode +def lazy_init(): + from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init + from .pad_mm import _pad_mm_init + + _pad_mm_init() + _sfdp_init() + _misc_patterns_init() + + +def remove_no_ops( + gm: torch.fx.GraphModule, + zeros: OrderedSet[torch.fx.Node], + ones: OrderedSet[torch.fx.Node], +): + with torch.utils._python_dispatch._disable_current_modes(): + "Removes no-ops: (+ 0, - 0, * 1, / 1)" + graph = gm.graph + + def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): + if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): + return False + for field in fields: + if getattr(t1, field) != getattr(t2, field): + return False + return True + + def replace_no_op(node, replace_input_index): + replacement = node.args[replace_input_index] + + # https://github.com/pytorch/pytorch/issues/86128 causes + # non-Tensor inputs even for ops with only Tensor inputs. + # TODO - decompose/type promote to avoid this + if not all(isinstance(arg, torch.fx.Node) for arg in node.args): + return + + if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): + if fake_tensors_eq( + node.meta["val"], + replacement.meta["val"], + ("shape", "device"), + ): + with graph.inserting_after(node): + replacement = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(replacement, node.meta["val"].dtype), + ) + else: + return + + node.replace_all_uses_with(replacement) + replacement.meta.update(node.meta) + graph.erase_node(node) + + for node in graph.find_nodes(op="call_function", target=aten.add.Tensor): + # TODO handle Tensor-Scalar adds, it's a different schema + if len(node.args) == 2: + if ( + not any(e in zeros for e in node.args) + or node.kwargs.get("alpha", 1) != 1 + ): + continue + + replace_index = 1 if node.args[0] in zeros else 0 + replace_no_op(node, replace_index) + + for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor): + if len(node.args) == 2: + if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: + continue + + replace_no_op(node, 0) + + for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor): + if len(node.args) == 2: + if not any(e in ones for e in node.args): + continue + + replace_input_index = 1 if node.args[0] in ones else 0 + replace_no_op(node, replace_input_index) + + for node in graph.find_nodes(op="call_function", target=aten.div.Tensor): + if len(node.args) == 2 and node.args[1] in ones: + replace_no_op(node, 0) + + # meta tensors returned from the graph have no data and can be replaced with empty_strided + for output_node in graph.find_nodes(op="output"): + had_meta_return = False + + def visit(n): + nonlocal had_meta_return + val = n.meta.get("val") + if isinstance(val, torch.Tensor) and val.device.type == "meta": + with graph.inserting_before(output_node): + n.replace_all_uses_with( + graph.call_function( + torch.ops.aten.empty_strided.default, + args=(val.size(), val.stride()), + kwargs={"dtype": val.dtype, "device": val.device}, + ) + ) + had_meta_return = True + + torch.fx.map_arg(output_node.args, visit) + if had_meta_return: + graph.eliminate_dead_code() + + +def remove_redundant_views(gm: torch.fx.GraphModule): + """ + Removes redundant views by reusing existing ones. + """ + with torch.utils._python_dispatch._disable_current_modes(): + # A dictionary mapping a tensor to all aliased views. + views: dict[torch.fx.Node, dict[torch.dtype, torch.fx.Node]] = {} + graph = gm.graph + + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten.view.dtype + ): + src = node.args[0] + to_type = node.args[1] + existing_views = views.get(src) + is_needed = True + + if existing_views: + # Replace the view with the an existing view if available. + alias = existing_views.get(to_type) + if alias: + is_needed = False + node.replace_all_uses_with(alias) + alias.meta.update(node.meta) + graph.erase_node(node) + else: + from_type = src.meta["val"].dtype + existing_views = {from_type: src} + views[src] = existing_views + + if is_needed: + # Save the new alias but do not replace existing one. + existing_views.setdefault(to_type, node) + views[node] = existing_views + + # Clean up unused views. + while True: + unused_views = [alias for alias in views if not alias.users] + if len(unused_views) == 0: + break + for unused in unused_views: + views.pop(unused) + graph.erase_node(unused) + + +class UniformValueConstantFolder(ConstantFolder): + """ + Runs constant folding and replaces tensors that have a uniform value + with a tensor constructor call: aten.full([shape], value, ...) + """ + + def __init__(self, gm, skip_constructors=False) -> None: + super().__init__(gm, skip_constructors) + self.node_storages_ptrs: dict[torch.fx.Node, int] = {} + self.constant_data_ptrs: dict[torch.fx.Node, StorageWeakRef] = {} + # we may constant fold a tensor which in the graph has a sym size + # see: [constant folding refining of symints] + self.node_replacements_shapes: dict[torch.fx.Node, list[int]] = {} + + # initialize symint -> node mapping so that we can + # use symint nodes in full constructors + self.symint_nodes = _SymHashingDict() + for n in self.module.graph.nodes: # type: ignore[union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + if n.meta["val"] not in self.symint_nodes: + self.symint_nodes[n.meta["val"]] = n + + # reference from torch/_funtorch/partitioners.py:get_default_op_list + self.view_op_packets = [ + aten.squeeze, + aten.unsqueeze, + aten.alias, + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + + self.indexing_op_packets = OrderedSet( + [ + aten.slice, + ] + ) + + self._add_peephole_patterns() + + def _add_peephole_patterns(self) -> None: + """ + Add peephole patterns for nodes where we can infer constant value even if some inputs + of the node are unknown. + """ + for op in itertools.chain( + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Tensor + ), + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Scalar + ), + ): + tensor_val = op.meta.get("val", None) + if not isinstance(tensor_val, torch.Tensor): + continue + + def is_zero_int(arg: Any) -> bool: + return isinstance(arg, int) and arg == 0 + + if not any(is_zero_int(a) for a in op.args): + continue + + t = torch.full( + [1], # shape + 0, # value + dtype=tensor_val.dtype, + device=tensor_val.device, + pin_memory=False, + ) + self.add_node_replacement(op, t) + + def _support_dynamic_shape(self): + return True + + def insertable_tensor_check(self, t: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor.flatten()[0].item() + self.node_replacements_shapes[node] = node.meta["val"].shape + self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) + + def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + env[n] = n.meta["val"] + else: + env[n] = self.unknown_value + + def _deduce_value(self, node: torch.fx.Node): + # deduce value for full-like nodes + # 1. for constructors, substitute value is a tensor of size [1] + # 2. for view ops/indexing, substitute value is the same as the input + # 3. for pointwise ops, run node to get the substitute value + # 4. deal with some special ops + # otherwise, stop deduce value and return unknown value + + # TODO: cat, more indexing + # TODO - do on cpu to avoid syncs + + # single-elem attrs + if node.op == "get_attr" or ( + node.op == "call_function" + and node.target is torch.ops.aten.lift_fresh_copy.default + ): + out = super(ConstantFolder, self).run_node(node) + if isinstance(out, torch.Tensor) and out.numel() == 1: + return out + + # handle device_put op + if node.target == prims.device_put.default: + return super(ConstantFolder, self).run_node(node) + + # constructors ops + if ( + node.op == "call_function" + and node.target is aten.full.default + and len(node.args) == 2 + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + value = args[1] + # Don't specialize symbolic value. + if not isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + new_args = [[1], value] + return aten.full.default(*new_args, **node.kwargs) + + # handle before view ops because this changes value + if node.target is aten.view.dtype: + return super(ConstantFolder, self).run_node(node) + + # view ops, return input tensor, the first argument + if hasattr(node.target, "overloadpacket") and ( + node.target.overloadpacket in self.view_op_packets + or node.target.overloadpacket in self.indexing_op_packets + ): + assert isinstance(node.args[0], torch.fx.Node) + return self.env[node.args[0]] + + # we don't want to return unknown value for symints so that we can + # still constant fold through their use in constructors or views + # if we see them in a pointwise node (e.g., tensor * symint) + # we will bail + if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt): + return node.meta["val"] + + # pointwise ops + if isinstance(node.target, torch._ops.OpOverload) and ( + torch.Tag.pointwise in node.target.tags + or node.target is torch.ops.aten.scalar_tensor.default + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs): + return self.unknown_value + + # we run the ops with dim 1, so remove memory_format to avoid error + kwargs = dict(kwargs) + kwargs.pop("memory_format", None) + + return node.target(*args, **kwargs) + + return self.unknown_value + + +def constant_fold_uniform_value(gm: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." + aten = torch.ops.aten + + # Constant folding can leak memory, especially with repeated compilation, so we are only going to + # remove constants which can be replaced with a constructor. + cf = UniformValueConstantFolder(gm) + cf.run() + + node_replacements = cf.node_replacements + + # note: [constant folding refining of symints] + # constant folding will partially evaluate a graph such that values which have dependencies which + # are entirely known at compile time may also become compile time constants. in some cases, + # this will include symints which we had not yet previously deduced are guaranteed a + # constant value and is then deduced in constant folding. an example is: + # unbacked_symint_eq_11 = torch.full((), 11).item() + # torch.full((unbacked_symint_eq_11,), 0) + node_replacements_shapes = cf.node_replacements_shapes + + graph = gm.graph + + zeros = OrderedSet[Any]() + ones = OrderedSet[Any]() + + # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, + # so just constant-ify if a Tensor is unaliased + constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() + + for node in cf.node_replacements: + constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 + + for node, value in node_replacements.items(): + # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now + # hasn't shown up to be important yet + if "val" not in node.meta: + # This can only happen in AOTI + continue + + fake_tensor = node.meta["val"] + if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): + continue + + # TODO - not sure about lossy uint->python value->uint conversions + if fake_tensor.dtype in ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ): + continue + + if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: + continue + + with graph.inserting_after(node): + # the conversion from tensor and back to value can be lossy, just use the original full ctor value + if ( + node.op == "call_function" + and node.target is aten.full.default + and len(node.args) == 2 + ): + value = node.args[1] + + # refines symints, see [constant folding refining of symints] above + for runtime_size, compile_time_size in zip( + node_replacements_shapes[node], fake_tensor.shape + ): + torch._check(runtime_size == compile_time_size) + + # replace SymInt as Node before creating a new full node + # e.g. (1, s0) -> (1, arg0_1) + node_shape = node_replacements_shapes[node] + if not all( + not isinstance(s, torch.SymInt) or s in cf.symint_nodes + for s in node_shape + ): + continue + + shapes = [ + cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s + for s in node_replacements_shapes[node] + ] + + # zeros and ones just get traced into full, so we insert those + new_node = graph.call_function( + aten.full.default, + args=(shapes, value), + kwargs={ + "dtype": fake_tensor.dtype, + "layout": torch.strided, + "device": fake_tensor.device, + "pin_memory": False, + }, + ) + + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if value == 0: + zeros.add(new_node) + elif value == 1: + ones.add(new_node) + + remove_no_ops(gm, zeros, ones) + remove_redundant_views(gm) + + +def canonicalize_quant_mapping(gm: torch.fx.GraphModule): + """ + + + torch.ops.higher_order.invoke_quant_packed(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1)); + -> + torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); + """ + graph = gm.graph + invoke_quant_invocations = graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_quant_packed + ) + for invoke_quant in invoke_quant_invocations: + kwargs = dict(invoke_quant.kwargs) + + quant_options_node = kwargs.pop("quant_options", None) + if quant_options_node is not None: + assert isinstance(quant_options_node, torch.fx.Node) + quant_options = torch._higher_order_ops.InvokeQuant( + *invoke_quant.kwargs["quant_options"].args, + **invoke_quant.kwargs["quant_options"].kwargs, + ) + else: + quant_options = torch._higher_order_ops.InvokeQuant() + + subgraph, *args = invoke_quant.args + with gm.graph.inserting_before(invoke_quant): + invoke_quant_replacement = graph.call_function( + torch._higher_order_ops.invoke_quant, + (subgraph, *args), + # pyrefly: ignore [bad-argument-type] + kwargs, + ) + invoke_quant_replacement.meta.update(subgraph.meta) + invoke_quant_replacement.meta["quant_options"] = quant_options + + invoke_quant.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(invoke_quant) + + if quant_options_node and len(quant_options_node.users) == 0: + graph.erase_node(quant_options_node) + + first_user = next(iter(invoke_quant_replacement.users)) + + if ( + len(invoke_quant_replacement.users) == 1 + and len(subgraph.users) == 1 + and first_user.target is operator.getitem + and first_user.args[1] == 0 + ): + subgraph_graph = getattr(gm, subgraph.target) + output_node = torch._inductor.utils.output_node(subgraph_graph) + assert ( + isinstance(output_node.args[0], (list, tuple)) + and len(output_node.args[0]) == 1 + ) + + unpacked_output = output_node.args[0][0] + output_node.args = (unpacked_output,) + if "val" in output_node.meta: + output_node.meta["val"] = output_node.meta["val"][0] + subgraph_graph.recompile() + + invoke_quant_replacement.meta.update(first_user.meta) + first_user.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(first_user) + + +def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): + """ + Canonicalization passes that will run immediately after aot autograd + tracing. Thsis must be run before all other graph passes. + """ + canonicalize_quant_mapping(gm) + + +def joint_graph_passes(graph: torch.fx.GraphModule): + """ + Run FX transformations on the joint forwards+backwards graph. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="joint_graph_passes", + ) + + lazy_init() + count = 0 + + # must occur before other passes + canonicalize_aten_ir_passes(graph) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + + from .post_grad import remove_noop_ops + + GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + + if config.joint_graph_constant_folding: + GraphTransformObserver(graph, "constant_fold_uniform_value").apply_gm_pass( + constant_fold_uniform_value + ) + + if config.pattern_matcher: + for i, patterns in enumerate(pass_patterns): + maybe_count = GraphTransformObserver( + graph, f"pass_pattern_{i}" + ).apply_graph_pass(patterns.apply) + count += maybe_count if maybe_count is not None else 0 + + if not config.fallback_random: + # not trying into the bisector because decomps may have already affected rng reproducibility + # we'll instead explicitly turn off the config + count += replace_random_passes(graph) + + if config.joint_custom_post_pass is not None: + GraphTransformObserver(graph, "joint_custom_post_pass").apply_graph_pass( + config.joint_custom_post_pass + ) + count += 1 + + if count: + stable_topological_sort(graph.graph) + graph.graph.lint() + graph.recompile() + return graph + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices = OrderedSet[torch.device]() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + KeywordArg("arg"), + KeywordArg("dtype1"), + ), + KeywordArg("dtype2"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): + """Remove chain of dtype conversions often created by AMP""" + graph = match.graph + node = match.output_node() + allowed = torch.float16, torch.bfloat16, torch.float32, torch.float64 + if dtype1 in allowed and dtype2 in allowed: + repl = graph.call_function( + torch.ops.prims.convert_element_type.default, (arg, dtype2) + ) + repl.meta.update(node.meta) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +def definitely_equal( + old_sizes: Sequence[torch.SymInt | int], + new_sizes: Sequence[torch.SymInt | torch.fx.Node | int], +) -> bool: + """ + Leverage guard_or_true/false to compare if two lists of int/symint are equal. + Useful to compare sizes, strides etc. + + Can handle -1 in new_sizes which happens in the size arguments of a + view op. old_sizes is supposed to be the tensor shape and should not + contain -1. + + new_sizes can contains fx.Node when dynamic shape is enabled. In that + case new_sizes[i].meta['val'] contains the real torch.SymInt. + """ + + num_neg1 = 0 + + if len(old_sizes) != len(new_sizes): + return False + + for lhs_item, rhs_item in zip(old_sizes, new_sizes): + if isinstance(rhs_item, torch.fx.Node): + rhs_item = rhs_item.meta["val"] + + assert isinstance(lhs_item, (int, torch.SymInt)), type(lhs_item) + assert isinstance(rhs_item, (int, torch.SymInt)), type(rhs_item) + + # It still makes sense to call guard_or_true/false since lhs_item + # rhs_item are torch.SymInt rather than sympy expressions when + # dynamic shape is enabled. + if guard_or_false(lhs_item == rhs_item): + continue + + if guard_or_true(rhs_item != -1): + return False + + num_neg1 += 1 + + if num_neg1 > 1: + return False + return True + + +@register_graph_pattern( + CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def pointless_view(match: Match, arg, size): + """Remove no-op view""" + node = match.output_node() + arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] + if definitely_equal(arg_size, size): + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.view.default, + CallFunction(aten.view.default, KeywordArg("arg"), KeywordArg("size1")), + KeywordArg("size2"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def pointless_view_pair(match: Match, arg, size1, size2): + """ + Remove a pair of views that are pointless. + """ + node = match.output_node() + arg_size = list(arg.meta["val"].shape) + if definitely_equal(arg_size, size2): + node.replace_all_uses_with(arg) + match.erase_nodes() + counters["inductor"]["removed_pointless_view_pair"] += 1 + + +@register_graph_pattern( + CallFunction( + aten.permute.default, + CallFunction(aten.permute.default, KeywordArg("arg"), KeywordArg("perm1")), + KeywordArg("perm2"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def pointless_permute_pair(match: Match, arg, perm1, perm2): + rank = len(perm1) + assert len(perm2) == rank + + for i in range(rank): + if perm1[perm2[i]] != i: + return # bail out + node = match.output_node() + node.replace_all_uses_with(arg) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.bmm, + Arg(), + Arg(), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, +) +def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + """Convert bmm to mm when batch size is 1""" + + def repl(a, b): + return torch.mm(a.squeeze(0), b.squeeze(0)).unsqueeze(0) + + if ( + check_device(mat1.meta["val"], mat2.meta["val"], get_gpu_type()) + and statically_known_true(mat1.meta["val"].shape[0] == 1) + and statically_known_true(mat2.meta["val"].shape[0] == 1) + ): + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [mat1, mat2]) + + +# When softmax is used with temperature or other scaling, we get the pattern +# +# scale(x) - scale(x).amax(dim, keepdim=True) +# +# which is expected to be at most zero, but we may end up with numerical +# discrepancies # between the recomputed values of scale(x) inside and out +# of the reduction, # depending on compiler optimizations, e.g. use of fma +# instructions. +# +# Here we replace it with the mathematically equivalent, +# +# scale(x - x.amax(dim, keepdim=True)) +# +# which is more stable as we only compute the scaling once. +# +# NOTE: This pattern must come after fused attention matching! + + +def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False): + # Allow matching inp * other and other * input + if reverse: + scaled = CallFunction( + linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE + ) + else: + scaled = CallFunction( + linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE + ) + if to_dtype: + scaled = CallFunction( + prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE + ) + amax = CallFunction( + aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim") + ) + return CallFunction(aten.sub.Tensor, scaled, amax) + + +def _other_is_broadcasted_in_dim(match): + # Check that the scaling factor is constant across the reduction dim, + # so scaling doesn't change which index corresponds to the maximum value + other = match.kwargs["other"] + if isinstance(other, (int, float)): + return True + + inp = match.kwargs["inp"] + if not all(isinstance(x, torch.fx.Node) for x in (inp, other)): + return False + + inp_example = inp.meta["val"] + other_example = other.meta["val"] + if isinstance(other_example, (torch.SymInt, torch.SymFloat)): + return True + + if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)): + return False + + inp_ndim = inp_example.ndim + other_shape = other_example.shape + if inp_ndim < len(other_shape): + return False + + # Pad other_shape to the same ndim as inp + other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) + + dim = match.kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + + if any(d >= len(other_shape) for d in dim): + return False + + return all(statically_known_true(other_shape[d] == 1) for d in dim) + + +def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: int | float | torch.Tensor + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + # pyrefly: ignore [unsupported-operation] + return (inp - max_) * (sign * other) + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [inp, other]) + + +for reverse, to_dtype in itertools.product((False, True), repeat=2): + register_graph_pattern( + _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(mul_softmax_pattern) + + +def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: int | float | torch.Tensor + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + # pyrefly: ignore [unsupported-operation] + return (inp - max_) / (sign * other) + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [inp, other]) + + +for to_dtype in (False, True): + register_graph_pattern( + _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(div_softmax_pattern) + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, torch.fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedious. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_graph_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise operation in joint graph. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + # pyrefly: ignore # bad-assignment + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + # Create a replacement that uses torch.where for the pointwise operation + def repl_fn(shape, background_val, dim, selector, val): + # Create a tensor of indices for the scatter dimension + length = shape[dim] + indices = torch.arange(length, device=selector.device, dtype=torch.int64) + + # Reshape indices to have size 'length' at dim, then broadcast + view_shape = [1] * len(shape) + view_shape[dim] = length + indices_view = indices.view(*view_shape) + + # Broadcast selector to match full tensor shape + selector_expanded = selector.expand(shape) + + # Create a mask for where to scatter + mask = selector_expanded == indices_view + + # Use torch.where to implement the scatter pointwise operation + return torch.where(mask, val, background_val) + + # replace the scatter operation with pointwise equivalent + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl_fn, [shape, background_val, dim, selector, val]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..e887d4bf62c8e11196ac5b2740c0ef3c39e64def --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/memory_estimator.py @@ -0,0 +1,454 @@ +import itertools +import logging +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.fx as fx +from torch.fx.experimental.symbolic_shapes import hint_int +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map_only + + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class StorageKey: + storage: torch.UntypedStorage + device: torch.device + + def __hash__(self) -> int: + return self.storage._cdata + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StorageKey): + return False + return ( + self.storage._cdata == other.storage._cdata and self.device == other.device + ) + + +class GraphAliasTracker: + """ + Tracks storage allocation and usage relationships in an FX graph. + + Differentiates between: + - Fresh allocations: nodes that allocate new storage (not views/aliases) + - Uses: nodes that use a storage as input + """ + + def __init__(self, nodes: list[fx.Node]): + # Map from node to the fresh storages it allocates (not views/aliases) + self.node_to_fresh_allocations: dict[fx.Node, OrderedSet[StorageKey]] = {} + + # Map from storage to the node that originally allocated it + self.storage_to_allocator: dict[StorageKey, fx.Node] = {} + + # Map from node to all storages it uses as inputs + self.node_to_storage_uses: dict[fx.Node, OrderedSet[StorageKey]] = {} + + # Map from storage to all nodes that use it + self.storage_to_uses: dict[StorageKey, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + + # Map from storage to the last node that uses it + self.storage_to_last_user: dict[StorageKey, fx.Node] = {} + + # Map from node to storages that have their last use at that node + self.node_to_storages_last_used: dict[fx.Node, OrderedSet[StorageKey]] = ( + defaultdict(OrderedSet) + ) + + # Track all output storages for each node (for building usage graph) + self.node_to_output_storages: dict[fx.Node, OrderedSet[StorageKey]] = {} + + # First pass: build storage allocations and track uses + for node in nodes: + # Get output storages + output_storages = self._get_output_storages(node) + self.node_to_output_storages[node] = output_storages + + # Track fresh allocations + fresh_allocations: OrderedSet[StorageKey] = OrderedSet() + for storage_key in output_storages: + if storage_key not in self.storage_to_allocator: + self.storage_to_allocator[storage_key] = node + fresh_allocations.add(storage_key) + self.node_to_fresh_allocations[node] = fresh_allocations + + # Track input storage uses (safe because inputs were already processed) + input_storages = self._get_input_storages(node) + self.node_to_storage_uses[node] = input_storages + for storage_key in input_storages: + self.storage_to_uses[storage_key].add(node) + + # Second pass: find last users (iterate in reverse) + for node in reversed(nodes): + input_storages = self.node_to_storage_uses[node] + for storage_key in input_storages: + if storage_key not in self.storage_to_last_user: + self.storage_to_last_user[storage_key] = node + self.node_to_storages_last_used[node].add(storage_key) + + @staticmethod + def _get_output_storages(node: fx.Node) -> OrderedSet[StorageKey]: + """ + Get all storages from a node's outputs. + + Uses pytree to handle arbitrary nested structures. + """ + val = node.meta.get("val") + if val is None: + return OrderedSet() + + storages: OrderedSet[StorageKey] = OrderedSet() + + def collect_storage(tensor: torch._subclasses.FakeTensor) -> None: + storages.add(StorageKey(tensor.untyped_storage(), tensor.device)) + + # Use tree_map_only to handle FakeTensors in nested structures + tree_map_only(torch._subclasses.FakeTensor, collect_storage, val) + + return storages + + def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]: + """ + Get all storages from a node's inputs. + """ + input_storages: OrderedSet[StorageKey] = OrderedSet() + + for input_node in node.all_input_nodes: + input_storages.update(self.node_to_output_storages[input_node]) + + return input_storages + + def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]: + """Get all fresh storage allocations by this node (not views/aliases).""" + return self.node_to_fresh_allocations[node] + + def get_storage_uses(self, node: fx.Node) -> OrderedSet[StorageKey]: + """Get all storages that this node uses as inputs.""" + return self.node_to_storage_uses[node] + + def get_storages_last_used( + self, + node: fx.Node, + ) -> OrderedSet[StorageKey]: + """ + Get storages whose last use is at this node. + """ + return self.node_to_storages_last_used[node] + + +def _size_of_default(num_bytes: int | torch.SymInt) -> int: + return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback) + + +def device_filter(device: torch.device) -> bool: + return device.type != "cpu" + + +def build_memory_profile( + graph: fx.Graph, + is_releasable: Callable[[fx.Node], bool], + size_of: Callable[[int | torch.SymInt], int] | None = None, +) -> list[int]: + """ + Function to estimate the memory profile of an input FX graph. + + Args: + - graph (fx.Graph): The input FX graph for which the memory profile + is to be estimated. + - is_releasable (Callable[[fx.Node], bool]): A function that + determines if a node's memory can be released (e.g. primal nodes + cannot be released). + - size_of (Callable[[int | torch.SymInt], int]): A function that converts + byte counts (possibly symbolic) to concrete integers. + + Returns: + - List[int]: A list representing the memory profile over the execution + of the graph, where each entry corresponds to the memory usage at + a particular point in the execution. + """ + + size_of = size_of or _size_of_default + nodes = list(graph.nodes) + alias_info = GraphAliasTracker(nodes) + + # Build memory profile + current_memory = 0 + + for node in itertools.chain( + graph.find_nodes(op="placeholder"), graph.find_nodes(op="get_attr") + ): + for storage_key in alias_info.get_fresh_allocations(node): + if device_filter(storage_key.device): + current_memory += size_of(storage_key.storage.nbytes()) + + memory_profile = [current_memory] + + for node in nodes: + if node.op in ("placeholder", "get_attr", "output"): + continue + + # Process allocations + for storage_key in alias_info.get_fresh_allocations(node): + if device_filter(storage_key.device): + current_memory += size_of(storage_key.storage.nbytes()) + + memory_profile.append(current_memory) + + # Process deallocations + for storage_key in alias_info.get_storages_last_used(node): + allocator = alias_info.storage_to_allocator[storage_key] + if is_releasable(allocator): + if device_filter(storage_key.device): + current_memory -= size_of(storage_key.storage.nbytes()) + + memory_profile.append(current_memory) + + return memory_profile + + +def get_fwd_bwd_interactions( + fwd_graph: fx.Graph, + bwd_graph: fx.Graph, + size_of: Callable[[int | torch.SymInt], int] | None = None, +) -> tuple[int, OrderedSet[str]]: + """ + Analyze the interactions between the forward (fwd) and backward (bwd) graphs + to determine memory usage characteristics. + + Args: + - fwd_graph (fx.Graph): The forward graph representing the forward pass. + - bwd_graph (fx.Graph): The backward graph representing the backward pass. + - size_of (Callable[[int | torch.SymInt], int]): A function that converts + byte counts (possibly symbolic) to concrete integers. + + Returns: + - tuple[int, OrderedSet[str]]: A tuple containing: + 1. The baseline memory usage during the backward pass, accounting for + storages that persist from the forward pass (i.e., in fwd output but + not in bwd input). + 2. A set of node names whose storage cannot be released during the bwd pass. + These include nodes that use storage from primals or are in bwd input + but not in fwd output. + """ + + size_of = size_of or _size_of_default + + # Build alias info for forward graph + fwd_nodes = list(fwd_graph.nodes) + fwd_alias_info = GraphAliasTracker(fwd_nodes) + + # Identify storages allocated by primal placeholder nodes + primal_storages: OrderedSet[StorageKey] = OrderedSet() + for node in fwd_graph.find_nodes(op="placeholder"): + if node.name.startswith("primals"): + primal_storages.update(fwd_alias_info.get_fresh_allocations(node)) + + # Get storages in forward output + fwd_output_node = next(iter(reversed(fwd_graph.nodes)))[-1] + assert fwd_output_node.op == "output" + fwd_output_storages = fwd_alias_info.get_storage_uses(fwd_output_node) + + # Node names that should not be deleted during memory profile estimation of bwd_graph + do_not_delete: OrderedSet[str] = OrderedSet() + + # Collect all storages in backward inputs and identify nodes to not delete + bwd_input_storages: OrderedSet[StorageKey] = OrderedSet() + for node in bwd_graph.find_nodes(op="placeholder"): + node_storages = GraphAliasTracker._get_output_storages(node) + bwd_input_storages.update(node_storages) + + # Check if this node uses primal storage + if node_storages & primal_storages: + do_not_delete.add(node.name) + + # Check if this node's storages are not in forward outputs + # (meaning it's an external input to backward pass) + if not (node_storages & fwd_output_storages): + do_not_delete.add(node.name) + + # Calculate baseline memory: storages in fwd output but not in bwd input + # These storages persist throughout the backward pass + baseline_storages = fwd_output_storages - bwd_input_storages + bwd_baseline_memory = 0 + for storage_key in baseline_storages: + if storage_key.device.type != "cpu": + bwd_baseline_memory += size_of(storage_key.storage.nbytes()) + + return bwd_baseline_memory, do_not_delete + + +def _is_releasable(n: fx.Node) -> bool: + # Storages of primals cannot be released during fwd or bwd pass. + return not n.name.startswith("primals") + + +def get_peak_memory( + fwd_graph: fx.Graph, + bwd_graph: fx.Graph, +) -> int: + fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable)) + + bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions( + fwd_graph, + bwd_graph, + ) + + def _is_bwd_releasable(n: fx.Node) -> bool: + # Storages of nodes in bwd_do_not_delete cannot be released + # during the bwd pass. + return _is_releasable(n) and n.name not in bwd_do_not_delete + + bwd_peak_memory = bwd_baseline_memory + max( + build_memory_profile(bwd_graph, _is_bwd_releasable) + ) + return max( + fwd_peak_memory, + bwd_peak_memory, + ) + + +class MemoryTracker: + """ + Tracks memory usage for alternative scheduling orders of an FX graph. + + This class enables tracking memory usage as nodes are scheduled in a different + order than the original graph. + """ + + def __init__( + self, + graph: fx.Graph, + is_releasable: Callable[[fx.Node], bool] | None = None, + device_filter: Callable[[torch.device], bool] | None = None, + ): + """ + Initialize memory tracker for alternative scheduling of the given graph. + + Args: + graph: FX graph to track memory for under alternative scheduling + is_releaseable: do we consider this input to the graph to release memory + upon final use, or is allocated for the duration of the graph ? + by default, we assume all nodes but those that start with "primals" to be releasable + device_filter: Function to determine which devices to track (default: non-CPU) + """ + + self.graph = graph + self.nodes = list(graph.nodes) + self.device_filter = device_filter or (lambda device: device.type != "cpu") + self.scheduled: OrderedSet[fx.Node] = OrderedSet() + + # Memory tracking using GraphAliasTracker + self.alias_tracker = GraphAliasTracker(self.nodes) + self.current_live_storages: OrderedSet[StorageKey] = OrderedSet() + self.current_memory_bytes = 0 + self.is_releasable = _is_releasable if is_releasable is None else is_releasable + + # Initialize live storages with placeholders and get_attr nodes + for node in self.nodes: + if node.op in ("placeholder", "get_attr"): + fresh_allocations = self.alias_tracker.get_fresh_allocations(node) + for storage_key in fresh_allocations: + if self.device_filter(storage_key.device): + self.current_live_storages.add(storage_key) + self.current_memory_bytes += self._get_storage_size(storage_key) + + self.peak_memory = self.current_memory_bytes + + log.debug( + "Memory tracker initialized with initial memory: %d MB", + self.current_memory_bytes // (1024 * 1024), + ) + + def schedule_node(self, node: fx.Node) -> None: + """ + Schedule a node and update memory tracking for the new scheduling order. + + Args: + node: The node being scheduled (potentially out of original order) + """ + assert node not in self.scheduled, "should not schedule node twice" + self.scheduled.add(node) + self._update_memory_for_node(node) + + def get_current_memory_bytes(self) -> int: + """Get current live memory in bytes under the current scheduling.""" + return self.current_memory_bytes + + def _get_storage_size(self, storage_key: StorageKey) -> int: + """Get the size of a storage in bytes, handling symbolic shapes.""" + size_bytes = storage_key.storage.nbytes() + return hint_int( + size_bytes, fallback=torch._inductor.config.unbacked_symint_fallback + ) + + def _get_storages_freed_by_node(self, node: fx.Node) -> OrderedSet[StorageKey]: + """Get storages that would be freed if we schedule this node.""" + freed_storages: OrderedSet[StorageKey] = OrderedSet() + + input_storages = self.alias_tracker.get_storage_uses(node) + for storage_key in input_storages: + if not self.device_filter(storage_key.device): + continue + + # Invariant: if a node uses a storage, it must be live + assert storage_key in self.current_live_storages, ( + "all input storages should be currently allocated" + ) + + if not self.is_releasable( + self.alias_tracker.storage_to_allocator[storage_key] + ): + continue + + all_uses = self.alias_tracker.storage_to_uses[storage_key] + + # If no more unscheduled uses remain, the storage can be freed + if all(u in self.scheduled for u in all_uses): + freed_storages.add(storage_key) + + return freed_storages + + def _update_memory_for_node(self, node: fx.Node) -> None: + """Update memory tracking when a node is scheduled.""" + if node.op in ("placeholder", "get_attr", "output"): + return + + # Add fresh allocations + fresh_allocations = self.alias_tracker.get_fresh_allocations(node) + alloc_bytes = 0 + for storage_key in fresh_allocations: + if ( + self.device_filter(storage_key.device) + and storage_key not in self.current_live_storages + ): + size = self._get_storage_size(storage_key) + self.current_live_storages.add(storage_key) + self.current_memory_bytes += size + alloc_bytes += size + + self.peak_memory = max(self.current_memory_bytes, self.peak_memory) + + # Remove storages that are no longer used + storages_to_free = self._get_storages_freed_by_node(node) + freed_bytes = 0 + for storage_key in storages_to_free: + if storage_key in self.current_live_storages: + size = self._get_storage_size(storage_key) + self.current_live_storages.remove(storage_key) + self.current_memory_bytes -= size + freed_bytes += size + + log.debug( + "Scheduled %s: memory change %d allocs, %d frees, current memory: %d MB", + node.name, + len(fresh_allocations), + len(storages_to_free), + self.current_memory_bytes // (1024 * 1024), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc5503d4815b6a37d1e0daa9c5ffaad4498539f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,1114 @@ +# mypy: allow-untyped-defs +import logging +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from math import prod +from typing import Any, cast + +import torch +from torch.utils._ordered_set import OrderedSet + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternExpr, + PatternMatcherPass, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_last_dim(t: torch.Tensor, dim: int) -> bool: + return dim == t.ndim - 1 or dim == -1 + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: list[torch.fx.Node], target) -> list[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> OrderedSet[torch.fx.Node]: + ancestors = OrderedSet[torch.fx.Node]() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return OrderedSet(node for node in ancestors if node.op != "placeholder") + + +def _get_tensor(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + return val + + +@dataclass +class _AllGatherMatch: + match: Match + shard_node: torch.fx.Node + ag_node: torch.fx.Node + res_node: torch.fx.Node + gather_dim: int + group_name: "torch.distributed.distributed_c10d.GroupName" + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_all_gather_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def make_zero_dim_all_gather_pattern(shard): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + shard, + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim == 0 + zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) + + def make_all_gather_split_pattern(shard): + return CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + make_zero_dim_all_gather_pattern(shard), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ) + + def make_cat_pattern(splits): + return CallFunction( + aten.cat.default, + ListOf(splits), + KeywordArg("gather_dim"), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + non_zero_dim_all_gather_pattern = make_cat_pattern( + make_all_gather_split_pattern(KeywordArg("shard")), + ) + + # Match a zero-dim all-gather in which the data is transferred as uint8 and + # viewed back as the original dtype. + zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_zero_dim_all_gather_pattern( + KeywordArg("shard"), + ), + Ignored(), + ) + + # Match a non-zero dim all-gather in which the data is transferred as uint8 + # and viewed back as the original dtype. + non_zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_cat_pattern( + CallFunction( + aten.view.dtype, + make_all_gather_split_pattern( + KeywordArg("shard"), + ), + Ignored(), + ), + ), + Ignored(), + ) + + # If two patterns with the same res_node_target have the same suffix, the + # longer pattern should appear first in the list. + # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) + # should appear before (2) in the list. + res_node_target_to_patterns = { + aten.cat.default: [ + (non_zero_dim_all_gather_pattern, 0), + ], + aten.view.dtype: [ + (non_zero_dim_type_erased_all_gather_pattern, 0), + (zero_dim_type_erased_all_gather_pattern, 0), + ], + c10d.wait_tensor.default: [ + (zero_dim_all_gather_pattern, 0), + ], + } + + # Match in reverse to ensure longer patterns is prioritized + all_gathers = [] + visited_ag_nodes = OrderedSet[torch.fx.Node]() + for node in reversed(graph.nodes): + for target, patterns in res_node_target_to_patterns.items(): + if node.target != target: + continue + for pattern, ag_node_idx in patterns: + match = pattern.match(node) + if not match: + continue + + assert isinstance(match, Match) + ag_node = match.nodes[ag_node_idx] + assert ag_node.target == c10d.all_gather_into_tensor.default + + if ag_node in visited_ag_nodes: + continue + visited_ag_nodes.add(ag_node) + + ag_match = _AllGatherMatch( + match=match, + shard_node=match.kwargs["shard"], + ag_node=ag_node, + res_node=node, + gather_dim=match.kwargs.get("gather_dim", 0), + group_name=match.kwargs["group_name"], + ) + all_gathers.append(ag_match) + + return list(reversed(all_gathers)) + + +@dataclass +class _ReduceScatterMatch: + match: Match + input_node: torch.fx.Node + reduce_scatter_node: torch.fx.Node + wait_tensor_node: torch.fx.Node + reduce_op: str + scatter_dim: int + group_name: "torch.distributed.distributed_c10d.GroupName" + + def replace_with(self, new_node: torch.fx.Node) -> None: + # Replace all uses of the result node (wait_tensor) with the fused node. + self.wait_tensor_node.replace_all_uses_with(new_node) + + # If the reduce-scatter result is saved for backward, save the fused node for backward instead. + self._update_save_for_backward(new_node) + + def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: + """ + If the output node is a user of the reduce_scatter node (indicating the reduce_scatter + result is saved for backward), this method will update the output node to use the fused node instead. + """ + output_node = None + for user in self.reduce_scatter_node.users: + if user.target == "output": + output_node = user + break + if output_node is not None: + output_node.replace_input_with(self.reduce_scatter_node, new_node) + + # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not + # saved for backward anymore. + assert len(self.reduce_scatter_node.users) == 1, ( + "Reduce scatter node has multiple users, this is not expected" + ) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_reduce_scatter_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def reduce_scatter_template(inp: PatternExpr, users: int): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + inp, + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + _users=users, + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + KeywordArg("input"), users=1 + ) + + # Two users will occur when the reduce-scatter result is saved for backward + zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + KeywordArg("input"), users=2 + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=1, + ) + + # Two users will occur when the reduce-scatter result is saved for backward + non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=2, + ) + + reduce_scatters = [] + for node in reversed(graph.nodes): + if node.target == c10d.wait_tensor.default: + if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + return list(reversed(reduce_scatters)) + + +@dataclass +class _Matmul: + nodes: list[torch.fx.Node] + arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False) + A_node: torch.fx.Node + B_node: torch.fx.Node + pre_mm_reshape: torch.fx.Node | None + post_mm_reshape: torch.fx.Node | None + + def __post_init__(self): + assert len(self.nodes) in (1, 3) + if len(self.nodes) == 1: + assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) + else: + assert self.nodes[0].target is aten.reshape.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + assert self.nodes[2].target is aten.reshape.default + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + graph = new_node.graph + + # For 2D-matmuls, we simply replace the mm node with `new_node`. + if len(self.nodes) == 1: + mm_node = self.nodes[0] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + return + + # An ND-matmul is reshape -> mm -> reshape sequence. We first replace + # the second reshape node with `new_node`. Then, we ensure that the + # original mm node in the sequence ends up with zero users by replacing + # it with a reverse reshape of `new_node`. + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + assert output_reshape_node.target is aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(_get_tensor(mm_node).shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + def erase(self) -> None: + for node in reversed(self.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten.mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + return _Matmul( + nodes=match, + A_node=cast("torch.fx.Node", match[0].args[0]), + B_node=cast("torch.fx.Node", mm_node.args[1]), + # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method. + # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. + pre_mm_reshape=None, + post_mm_reshape=None, + ) + + +@dataclass +class _ScaledMatmul(_Matmul): + A_scale_node: torch.fx.Node + B_scale_node: torch.fx.Node + bias_node: torch.fx.Node | None + result_scale_node: torch.fx.Node | None + out_dtype: torch.dtype | None + use_fast_accum: bool + pre_mm_reshape: torch.fx.Node | None + post_mm_reshape: torch.fx.Node | None + + def __post_init__(self): + super().__post_init__() + self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) + self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_ScaledMatmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten._scaled_mm.default, + aten.reshape.default, + ) + + def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: + if idx >= len(node.args): + return default + return node.args[idx] + + # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern. + # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to + # produce the correct output shapes, reduce-scatter along the correct dimensions, etc. + is_reshape_mm_reshape_pattern = match[0].target is aten.reshape.default + mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0] + pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None + post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None + A_node = cast("torch.fx.Node", mm_node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + A_scale_node = cast("torch.fx.Node", mm_node.args[2]) + B_scale_node = cast("torch.fx.Node", mm_node.args[3]) + + return _ScaledMatmul( + nodes=match, + A_node=A_node, + B_node=B_node, + A_scale_node=A_scale_node, + B_scale_node=B_scale_node, + bias_node=get_arg(mm_node, 4, None), + result_scale_node=get_arg(mm_node, 5, None), + out_dtype=get_arg(mm_node, 6, None), + use_fast_accum=get_arg(mm_node, 7, False), + pre_mm_reshape=pre_mm_reshape, + post_mm_reshape=post_mm_reshape, + ) + + +def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]: + if node.target != aten.reshape.default: + return [] + + matches = [] + for mm_node in node.users: + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + continue + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + # Since the reshape -> mm -> reshape pattern would be subsumed into + # the fused op, we only match the patterns where the shape of the + # second reshape is matches the mm result produced by the fused op. + matmul_input_node = cast("torch.fx.Node", node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + matmul_out_shape = torch.Size( + [ + *_get_tensor(matmul_input_node).shape[:-1], + _get_tensor(B_node).shape[-1], + ] + ) + if _get_tensor(reshape_node).shape != matmul_out_shape: + continue + matches.append([node, mm_node, reshape_node]) + # If for some rare reason mm_node is being reshaped by two + # different reshape nodes, we only include mm_node once in the + # parsing result. + break + + matmuls = [] + for match in matches: + mm_node = match[1] + if mm_node.target is aten.mm.default: + matmul = _Matmul.from_match(match) + matmuls.append(matmul) + elif mm_node.target is aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match(match) + matmuls.append(matmul) + else: + raise AssertionError( + "Expect the node's target to be either aten.mm.default or " + f"aten._scaled_mm.default. Got {mm_node.target}." + ) + return matmuls + + +def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: + """ + Find the matmuls that use `node` as the lhs argument. + """ + matmuls = [] + for user in node.users: + # ND matmuls + if user.target is aten.reshape.default: + matmuls.extend(_find_reshape_mm_reshape(user)) + # 2D matmuls + elif user.target is aten.mm.default: + matmul = _Matmul.from_match(match=[user]) + matmuls.append(matmul) + elif user.target is aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match([user]) + matmuls.append(matmul) + return matmuls + + +def _insert_fused_all_gather_matmul( + graph: torch.fx.Graph, + matmuls: list[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: "torch.distributed.distributed_c10d.GroupName", +) -> torch.fx.Node: + mm_types = OrderedSet(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [matmul.B_node for matmul in matmuls] + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + kwargs={"return_A": True}, + ) + elif mm_type == _ScaledMatmul: + scaled_matmuls = cast("list[_ScaledMatmul]", matmuls) + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_scaled_matmul.default, + args=( + shard_node, + [matmul.B_node for matmul in scaled_matmuls], + scaled_matmuls[0].A_scale_node, + [matmul.B_scale_node for matmul in scaled_matmuls], + gather_dim, + group_name, + [matmul.bias_node for matmul in scaled_matmuls], + [matmul.result_scale_node for matmul in scaled_matmuls], + [matmul.out_dtype for matmul in scaled_matmuls], + [matmul.use_fast_accum for matmul in scaled_matmuls], + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_shard_for_fused_all_gather_matmul, + ) + + shard_node, ag_node, ag_res_node, gather_dim, group_name = ( + all_gather.shard_node, + all_gather.ag_node, + all_gather.res_node, + all_gather.gather_dim, + all_gather.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + filter_matmul = None + if _is_last_dim(_get_tensor(shard_node), gather_dim): + # Decomposed mms should not be too small + if _get_tensor(shard_node).shape[-1] < 1024: + return + + # scaled_mm is not supported yet for last dim + def _filter_out_scaled_matmul(matmul: _Matmul): + return not isinstance(matmul, _ScaledMatmul) + + filter_matmul = _filter_out_scaled_matmul + + # Find consumer matmuls + matmuls = _find_consumer_matmuls(ag_res_node) + + # The matmuls are only fusible if non-A args don't depend on the all-gather + # result node + matmuls = [ + matmul + for matmul in matmuls + if all_gather.res_node not in matmul.arg_ancestor_nodes + ] + + if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: + return + + if _is_last_dim(_get_tensor(shard_node), gather_dim) and len( + all_gather.res_node.users + ) > len(matmuls): + # The result of ag-split-cat is used not only in matmuls. + # Then it has to be materialized, which can have overhead. + return + + if filter_matmul and not filter_matmul(matmuls[0]): + return + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + if not _is_last_dim(_get_tensor(shard_node), gather_dim): + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + + fused_node = _insert_fused_all_gather_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + matmul.erase() + all_gather.replace_with(new_ag_node) + all_gather.erase() + + # If the new_ag_node has no users, we tell the fused op to not return + # it. This creates more optimization opportunities. + if len(new_ag_node.users) == 0: + graph.erase_node(new_ag_node) + kwargs = dict(fused_node.kwargs) + if "return_A" in kwargs: + kwargs["return_A"] = False + fused_node.kwargs = kwargs + + # Raise ancestors of non-A args that are topologically ordered between + # ag_res_node and the matmul above fused_node. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + OrderedSet(x for matmul in matmuls for x in matmul.arg_ancestor_nodes), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _scatter_dim_after_reshape( + reshape_node: torch.fx.Node, orig_scatter_dim: int +) -> int: + """ + Given a reshape node and the original scatter dim for the target tensor, + returns the new scatter dim for the reshaped tensor. + """ + # if there was no pre-mm reshape, scatter dim will not change. + if not reshape_node: + return orig_scatter_dim + + reshape_op_output_tensor = _get_tensor(reshape_node) + assert reshape_op_output_tensor.ndim == 2, ( + "reshape must produce 2D tensor for scaled_mm" + ) + + assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg" + input_tensor_node = cast(torch.fx.Node, reshape_node.args[0]) + reshape_op_input_tensor = _get_tensor(input_tensor_node) + assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, ( + "reshape must be from 3D+ to 2D" + ) + + # Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must + # be collapsed to a single dim. First determine which of these happened. + input_shape = reshape_op_input_tensor.shape + output_shape = reshape_op_output_tensor.shape + leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1]) + + # Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == 0: + return 0 + + # Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == reshape_op_input_tensor.ndim - 1: + return 1 + + # Case 3: scatter dim was one of the middle dims (between 0 and ndim-1). + # if the leading dims were collapsed, the new scatter dim will be 0. + # if the ending dims were collapsed, the new scatter dim will be 1. + return 0 if leading_dims_collapsed else 1 + + +def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None: + """ + Returns producer matmul node if found, otherwise returns None. + """ + if node.target is aten.mm.default: + return _Matmul.from_match(match=[node]) + elif node.target is aten._scaled_mm.default: + return _ScaledMatmul.from_match(match=[node]) + elif node.target is aten.reshape.default: + reshape_node_1 = node + + mm_node = reshape_node_1.args[0] + assert isinstance(mm_node, torch.fx.Node) + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + return None + + reshape_node_0 = mm_node.args[0] + assert isinstance(reshape_node_0, torch.fx.Node) + if reshape_node_0.target != aten.reshape.default: + return None + + if mm_node.target is aten.mm.default: + return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) + elif mm_node.target is aten._scaled_mm.default: + return _ScaledMatmul.from_match( + match=[reshape_node_0, mm_node, reshape_node_1] + ) + return None + + +def _insert_fused_matmul_reduce_scatter( + graph: torch.fx.Graph, + matmul: _Matmul, + reduce_op: str, + orig_scatter_dim: int, + group_name: "torch.distributed.distributed_c10d.GroupName", + scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern + output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern +) -> torch.fx.Node: + if type(matmul) is _Matmul: + return graph.call_function( + torch.ops.symm_mem.fused_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + reduce_op, + orig_scatter_dim, + group_name, + ), + ) + elif type(matmul) is _ScaledMatmul: + return graph.call_function( + torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + matmul.A_scale_node, + matmul.B_scale_node, + reduce_op, + orig_scatter_dim, + scatter_dim_after_reshape, + group_name, + output_shape, + matmul.bias_node, + matmul.result_scale_node, + matmul.out_dtype, + matmul.use_fast_accum, + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") + + +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + + Returns boolean indicating if fusion was successful or not. + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_for_fused_matmul_reduce_scatter, + ) + + ( + input_node, + _reduce_scatter_node, + rs_wait_tensor_node, + reduce_op, + orig_scatter_dim, + group_name, + ) = ( + reduce_scatter.input_node, + reduce_scatter.reduce_scatter_node, + reduce_scatter.wait_tensor_node, + reduce_scatter.reduce_op, + reduce_scatter.scatter_dim, + reduce_scatter.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + filter_matmul = None + if _is_last_dim(_get_tensor(input_node), orig_scatter_dim): + # scaled_mm is not supported yet for last dim mm+rs + def _filter_out_scaled_matmul(matmul: _Matmul): + return not isinstance(matmul, _ScaledMatmul) + + filter_matmul = _filter_out_scaled_matmul + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(input_node.users) != 1: + log.warning( + "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." + ) + return + + matmul = _find_producer_matmul(input_node) + + if matmul is None: + log.warning( + "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + if filter_matmul and not filter_matmul(matmul): + return + + if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log.warning( + "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + # We need to track 3 values for the fused scaled mm reduce scatter implementation: + # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. + # 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op. + # 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended + # 3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops. + + # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape + # for the fused matmul reduce scatter implementation to use. + if matmul.pre_mm_reshape: + scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape( + matmul.pre_mm_reshape, orig_scatter_dim + ) + else: + scatter_dim_after_maybe_reshape = orig_scatter_dim + + # If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the + # fused matmul reduce scatter implementation to use. + if matmul.post_mm_reshape: + output_shape = list(_get_tensor(matmul.post_mm_reshape).shape) + else: + A_orig_shape = list(_get_tensor(matmul.A_node).shape) + B_shape = list(_get_tensor(matmul.B_node).shape) + output_shape = [*A_orig_shape[:-1], B_shape[-1]] + + graph = rs_wait_tensor_node.graph + with graph.inserting_before(rs_wait_tensor_node): + # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim_after_maybe_reshape, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) + + # Replace matched subgraph with fused matmul reduce scatter node + fused_node = _insert_fused_matmul_reduce_scatter( + graph, + matmul, + reduce_op, + orig_scatter_dim, + group_name, + scatter_dim_after_maybe_reshape, + output_shape, + ) + reduce_scatter.replace_with(fused_node) + reduce_scatter.erase() + matmul.erase() + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + matmul.arg_ancestor_nodes, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + log.debug("successfully fused matmul reduce scatter") + + +def _get_node_to_ancestors( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Compute the ancestors for all nodes in a graph. + """ + node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated] + for node in graph.nodes: + node_to_ancestors[node] = OrderedSet(node.all_input_nodes) + for dep in node.all_input_nodes: + node_to_ancestors[node] |= node_to_ancestors[dep] + + return node_to_ancestors + + +def _get_collective_to_overlappable_nodes( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, list[torch.fx.Node]]: + """ + For each collective in the graph, find nodes that are neither ancestors nor + descendants of the collective. + """ + + def is_collective(node) -> bool: + # Only consider all-gather and reduce-scatter in the context of + # micro-pipeline TP. + return node.target in [ + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + + node_to_ancestors = _get_node_to_ancestors(graph) + collective_to_overlappable_nodes = defaultdict(list) + for node in graph.nodes: + if not is_collective(node): + continue + for x in graph.nodes: + if ( + node not in node_to_ancestors[x] + and x not in node_to_ancestors[node] + and x.op == "call_function" + ): + collective_to_overlappable_nodes[node].append(x) + + return collective_to_overlappable_nodes + + +def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: + """ + Find all unexposed collectives in the graph. + + Because we don't have the runtime estimate, this function is a rough + estimation using the following strong/hand-wavy assumptions: + + - Only a predefined set of "compute intensive" operation can hide a collective. + - Any "compute intensive" operation can hide exactly one collective. + """ + + def _is_compute_intensive(node: torch.fx.Node) -> bool: + return node.target is torch.ops.aten.mm.default + + collective_to_overlapping_candidates = defaultdict(list) + available_nodes = OrderedSet[torch.fx.Node]() + collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) + for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): + candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] + collective_to_overlapping_candidates[collective] = candidates + available_nodes.update(candidates) + + unexposed_collectives = [] + for ( + collective, + overlapping_candidates, + ) in collective_to_overlapping_candidates.items(): + # Each collective consumes exactly one overlapping candidate + for x in overlapping_candidates: + if x in available_nodes: + unexposed_collectives.append(collective) + available_nodes.remove(x) + break + return unexposed_collectives + + +def micro_pipeline_tp_pass(graph: torch.fx.Graph): + all_gathers = find_all_gather_patterns(graph) + reduce_scatters = find_reduce_scatter_patterns(graph) + + # When a collective can be hidden through either simple overlapping or + # micro-pipeline TP, we prefer simple overlapping to avoid the overhead + # associated with decomposition. If reorder_for_compute_comm_overlap is + # enabled, we identify collectives that can be hidden through simple + # overlapping and exclude them from micro-pipeline TP candidates. + if config.reorder_for_compute_comm_overlap: + unexposed_collectives = _get_unexposed_collectives(graph) + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + reduce_scatters = [ + x + for x in reduce_scatters + if x.reduce_scatter_node not in unexposed_collectives + ] + + if not all_gathers and not reduce_scatters: + log.warning( + "async TP found no matching all-gather/reduce-scatter patterns for fusion" + ) + + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather) + + for reduce_scatter in reduce_scatters: + fuse_matmul_reduce_scatter(reduce_scatter) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0981e72e8b2f1e4f4d618c7bcc4dc0afa970c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,139 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._dynamo.utils import counters +from torch._ops import OpOverload, OpOverloadPacket +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import fwd_only, register_replacement + + +aten = torch.ops.aten + + +@functools.cache +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + # pyrefly: ignore [bad-argument-type] + randperm_index_add_pattern, + # pyrefly: ignore [bad-argument-type] + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + register_replacement( + # pyrefly: ignore [bad-argument-type] + randperm_index_pattern, + # pyrefly: ignore [bad-argument-type] + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: dict[str, tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: dict[str, str] + cache: dict["torch.fx.graph.Target", OrderedSet[str]] + + def __init__(self) -> None: + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = OrderedSet() + for sig in signatures: + for param_name in sig.parameters: + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8f729596cbb1f180d377a8e895b3a2fe12c8e1be --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1585 @@ +# mypy: allow-untyped-defs +import functools +import operator +from functools import reduce +from typing import Any, TYPE_CHECKING + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.utils._ordered_set import OrderedSet + +from .. import ir, mkldnn_ir +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..utils import ( + is_mkldnn_bf16_supported, + is_mkldnn_fp16_supported, + SUPPORTED_MKLDNN_DEVICES, +) +from ..virtualized import ops, V +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_int8_woq_concat_linear_pattern, + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, + _register_woq_lowerings, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + class MkldnnDeviceOpBase: + def get_linear_transpose_weight(self, weight_node): + raise NotImplementedError + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + raise NotImplementedError + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + raise NotImplementedError + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + raise NotImplementedError + + class CpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def get_linear_transpose_weight(self, weight_node): + packed_weight_node = weight_node + assert packed_weight_node.target == mkldnn._reorder_linear_weight + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.target is aten.permute.default + return transpose_weight_node + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + packed_weight_op = mkldnn._reorder_convolution_weight + if is_transposed: + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + + # mkldnn_reorder_conv_weight(self, padding, stride, dilation, groups, input_size) + packed_weight_inputs = (weight,) + tuple(constant_args) + (input_size,) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode. + # Disable MKL prepack linear when batch_size has free symbols. + packed_weight_op = ( + mkldnn._reorder_linear_weight + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation + or has_free_symbols(batch_size) + ) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) + transpose_weight_node = packed_weight_node.args[0] + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation + or has_free_symbols(batch_size) + ): + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + + return graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + + class XpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + assert not is_transposed, ( + "'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + return weight + + def _get_mkldnn_device_op(device_type: str) -> MkldnnDeviceOpBase: + """ + Returns the MKLDNN device operation class based on the current device type. + """ + if device_type == "cpu": + return CpuMkldnnDeviceOp() + elif device_type == "xpu": + return XpuMkldnnDeviceOp() + else: + raise RuntimeError(f"MKLDNN is not supported on {device_type} device.") + + def _is_valid_grouped_gemm_fusion(computation_nodes): + """ + Here we check: + 1. More than 1 GEMM nodes has been found. + 2. All the GEMM nodes share the same activation. + 3. All the GEMM nodes have same weight size but different wgt node. + """ + computation_op = mkldnn._linear_pointwise.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act + and (node.args[1].meta.get("val").size() == wgt_size) + and (node.args[1] != wgt or gemm_idx == 0) + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + def grouped_gemm_pass(graph: torch.fx.Graph): + """ + Group GEMM has multi output nodes which is complicated to define a Pattern. + Use below way to connect the pattern to the lowering. + TODO: Use MultiOutputPattern, current limitation is the pattern requires + fixed number of output nodes. Extend to support Group GEMM for pattern matcher. + """ + computation_op = mkldnn._linear_pointwise.default + from ..mkldnn_lowerings import grouped_gemm_lowering + + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_grouped_gemm_fusion(users): + with graph.inserting_before(node): + grouped_gemm_node = graph.create_node( + "call_function", + grouped_gemm_lowering, + ( + act, + [user.args[1] for user in users], + [user.args[2] for user in users], + ), + ) + grouped_gemm_node.meta["val"] = [ + user.meta["val"] for user in users + ] + with graph.inserting_after(grouped_gemm_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + grouped_gemm_node, + gemm_idx, + ), + ) + user.replace_all_uses_with(get_item) + graph.erase_node(user) + return + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op, lowp_dtype=None): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, computation_op, binary_op): + binary_nodes = filter_nodes(match.nodes, binary_op) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + + def _check_input_sizes(n, computation_op): + # Check if the tensor shape of the 'other' node is the same as or + # can be broadcasted to the tensor shape of the computation node. + computation_node = ( + n.args[0] if n.args[1] is match.kwargs["other"] else n.args[1] + ) + assert computation_node.target == computation_op + computation_node_size = get_meta_value(computation_node).size() + if computation_op is mkldnn._linear_pointwise.default: + broadcast_sizes = [] + if len(computation_node_size) >= 2: + broadcast_sizes = [ + torch.Size( + [1 for _ in range(len(computation_node_size) - 1)] + + [computation_node_size[-1]] + ), + ] + else: + assert len(computation_node_size) > 2 + broadcast_sizes = [ + torch.Size( + [computation_node_size[0], computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size( + [1, computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size([1 for _ in range(len(computation_node_size))]), + ] + return ( + get_meta_value(match.kwargs["other"]).size() + in [ + computation_node_size, + ] + + broadcast_sizes + ) + + if any( + not _check_input_sizes(n, computation_op) + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, computation_op, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = OrderedSet[torch.fx.Node]() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert len(_binary_node.all_input_nodes) == 2, ( + "Binary node should have 2 input nodes." + ) + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + # pyrefly: ignore [missing-attribute] + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + return not ( + isinstance(_other.data, ir.BaseView) + or len(_other.get_inputs_that_alias_output()) > 0 + ) + + def _qlinear_binary_can_be_inplace(_other): + if isinstance(_other.data, ir.BaseView): + + def unwrap_buffer(data): + if isinstance(data, ir.StorageBox): + return data.data + return data + + data = _other.data.unwrap_view() + if isinstance(unwrap_buffer(data), ir.CppTemplateBuffer): + # It can be inplaced when _other is the 2D to 3D view of + # a CppTemplateBuffer because if there is a view of CppTemplateBuffer, + # CppTemplateBuffer will not be used directly but the view. + return True + else: + # The case of QLinearPointwiseBinaryPT2E(sum) -> QLinearPointwiseBinaryPT2E(sum) + # is similar to CppTemplateBuffer above. + # The output of previous QLinearPointwiseBinaryPT2E is + # the input x2 of current QLinearPointwiseBinaryPT2E. + # Use V.graph.operations to check if _other is a view of the output + # of previous QLinearPointwiseBinaryPT2E (the inputs[6]). + for op in V.graph.operations: + if ( + isinstance(op, mkldnn_ir.QLinearPointwiseBinaryPT2E) + and unwrap_buffer(data) == op.inputs[6] # type: ignore[attr-defined] + ): + return True + return False + elif len(_other.get_inputs_that_alias_output()) > 0: + return False + else: + return True + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other) or other.data.shape != list( + match.nodes[0].meta["val"].size() + ): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__( + self, op_name: str, scalars_attr=None, algorithm_attr=None + ) -> None: + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_number=1) -> _recover_linear(pass_number=2) + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=2, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + def get_val(val): + return val if isinstance(val, int) else val.meta.get("val") + + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size([get_val(val) for val in reshape_2[:-1]]) + can_remove_reshape = can_remove_reshape and ( + reduce( + operator.mul, + [get_val(val) for val in reshape_2[:-1]], + ) + == get_val(reshape_1[0]) + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1 + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len( + match.nodes + ) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + device_type = add_node.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight( + linear_node.args[1] + ) + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + + if bias_meta.dtype != weight_meta.dtype: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(1) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=2, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes() + counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes) + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not is_mkldnn_bf16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not is_mkldnn_fp16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + device_type = conv_node.meta.get("val").device.type + # The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device. + if match.kwargs["is_transposed"] and device_type == "xpu": + return False + + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type not in SUPPORTED_MKLDNN_DEVICES + or (meta_value.dim() != 4 and meta_value.dim() != 5) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + + def is_const_or_cat_by_const(weight): + if weight.op == "get_attr": + return True + if weight.target != aten.cat.default: + return False + return all(arg.op == "get_attr" for arg in weight.args[0]) + + linear_node = match.output_node() + # mkldnn linear only supports beta=1or0 and alpha=1 + if linear_node.target is aten.addmm.default: + alpha = linear_node.kwargs.get("alpha", 1.0) + beta = linear_node.kwargs.get("beta", 1.0) + if (beta != 0.0 and beta != 1.0) or alpha != 1.0: + return False + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target is aten.addmm.default else 1 + if not is_const_or_cat_by_const(linear_node.args[weight_idx]): + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + if ( + input_meta_value.dtype == torch.float64 + or weight_meta_value.dtype == torch.float64 + ): + return False + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + reduced_f32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision in [ # type: ignore[attr-defined] + "bf16", + "tf32", + ] + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_meta_value.dtype == torch.float32 + ) + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight + # on x86, for fp32, mkl should be enabled. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not compute_with_lp + and not mkldnn._is_mkldnn_acl_supported() + and not torch._C.has_mkl + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + device_type = input_meta_value.device.type + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + device_type = conv_node.args[0].meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + + if not has_free_symbols(input_size): + packed_weight_node = mkldnn_device_op.pack_conv_weight( + graph, + is_transposed, + args[1], + constant_args, + input_size, + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction( + aten.addmm.default, + Arg(), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), + extra_check=_is_packable_linear, + pass_number=1, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + pass_number=1, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target is aten.mm.default else args[1] + bias = ( + None + if linear_node.target is aten.mm.default + or ( + linear_node.target is aten.addmm.default + and linear_node.kwargs.get("beta", 1.0) == 0.0 + ) + else args[0] + ) + weight = args[1] if linear_node.target is aten.mm.default else args[2] + device_type = input.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + reduced_f32_matmul_enabled = ( + torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"] # type: ignore[attr-defined] + ) + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_dtype == torch.float32 + ) + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight + batch_size = input.meta.get("val").shape[0] + packed_weight_node = mkldnn_device_op.pack_linear_weight( + graph, compute_with_lp, transpose_weight_node, batch_size + ) + packed_linear_node = mkldnn_device_op.pack_linear( + graph, compute_with_lp, batch_size, input, packed_weight_node, bias + ) + + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + torch._C._nn.mkldnn_reorder_conv3d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.cache + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + not torch.backends.mkldnn.enabled + or not torch.backends.mkldnn.is_available() + ): + return + + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + + _register_woq_lowerings() + + @functools.cache + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() + _register_int8_woq_concat_linear_pattern() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3e3ebf084ad7b785450a453c483ad4ae01895b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/node_runtime_estimation.py @@ -0,0 +1,325 @@ +""" +Collective runtime estimation using CUDA events and power-of-2 rounding. +""" + +from __future__ import annotations + +import itertools +from functools import lru_cache +from typing import Any, Optional + +import torch +import torch.fx as fx +from torch._inductor.utils import clear_on_fresh_cache, tabulate_2d +from torch._logging import getArtifactLogger, trace_structured +from torch.fx.operator_schemas import normalize_function + + +# Setup logger for artifact logging +log = getArtifactLogger(__name__, "node_runtime_estimation") + + +# TODO: Consider using a distributed-aware cache or rank-local disk cache +# not using local cache because different ranks might write to it concurrently. +# solvable in future, potentially with workflow to seed cache +@clear_on_fresh_cache +@lru_cache +def _get_collective_cache() -> dict[str, float]: + """Get process-local cache for collective benchmarks.""" + return {} + + +def get_cached_runtime(key: str) -> Optional[float]: + """Get cached runtime from process-local cache.""" + return _get_collective_cache().get(key) + + +def set_cached_runtime(key: str, value: float) -> None: + """Set cached runtime in process-local cache.""" + _get_collective_cache()[key] = value + + +def get_hint(x: int | torch.SymInt) -> Optional[int]: + if isinstance(x, int): + return x + assert isinstance(x, torch.SymInt) + return x.node.hint if x.node.has_hint() else None + + +def can_benchmark_collective() -> bool: + """Check if we can benchmark collectives (not fake process group).""" + import torch.distributed as c10d + + if not c10d.is_initialized(): + return False + + pg = c10d.distributed_c10d._get_default_group() + if torch.distributed.distributed_c10d.get_backend(pg) == "fake": + return False + + return True + + +def _median(lst): + assert len(lst) > 0 + return torch.median(torch.tensor(lst)).item() + + +def _benchmark_collective_with_cuda_events_impl( + n: torch.fx.Node, + args: tuple[Any, ...], + kwargs: dict[str, Any], + nruns: int, +) -> float | None: + """ + Core benchmarking logic using CUDA events and barriers. + Returns runtime in ms or None on failure. + """ + from torch._dynamo.testing import rand_strided + + # Convert FakeTensors to real tensors before benchmarking + def to_real(t: torch.Tensor) -> torch.Tensor: + shape = [get_hint(dim) for dim in t.shape] + stride = [get_hint(s) for s in t.stride()] + + if any(s is None for s in itertools.chain(shape, stride)): + # This should not happen, as can_benhcmark_collective checks for unbacked + raise ValueError("Cannot convert tensor with symbolic dimensions") + + return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type] + + args, kwargs = torch.utils._pytree.tree_map_only( + torch.Tensor, + to_real, + (args, kwargs), + ) + + # Warmup: call collective once and wait + torch.cuda.synchronize() + result = n.target(*args, **kwargs) # type: ignore[operator] + torch.ops._c10d_functional.wait_tensor(result) + torch.cuda.synchronize() + + # Benchmark with CUDA events + comm_times = [] + for _ in range(nruns): + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + result = n.target(*args, **kwargs) # type: ignore[operator] + torch.ops._c10d_functional.wait_tensor(result) + end_evt.record() + end_evt.synchronize() + + comm_times.append(start_evt.elapsed_time(end_evt)) + + return _median(comm_times) + + +def benchmark_collective_with_cuda_events( + n: torch.fx.Node, + nruns: int = 2, +) -> tuple[float | None, str]: + """ + Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure. + """ + # context manager not allowed with profiler. + with torch.utils._python_dispatch._disable_current_modes(): + return benchmark_collective_with_cuda_events_impl(n, nruns) + + +def benchmark_collective_with_cuda_events_impl( + n: torch.fx.Node, + nruns: int = 3, +) -> tuple[float | None, str]: + """ + Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure. + """ + from torch._inductor import fx_utils + from torch.distributed.distributed_c10d import _get_group_size_by_name + + # Early check: can we actually run collectives? + if not can_benchmark_collective(): + return None, "" + + success, args, kwargs = fx_utils.get_fake_args_kwargs(n) + + opt_args_kwargs = normalize_function( + n.target, # type: ignore[arg-type] + args=n.args, + kwargs=n.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + group_name = opt_args_kwargs[1]["group_name"] + group_size = _get_group_size_by_name(group_name) + + if not success: + return None, "" + + # Extract actual input size in BYTES (first tensor argument) + actual_bytes: Optional[int] = None + + def extract_tensor_info(t: torch.Tensor) -> torch.Tensor: + nonlocal actual_bytes + if actual_bytes is None: + shape = [get_hint(dim) for dim in t.shape] + if any(s is None for s in shape): + return t + + total_elems = 1 + for dim in shape: + assert dim is not None + total_elems *= dim + + actual_bytes = total_elems * t.dtype.itemsize + else: + raise RuntimeError(f"should only be one input tensor to collective {n}") + return t + + torch.utils._pytree.tree_map_only(torch.Tensor, extract_tensor_info, (args, kwargs)) + + if actual_bytes is None: + return None, "" + + # Cache key by BYTES (dtype-agnostic) + key = f"{n.target}: ({group_size} group size, {actual_bytes} bytes)" + + # Check cache + if (cached := get_cached_runtime(key)) is not None: + return cached, key + + # Benchmark using CUDA events with actual args/kwargs + runtime = _benchmark_collective_with_cuda_events_impl(n, args, kwargs, nruns) + + if runtime is None: + return None, key + + # Cache the result + set_cached_runtime(key, runtime) + return runtime, key + + +def _log_compute_estimations( + compute_nodes: list[fx.Node], + benchmarked_estimations: list[float], + analytical_estimations: list[float], +) -> None: + """Log compute node runtime estimations comparing benchmarked vs analytical.""" + import torch.utils._pytree as pytree + from torch._inductor.fx_utils import count_flops_fx + from torch.utils._dtype_abbrs import dtype_abbrs + + def _node_summary(n: fx.Node) -> str: + ret = str(n) + for arg in pytree.arg_tree_leaves(n.args, n.kwargs): + if not isinstance(arg, torch.fx.Node): + continue + if "val" in arg.meta: + t = arg.meta["val"] + ret += f" {dtype_abbrs[t.dtype]}{tuple(t.shape)}" + return ret + + headers = [ + "Node", + "Benchmarked Est(us)", + "Analytical Est(us)", + "Diff(%)", + "Diff(us)", + "Flops", + ] + + rows = [ + [ + _node_summary(node)[:120], + est_b * 1e3, + est_a * 1e3, + (est_a / est_b) if est_b > 0 else 0, + (est_a - est_b) * 1e3, + count_flops_fx(node), + ] + for node, est_b, est_a in zip( + compute_nodes, benchmarked_estimations, analytical_estimations + ) + ] + + log_str = tabulate_2d(rows, headers) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_compute_nodes_runtime_estimation", + "encoding": "string", + }, + payload_fn=lambda: log_str, + ) + + +def _log_collective_benchmarks( + collective_nodes: list[fx.Node], + collective_keys: list[str], + benchmarked_medians: list[float], + world_size: int, +) -> None: + """Log collective benchmarks with analytical comparisons for tlparse.""" + headers = [ + "Collective Key", + "Benchmarked(ms)", + "NCCL Est(ms)", + "Inductor Est(ms)", + "NCCL Diff(%)", + "Inductor Diff(%)", + ] + + rows = [] + collective_benchmarks = {} + for key, benchmarked_ms, coll_node in zip( + collective_keys, benchmarked_medians, collective_nodes + ): + # NCCL estimator (deterministic, no need to align) + nccl_ms = ( + torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + coll_node, None, use_nccl_estimator=True + ) + ) + + # Inductor analytical (deterministic, no need to align) + inductor_ms = ( + torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + coll_node, None, use_nccl_estimator=False + ) + ) + + collective_benchmarks[key] = { + "benchmarked_ms": benchmarked_ms, + "analytical_nccl_ms": nccl_ms, + "analytical_inductor_ms": inductor_ms, + } + + # Compute percentage differences + nccl_diff_pct = (nccl_ms / benchmarked_ms) if benchmarked_ms > 0 else 0 + inductor_diff_pct = (inductor_ms / benchmarked_ms) if benchmarked_ms > 0 else 0 + + rows.append( + [ + key[:80], + f"{benchmarked_ms:.4f}", + f"{nccl_ms:.4f}", + f"{inductor_ms:.4f}", + f"{nccl_diff_pct:.2f}", + f"{inductor_diff_pct:.2f}", + ] + ) + + log_str = f"World size: {world_size}\n" + log_str += tabulate_2d(rows, headers) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_collectives_node_runtime_estimation", + "encoding": "string", + }, + payload_fn=lambda: log_str, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1db82f21f7ec6a37e1f260b02d2fcd77622c058 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,213 @@ +# mypy: allow-untyped-defs +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim +from torch.utils._ordered_set import OrderedSet + +from .. import config + + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(OrderedSet(dict_base.keys())) != len(OrderedSet(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base: + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception: + logger.exception("Exception when comparing gradients") + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception: + logger.exception( + "Exception when optimizer is added to check parameter names" + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( # noqa: G200 + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..540e73166ba45be7d9fd6eb12e627f795bae94dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import heapq +from collections import Counter, defaultdict +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.fx as fx +from torch._dynamo.graph_deduplication import _stable_topological_sort +from torch._inductor.fx_passes.bucketing import ( + _schedulable_wait_node, + is_all_gather_into_tensor as is_all_gather, + is_reduce_scatter_tensor as is_reduce_scatter, + merge_all_gather_bucket, + merge_reduce_scatter_bucket, +) +from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + bucket_key, + OverlapPreservingBucketer, +) +from torch._inductor.fx_passes.overlap_scheduling import ( + CollectiveInfo, + is_compute_node, + OverlapScheduler, +) +from torch.utils._ordered_set import OrderedSet + +from .graph_view import get_subgraph_by_path, GraphView, make_graph_view + + +if TYPE_CHECKING: + from collections.abc import Callable + + +class ManualOverlapPreservingBucketer(OverlapPreservingBucketer): + """ + Buckets collective operations based on user specifications. + The actual bucket happens in bucket_collectives, where all-gathers/reduce-scatters in + `nodes` will be buckted one single all-gather/reduce-scatter. + """ + + def __init__( + self, + node_users: dict[fx.Node, OrderedSet[fx.Node]], + *args: Any, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.node_users = node_users + self.wait_to_node_map: dict[fx.Node, fx.Node] = defaultdict() + + def _check_recursive_dep( + self, + node: fx.Node, + target_op: str, + dep_dict: dict[torch.fx.Node, OrderedSet[torch.fx.Node]], + ) -> bool: + """ + Check if the node is directly used for fetch parameters/gradients + + TODO (ruisizhang123): currently, we assume the node only pre-fetch/update one parameter/gradient + We should handle multiple parameters/gradients update case by checking if there are non closure + computes along the path from primal/output to coll_node + """ + deps: OrderedSet[fx.Node] = dep_dict[node] + seen_target_op = 0 + for d in deps: + if d.op == target_op: + seen_target_op += 1 + + return seen_target_op == 1 + + def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: + assert len(coll_nodes) > 0, "bucketed coll_nodes should have nonzero node" + + waits = [self.collective_info[n].wait_node for n in coll_nodes] + # Use earliest wait insertion point + first_wait = min(waits, key=lambda w: self.node_idx[w]) + # Find insertion location + first = coll_nodes[0] + next_node = first + while next_node in coll_nodes: + next_node = next_node.next + + if is_all_gather(first): + new_nodes, replacements = merge_all_gather_bucket( + self.graph, + coll_nodes, + wait_insertion_point=first_wait, + insert_before=next_node, + mode="custom_ops", + ) + elif is_reduce_scatter(first): + new_nodes, replacements = merge_reduce_scatter_bucket( + self.graph, + coll_nodes, + wait_insertion_point=first_wait, + insert_before=next_node, + mode="custom_ops", + ) + else: + raise ValueError( + "bucket non all_gather/reduce_scatter node is not supported" + ) + + # Identify the new wait and start + new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] + assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}" + new_wait = new_waits[0] + new_start = new_wait.args[0] + assert isinstance(new_start, fx.Node) + + # Set manual bucketing-specific metadata + # Note: Generic metadata (nn_module_stack, fwd_nn_module_stack, custom, stack_trace) + # is now preserved automatically by the bucketing functions in bucketing.py + node_type = ( + "bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter" + ) + for n in new_nodes: + if n == new_wait: + node_type = node_type + "_wait" + n.meta["manual_bucket_node_type"] = node_type + if "wait" in node_type: + self.wait_to_node_map[n] = new_wait + + def manual_bucket_collectives(self, nodes: list[fx.Node]) -> None: + """ + Bucket all all-gather/reduce-scatter nodes from nodes into one all-gather/reduce-scatter. + """ + # Filter out valid collectives + collectives = [n for n in nodes if n in self.collective_info] + if collectives == []: + return + grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in collectives: + key = bucket_key(node) + if not (is_all_gather(node) or is_reduce_scatter(node)): + continue + # We only want to bucket all-gather/reduce-scatter that + # 1. all_gather that have ancestors dependent only on input placeholder(parameters) + # 2. reduce scatter that the wait user node is returned as output(gradients) + if is_all_gather(node) and not self._check_recursive_dep( + node, "placeholder", self.node_ancestors + ): + continue + if is_reduce_scatter(node) and not self._check_recursive_dep( + self.collective_info[node].wait_node, "output", self.node_users + ): + continue + if key is not None: + grouped_collectives[key].add(node) + + for key, nodes in grouped_collectives.items(): # type: ignore[arg-type] + self._bucket_group(list(nodes)) + + +class ManualOverlapScheduler(OverlapScheduler): + """ + Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans + """ + + def __init__( + self, + gm: fx.GraphModule, + module_bucket_plans: list[list[str] | str], + insert_overlap_deps: bool, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, + ): + super().__init__( + gm, + max_in_flight_gb=0.0, + max_compute_pre_fetch=0, + collective_bucketing=True, + insert_overlap_deps=insert_overlap_deps, + compute_overlap_multipler=0.0, + max_coll_distance=0, + custom_runtime_estimation=None, + collective_estimator="analytical", + max_memory_increase_gb=None, + max_memory_increase_ratio=None, + ) + self.module_bucket_plans = module_bucket_plans + self.nodes_in_subgraph: list[list[fx.Node]] = [] + + self.node_users: dict[fx.Node, OrderedSet[fx.Node]] = self._collect_node_users() + self.bucketer = ManualOverlapPreservingBucketer( + graph=self.graph, + collective_info=self.collective_info, + node_users=self.node_users, + scheduled=OrderedSet(self.graph.nodes), + ) + self.insert_overlap_deps = insert_overlap_deps + + self.module_stack_fn = module_stack_fn + + def _identify_collectives(self) -> None: + """Identify all collective operations.""" + for node in self.nodes: + if _schedulable_wait_node(node): + start = node.args[0] + info = CollectiveInfo( + start_node=start, + wait_node=node, + size_bytes=0, + estimated_time_ms=0, + exposed_time_ms=0, + ) + self.collective_info[start] = info + self.wait_to_start[node] = start + self.unscheduled_collectives.add(start) + + def run(self) -> torch.fx.GraphModule: + """Entry point to run the manual bucket algorithm""" + # Bucket collectives in each bucket_module + self._manual_bucket_collectives() + + # Reorder collectives with last/next bucket_module + self._manual_reorder_graph() + + return self.gm + + def _manual_reorder_graph(self) -> None: + """ + Reorder nodes in the FX graph to enforce manual overlap dependencies. + + Enforce: + - all_gather_start_i depends on all_gather_wait_(i-1) + - reduce_scatter_wait_i must happen before reduce_scatter_start_(i+1) + """ + delayed_rs_nodes: list[fx.Node] = [] + overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + + # schedule reduce scatter normally in self._schedule + while self.ready: + _, node = heapq.heappop(self.ready) + node_type = node.meta.get("manual_bucket_node_type", "") + + if node in self.scheduled: + continue + + if node_type == "bucketed_reduce_scatter": + # Ensure all delayed waits execute before this reduce_scatter + for delayed in delayed_rs_nodes: + self._schedule(delayed) + overlap_deps[delayed].add(node) + delayed_rs_nodes.clear() + + elif node_type == "bucketed_reduce_scatter_wait": + # Defer until next reduce_scatter + delayed_rs_nodes.append(node) + continue + self._schedule(node) + + for delayed in delayed_rs_nodes: + self._schedule(delayed) + + self.scheduled = OrderedSet(reversed(list(self.scheduled))) + picked_ag: list[fx.Node] = [] + last_compute: Optional[fx.Node] = None + + for node in self.scheduled: + node_type = node.meta.get("manual_bucket_node_type", "") + if node_type == "bucketed_all_gather": + picked_ag.append(node) + continue + + if node_type == "bucketed_all_gather_wait": + # Connect corresponding all_gather_wait -> all_gather edges + if picked_ag: + for ag in picked_ag: + overlap_deps[self.bucketer.wait_to_node_map[node]].add(ag) + picked_ag.clear() + if is_compute_node(node): + last_compute = node + + if last_compute is not None and not bool( + OrderedSet(picked_ag) & OrderedSet(self.node_ancestors[last_compute]) + ): + for ag in picked_ag: + overlap_deps[last_compute].add(ag) + + _stable_topological_sort(self.graph, overlap_deps) + self.graph.lint() + + if self.insert_overlap_deps: + from torch._inductor.fx_passes.control_dependencies import ( + preserve_node_ordering, + ) + + preserve_node_ordering(self.graph, overlap_deps) + + def _manual_bucket_collectives(self) -> None: + """Bucket nodes in each module_bucket from module_bucket_plans.""" + self._obtain_nodes_in_subgraph() + for i, nodes in enumerate(self.nodes_in_subgraph): + self.bucketer.manual_bucket_collectives(nodes=nodes) + + _stable_topological_sort(self.graph, {}) + self.graph.lint() + self.nodes = list(self.graph.nodes) + self.in_degree = Counter(user for node in self.nodes for user in node.users) + + def _collect_node_users(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """Collect all users for each node.""" + node_users: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in self.nodes: + for output_node in list(node.users.keys()): + node_users[node].add(output_node) + node_users[node] |= node_users[output_node] + return node_users + + def _schedule(self, node: fx.Node) -> None: + """Schedule a node.""" + assert node not in self.scheduled + assert all(n in self.scheduled for n in node.all_input_nodes) + self.scheduled.add(node) + for user in node.users: + self.in_degree[user] -= 1 + if self.in_degree[user] == 0: + heapq.heappush(self.ready, ((), user)) + + def _obtain_nodes_in_subgraph(self) -> None: + """ + Obtain nodes in each subgraph from module_bucket_plans + """ + graph_view: GraphView | None = make_graph_view(self.graph, self.module_stack_fn) + if graph_view is None: + return + + for module in self.module_bucket_plans: + subgraph_view = get_subgraph_by_path(graph_view, module) + self.nodes_in_subgraph.append(subgraph_view) + + all_subgraph_nodes = [ + node for sublist in self.nodes_in_subgraph for node in sublist + ] + unique_subgraph_nodes = list(OrderedSet(all_subgraph_nodes)) + assert len(all_subgraph_nodes) <= len(unique_subgraph_nodes), ( + f"Overlapping FX nodes detected across subgraphs in `module_bucket_plans`. " + f"Expected disjoint node sets but found " + f"{len(all_subgraph_nodes) - len(unique_subgraph_nodes)} duplicated node(s)." + ) + + +def manual_overlap_bucketing( + gm: torch.fx.GraphModule, + module_bucket_plans: list[list[str] | str], + insert_overlap_deps: bool = False, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, +) -> torch.fx.GraphModule: + """Schedule nodes based on user specifications in module_bucket_plans + The manual overlapping consists of two steps: + Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans + Step 2: reorder all-gather to overlap with last module_bucket & + reorder reduce-scatter to overlap with next module_bucket + TODO(ruisizhang123): allow users to explicitly specify which + module_bucket they want to overlap. + + Args: + gm: input graph module to optimize. + module_bucket_plans: user specified FQNs + module_stack_fn: Optional callable for extracting module hierarchy from nodes. + Used to construct a GraphView for identifying nodes in module_bucket_plans. + The module_class component of the returned tuples is not used by this pass. + + See the `module_stack_fn` parameter in `make_graph_view` (graph_view.py) for + detailed documentation on signature, return format, and usage examples. + """ + # decode abbreviated FQNs to actual FQNs + overlapped_gm = ManualOverlapScheduler( + gm, module_bucket_plans, insert_overlap_deps, module_stack_fn + ).run() + overlapped_gm.recompile() + return overlapped_gm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c819f37a1a83ecff13c4b18ceb2753b61087c29 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -0,0 +1,912 @@ +import itertools +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Literal, Optional + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters +from torch._inductor.augmented_graph_helper import AugmentedGraphHelper +from torch._inductor.fx_passes.bucketing import ( + _schedulable_wait_node, + bucket_key, + BucketMode, + has_mergeable_all_gather_convert_dtype, + is_all_gather_into_tensor as is_all_gather, + is_reduce_scatter_tensor as is_reduce_scatter, +) +from torch._inductor.fx_passes.overlap_scheduling import ( + CollBucket, + CollectiveInfo, + get_group_name, + is_compute_node, +) +from torch.utils._ordered_set import OrderedSet + + +bucket_log = logging.getLogger(__name__) + + +@dataclass +class WhyNoBucket: + name1: str + name2: str + reason: str + args: tuple[Any, ...] + + def __init__(self, node1: fx.Node, node2: fx.Node) -> None: + self.name1 = node1.name + self.name2 = node2.name + self.reason = "" + self.args = () + + def __call__(self, reason: str, *args: Any) -> None: + if bucket_log.isEnabledFor(logging.DEBUG): + bucket_log.debug( + "cannot bucket %s with %s: " + reason, # noqa: G003 + self.name1, + self.name2, + *args, + ) + + +def is_collective_or_wait(n: fx.Node) -> bool: + """Check if node is a collective start or wait.""" + if _schedulable_wait_node(n): + return True + # Collective starts have exactly one use: the wait_tensor + if len(n.users) == 1: + user = next(iter(n.users.keys())) + if _schedulable_wait_node(user): + return True + return False + + +@dataclass +class PGEvent: + """ + Represents an important event in a process group timeline. Either + a collective start, wait, or hiding compute. Each node is linked + to its prev and next and these dependencies are reflected + in the augmented graph. + + We want to enforce a sequential ordering of collective starts and waits + because NCCL collectives on the same process group execute on the same CUDA + stream, creating implicit dependencies between all operations on that PG. + + A wait of a particular collective will implicitly force realization of all collectives + enqueued prior to that collective. + """ + + node: fx.Node + event_type: Literal["compute", "starts", "waits"] + position: int + prev: Optional["PGEvent"] = None + next: Optional["PGEvent"] = None + + @property + def is_start(self) -> bool: + return self.event_type == "starts" + + @property + def is_wait(self) -> bool: + return self.event_type == "waits" + + @property + def is_compute(self) -> bool: + return self.event_type == "compute" + + def unlink(self) -> tuple[Optional["PGEvent"], Optional["PGEvent"]]: + """Remove this event from the linked list, return (prev, next).""" + prev_event, next_event = self.prev, self.next + if self.prev: + self.prev.next = self.next + if self.next: + self.next.prev = self.prev + self.prev = None + self.next = None + return prev_event, next_event + + def insert_between( + self, prev_event: Optional["PGEvent"], next_event: Optional["PGEvent"] + ) -> None: + """Insert this event between prev_event and next_event in the linked list.""" + if prev_event: + prev_event.next = self + self.prev = prev_event + + if next_event: + next_event.prev = self + self.next = next_event + + +class OverlapPreservingBucketer: + """ + Buckets collective operations while preserving compute-collective overlap relationships. + Uses an augmented graph to track dependencies between compute and collective operations. + """ + + def __init__( + self, + graph: fx.Graph, + collective_info: dict[fx.Node, CollectiveInfo], + scheduled: OrderedSet[fx.Node], + max_bucket_memory_gb: float = 1.0, + max_coll_distance: int = 1000, + insert_overlap_deps: bool = False, + bucket_mode: BucketMode = "custom_ops_multidtype", + ): + self.graph = graph + self.collective_info = collective_info + self.scheduled = scheduled + self.max_bucket_memory_gb = max_bucket_memory_gb + self.node_idx = {n: i for i, n in enumerate(scheduled)} + self.max_coll_distance = max_coll_distance + self.insert_overlap_deps = insert_overlap_deps + self.bucket_mode = bucket_mode + self.node_to_event: dict[fx.Node, PGEvent] = {} + self.all_hiding_nodes: OrderedSet[fx.Node] = OrderedSet() + + # Compute ancestors including original graph edges and hiding interval dependencies + self.node_ancestors = self._compute_node_ancestors() + self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) + + # Build timelines and add constraints to aug_graph + self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() + self._add_hiding_interval_constraints() + + def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """ + Compute ancestor sets for all nodes including: + 1. Original graph edges + 2. Hiding interval deps: collective_start -> hiding_node -> wait + """ + augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for start, info in self.collective_info.items(): + if info.is_exposed: + continue + for hiding_node in info.hiding_nodes: + augmented_inputs[hiding_node].add(start) + augmented_inputs[info.wait_node].add(hiding_node) + + node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in self.scheduled: + for input_node in itertools.chain( + augmented_inputs[node], node.all_input_nodes + ): + node_ancestors[node].add(input_node) + node_ancestors[node] |= node_ancestors[input_node] + + return node_ancestors + + def build_timelines(self) -> dict[str, Optional[PGEvent]]: + "Construct each process groups ordered series of event" + all_pgs: OrderedSet[str] = OrderedSet() + for start in self.collective_info: + pg = get_group_name(start) + all_pgs.add(pg) + + pg_timeline: dict[str, Optional[PGEvent]] = {} + for pg in all_pgs: + pg_timeline[pg] = self.build_timeline(pg) + + return pg_timeline + + def build_timeline(self, pg: str) -> Optional[PGEvent]: + """ + Build a timeline of important events (starts, waits, hiding compute) for this process group + and constrain this ordering in the augmented graph. + + Sequential dependencies are added between all events because NCCL collectives on the same + process group execute on the same CUDA stream, enforcing LIFO semantics where later-issued + collectives must complete before earlier ones can finish. + """ + + head = None + prev_event = None + position = 0 + hiding_nodes = OrderedSet() + + for node in self.scheduled: + node_type = None + + # Determine if this node is relevant for this PG + if node in self.collective_info and get_group_name(node) == pg: + node_type = "starts" + hiding_nodes |= self.collective_info[node].hiding_nodes + elif _schedulable_wait_node(node): + wait_input = node.args[0] + if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: + node_type = "waits" + # Wait for a different PG but hiding a collective on this PG + elif node in hiding_nodes: + node_type = "compute" + elif is_compute_node(node) or node in hiding_nodes: + node_type = "compute" + + if node_type is None: + continue + + event = PGEvent(node=node, event_type=node_type, position=position) # type: ignore[arg-type] + + event.insert_between(prev_event, None) + + # Add sequential dependency to augmented graph + if prev_event: + self.aug_graph.add_extra_dep(n=event.node, dep=prev_event.node) + else: + head = event + + prev_event = event + position += 1 + + return head + + def _populate_node_to_event(self, pg: str) -> None: + """Populate node_to_event mapping for a specific PG's timeline.""" + self.node_to_event.clear() + head = self.pg_to_timeline_head[pg] + curr = head + while curr is not None: + self.node_to_event[curr.node] = curr + curr = curr.next + + def _add_hiding_interval_constraints(self) -> None: + """ + Add hiding interval constraints: start -> compute -> wait. + """ + for start, info in self.collective_info.items(): + if info.is_exposed: + continue + for hn in info.hiding_nodes: + # Enforce: start -> compute -> wait + self.aug_graph.add_extra_dep(n=hn, dep=start) + self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn) + + self.all_hiding_nodes |= info.hiding_nodes + + def bucket_collectives(self) -> None: + # Group collectives by PG first + pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for start in self.collective_info: + pg = get_group_name(start) + pg_collectives[pg].add(start) + + all_buckets: list[CollBucket] = [] + for pg, collectives in pg_collectives.items(): + # Populate node_to_event for this PG's timeline + self._populate_node_to_event(pg) + + # Group by bucket key within this PG + grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + for start in collectives: + key = bucket_key(start, self.bucket_mode) + if key is not None: + grouped_collectives[key].add(start) + + # Find buckets for this PG + for key, collective_group in grouped_collectives.items(): + bucket_log.debug( + "bucketing collective group with key %s: %s", + key, + [n.name for n in collective_group], + ) + buckets = self._find_buckets(collective_group) + all_buckets.extend(buckets) + + # Apply bucketing transformations + # Dependencies are tracked in aug_graph.extra_deps during bucketing + for coll_bucket in all_buckets: + if len(coll_bucket.collectives) <= 1: + continue + + counters["inductor"]["collective_buckets"] += 1 + self._apply_bucket(coll_bucket) + + # Extract all dependencies from augmented graph + # This includes: + # - Sequential timeline deps (added during build_timeline) + # - Hiding interval deps (added during _add_hiding_interval_constraints) + # - All transferred deps from bucketing (transferred during _apply_bucket) + additional_deps = self.aug_graph.get_all_extra_deps() + + # Apply topological sort with all dependencies + from torch._dynamo.graph_deduplication import _stable_topological_sort + + for n, deps in additional_deps.items(): + torch._check( + not n._erased, lambda: f"Erased node deps not transferred: {n}" + ) + for d in deps: + torch._check( + not d._erased, lambda: f"Erased node deps not transferred: {d}" + ) + + _stable_topological_sort(self.graph, additional_deps) + + # After topological sort, preserve dependencies using effect tokens + # Only preserve edges where NOT both nodes are collective starts or waits + if self.insert_overlap_deps: + filtered_deps: dict[fx.Node, OrderedSet[fx.Node]] = {} + for node, deps in additional_deps.items(): + filtered_node_deps: OrderedSet[fx.Node] = OrderedSet() + + # only preserve comm-comptue overlap for now, although we could more + # generally constrain + for dep in deps: + if not (is_collective_or_wait(node) and is_collective_or_wait(dep)): + filtered_node_deps.add(dep) + + if filtered_node_deps: + filtered_deps[node] = filtered_node_deps + + self._preserve_dependencies_with_tokens(filtered_deps) + + self.graph.lint() + + def _find_buckets( + self, + collective_group: OrderedSet[fx.Node], + ) -> list[CollBucket]: + """Find valid buckets within a group of similar collectives.""" + max_bucket_bytes = int(self.max_bucket_memory_gb * 1024 * 1024 * 1024) + buckets = [] + processed: OrderedSet[fx.Node] = OrderedSet() + + # Sort collectives by node index for efficient distance checking + sorted_collectives = sorted(collective_group, key=lambda n: self.node_idx[n]) + + for i, start_node in enumerate(sorted_collectives): + if start_node in processed: + continue + + if ( + start_node in self.all_hiding_nodes + or self.collective_info[start_node].wait_node in self.all_hiding_nodes + ): + continue + + # Initialize bucket with first collective + bucket_info = CollBucket( + collectives=[start_node], + total_bytes=self.collective_info[start_node].size_bytes, + ) + processed.add(start_node) + + # Greedy optimization: stop after consecutive failures + consecutive_failures = 0 + max_consecutive_failures = 20 + + # Check candidates in sorted order, break when beyond max distance + for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: + candidate_bytes = self.collective_info[candidate].size_bytes + # proxy on memory use, if we see a too large bucket, + # dont look for another, later bucket + if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: + break + + if candidate in processed: + continue + + if self._can_add_to_bucket(bucket_info, candidate): + bucket_info.collectives.append(candidate) + bucket_info.total_bytes += candidate_bytes + processed.add(candidate) + consecutive_failures = 0 # Reset on success + else: + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + break + + if len(bucket_info.collectives) > 1: + buckets.append(bucket_info) + + return buckets + + def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool: + """Check if there's an ancestor relationship between two nodes.""" + return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1] + + def _get_intervals( + self, event: PGEvent + ) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]: + """Get (execution_interval, hiding_intervals) for a collective event. + + Returns: + (execution_interval, hiding_intervals) where: + - execution_interval is (start_pos, wait_pos) or None + - hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node + + Works for both start and wait events by looking up the collective info. + """ + # For start events, directly use the node + if event.is_start: + coll = event.node + # For wait events, look up the start node from the event's args + elif event.is_wait: + wait_input = event.node.args[0] + if not isinstance(wait_input, fx.Node): + return None, [] + coll = wait_input + else: + return None, [] + + if coll not in self.collective_info: + return None, [] + + info = self.collective_info[coll] + start_event = self.node_to_event[coll] + wait_event = self.node_to_event[info.wait_node] + + execution_interval = (start_event.position, wait_event.position) + + hiding_intervals = [] + if info.hiding_nodes: + for hiding_node in info.hiding_nodes: + hiding_intervals.append( + ( + start_event.position, + self.node_to_event[hiding_node].position, + ) + ) + + return execution_interval, hiding_intervals + + def _preserves_hiding_intervals( + self, + bucket_info: CollBucket, + candidate: fx.Node, + start_pos: fx.Node, + wait_pos: fx.Node, + why: WhyNoBucket, + ) -> bool: + """ + Check that (start_pos, wait_pos) doesn't violate any hiding intervals or collectives. + + Collects all execution and hiding intervals in the affected timeline regions, + then checks: + 1. All bucket hiding compute stays between new start/wait + 2. No other collective's compute interval is enclosed by bucket execution interval + 3. No other collective's execution interval encloses bucket compute intervals + """ + # Collect all collectives being bucketed + all_bucketed_colls = [candidate] + list(bucket_info.collectives) + all_bucketed_waits = [ + self.collective_info[coll].wait_node for coll in all_bucketed_colls + ] + + # Collect hiding compute positions for the bucket + bucket_hiding_compute_positions = [] + for coll in all_bucketed_colls: + for coll_hiding_node in self.collective_info[coll].hiding_nodes: + bucket_hiding_compute_positions.append( + self.node_to_event[coll_hiding_node].position + ) + + # Get new positions + new_start_event = self.node_to_event[start_pos] + new_wait_event = self.node_to_event[wait_pos] + + # Check 1: All bucket hiding compute must be between new start and wait + for compute_pos in bucket_hiding_compute_positions: + if not (new_start_event.position < compute_pos < new_wait_event.position): + why( + "hiding compute at pos %d not between start %d and wait %d", + compute_pos, + new_start_event.position, + new_wait_event.position, + ) + return False + + def get_wait(n: fx.Node) -> fx.Node: + return self.collective_info[n].wait_node + + def get_pos(n: fx.Node) -> int: + return self.node_to_event[n].position + + latest_start_pos = max(get_pos(candidate), get_pos(bucket_info.collectives[0])) + earliest_wait_pos = min( + get_pos(get_wait(candidate)), get_pos(get_wait(bucket_info.collectives[0])) + ) + + # Bucket execution interval + bucket_execution_interval = (new_start_event.position, new_wait_event.position) + + # Because collectives on the same PG operate under LIFO semantics, + # it's only possible for us to force an early realization of an unrelated collective + # by delaying a start or raising a wait. + # We search in the interval from old_start -> new_start, to see if would be + # forcing another collective to be realized prior to its hiding nodes. + # Similarly, we search from old_wait -> new_wait, in the reverse direction, + # to check the same thing. + + execution_intervals = [bucket_execution_interval] + hiding_intervals = [ + (bucket_execution_interval[0], pos) + for pos in bucket_hiding_compute_positions + ] + + curr_event = new_start_event.next + while curr_event is not None and curr_event.position < latest_start_pos: + if ( + curr_event.node not in all_bucketed_colls + and curr_event.node not in all_bucketed_waits + ): + exec_interval, hiding_interval_list = self._get_intervals(curr_event) + if exec_interval: + execution_intervals.append(exec_interval) + hiding_intervals.extend(hiding_interval_list) + curr_event = curr_event.next + + curr_event = new_wait_event.prev + while curr_event is not None and curr_event.position > earliest_wait_pos: + if ( + curr_event.node not in all_bucketed_colls + and curr_event.node not in all_bucketed_waits + ): + exec_interval, hiding_interval_list = self._get_intervals(curr_event) + if exec_interval: + execution_intervals.append(exec_interval) + hiding_intervals.extend(hiding_interval_list) + curr_event = curr_event.prev + + # Check: no hiding interval should be enclosed by any execution interval + def enclosed_interval(inner: tuple[int, int], outer: tuple[int, int]) -> bool: + return outer[0] < inner[0] and inner[1] < outer[1] + + for hiding_interval in hiding_intervals: + for execution_interval in execution_intervals: + if enclosed_interval(hiding_interval, execution_interval): + why( + "hiding interval %s enclosed by execution interval %s", + hiding_interval, + execution_interval, + ) + return False + + return True + + def remove_from_event( + self, node: fx.Node + ) -> tuple[Optional[PGEvent], Optional[PGEvent]]: + """Remove node from timeline and return (prev_event, next_event).""" + event = self.node_to_event[node] + assert not event.is_compute, "Cannot remove compute events from timeline" + + prev_event, next_event = event.unlink() + + # Remove augmented graph dependency + if prev_event: + self.aug_graph.remove_extra_dep(n=node, dep=prev_event.node) + if next_event: + self.aug_graph.remove_extra_dep(n=next_event.node, dep=node) + + # Add bypass dependency + if prev_event and next_event: + self.aug_graph.add_extra_dep(n=next_event.node, dep=prev_event.node) + + return prev_event, next_event + + def restore_to_event( + self, + node: fx.Node, + prev_event: Optional[PGEvent], + next_event: Optional[PGEvent], + ) -> None: + """Restore node to timeline after failed merge attempt.""" + event = self.node_to_event[node] + + # Reinsert into linked list + event.insert_between(prev_event, next_event) + if prev_event: + self.aug_graph.add_extra_dep(n=node, dep=prev_event.node) + if next_event and not prev_event: + self.aug_graph.add_extra_dep(n=next_event.node, dep=node) + + # Remove bypass dependency + if prev_event and next_event: + self.aug_graph.remove_extra_dep(n=next_event.node, dep=prev_event.node) + + def _try_timeline_position( + self, + bucket_info: CollBucket, + candidate: fx.Node, + start_pos: fx.Node, + wait_pos: fx.Node, + why: WhyNoBucket, + ) -> bool: + """ + Try a specific timeline position for the candidate. + Returns True if valid and merges are successful. + """ + candidate_info = self.collective_info[candidate] + candidate_wait = candidate_info.wait_node + + # Quick check: does this violate hiding intervals? + if not self._preserves_hiding_intervals( + bucket_info, candidate, start_pos, wait_pos, why + ): + return False + + # Determine which start needs to move + existing_coll = bucket_info.collectives[0] + if start_pos == existing_coll: + start_to_move = candidate + else: + assert start_pos == candidate + start_to_move = existing_coll + + # Remove start from timeline + start_prev, start_next = self.remove_from_event(start_to_move) + + # Check if starts can be merged + if self.aug_graph.has_path(existing_coll, candidate) or self.aug_graph.has_path( + candidate, existing_coll + ): + # Restore start constraints + self.restore_to_event(start_to_move, start_prev, start_next) + why("path exists between starts") + return False + + # Merge starts + self.aug_graph.merge_to_set(existing_coll, candidate) + + # Determine which wait needs to move + existing_wait = self.collective_info[existing_coll].wait_node + candidate_wait = self.collective_info[candidate].wait_node + + if wait_pos == existing_wait: + wait_to_move = candidate_wait + else: + wait_to_move = existing_wait + + # Remove wait from timeline + wait_prev, wait_next = self.remove_from_event(wait_to_move) + + # Check if waits can be merged + if self.aug_graph.has_path( + existing_wait, candidate_wait + ) or self.aug_graph.has_path(candidate_wait, existing_wait): + # Restore wait constraints + self.restore_to_event(wait_to_move, wait_prev, wait_next) + # Unmerge the start we just merged + self.aug_graph.unmerge_node(candidate) + # Restore start constraints + self.restore_to_event(start_to_move, start_prev, start_next) + why("path exists between waits") + return False + + # Merge waits - success! + self.aug_graph.merge_to_set(existing_wait, candidate_wait) + + # Update node_to_event for moved nodes + target_start_event = self.node_to_event[start_pos] + target_wait_event = self.node_to_event[wait_pos] + + self.node_to_event[candidate] = target_start_event + self.node_to_event[candidate_wait] = target_wait_event + self.node_to_event[existing_coll] = target_start_event + self.node_to_event[existing_wait] = target_wait_event + + return True + + def _has_ancestor_conflicts( + self, bucket_info: CollBucket, candidate: fx.Node + ) -> bool: + """ + Check if candidate has ancestor conflicts with bucket collectives. + Returns True if there are conflicts. + """ + candidate_info = self.collective_info[candidate] + candidate_wait = candidate_info.wait_node + + for coll in bucket_info.collectives: + if ( + coll in self.node_ancestors[candidate] + or candidate in self.node_ancestors[coll] + ): + return True + + # Check if waits are ancestors of each other + coll_wait = self.collective_info[coll].wait_node + if ( + coll_wait in self.node_ancestors[candidate_wait] + or candidate_wait in self.node_ancestors[coll_wait] + ): + return True + + # Check if existing hiding node conflicts with candidate wait + for old_hiding_node in self.collective_info[coll].hiding_nodes: + if candidate_wait in self.node_ancestors[old_hiding_node]: + return True + + # Check if candidate hiding node conflicts with existing wait + for new_hiding_node in candidate_info.hiding_nodes: + if coll_wait in self.node_ancestors[new_hiding_node]: + return True + + return False + + def _can_add_to_bucket( + self, + bucket_info: CollBucket, + candidate: fx.Node, + ) -> bool: + """ + Check if candidate can be added to bucket without breaking comm/compute overlap. + + Strategy: Try all timeline positions - combinations of [existing_start, candidate_start] + x [existing_wait, candidate_wait]. For each position, verify: + 1. Hiding intervals preserved - for any (start, hiding_compute, wait) interval, no other + collective's (start, wait) pair falls between start and hiding_compute, which would + force realization and break overlap due to LIFO semantics + 2. Topologically valid (no dependency cycles) + + Return True if any timeline position satisfies both constraints. + """ + existing_coll = bucket_info.collectives[0] + why = WhyNoBucket(existing_coll, candidate) + + candidate_info = self.collective_info[candidate] + + if ( + candidate in self.all_hiding_nodes + or candidate_info.wait_node in self.all_hiding_nodes + ): + why("nyi: bucketing collective used for overlap") + return False + + # Step 1: Quick check using precomputed ancestors + # These ancestors are computed prior to adding augmented dependencies and not updated, + # so if any of these checks fail then the merge will not be topologically valid + # even ignoring comm/compute overlap + if self._has_ancestor_conflicts(bucket_info, candidate): + why("has ancestor conflicts") + return False + + # Step 2: Try different rail positions + existing_wait = self.collective_info[existing_coll].wait_node + + candidate_start = candidate + candidate_wait = candidate_info.wait_node + + # Try combinations in order of likelihood to succeed + # (early start, later wait is most likely to work) + combinations = [ + ( + existing_coll, + candidate_wait, + ), # Move candidate start early, keep wait late + ( + existing_coll, + existing_wait, + ), # Move candidate start early, move wait early + (candidate_start, candidate_wait), # Keep both in place + (candidate_start, existing_wait), # Keep start in place, move wait early + ] + + for i, (start_pos, wait_pos) in enumerate(combinations): + if self._try_timeline_position( + bucket_info, candidate, start_pos, wait_pos, why + ): + bucket_log.debug( + "bucketed %s with %s using timeline position %d: (start=%s, wait=%s)", + candidate.name, + existing_coll.name, + i + 1, + start_pos.name, + wait_pos.name, + ) + return True + + why("all timeline positions failed") + return False + + def _apply_bucket(self, bucket_info: CollBucket) -> None: + """ + Apply bucketing transformation. + + Dependencies are added to aug_graph.extra_deps and transferred from old nodes. + """ + + from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor, + merge_all_gather_bucket, + merge_all_reduce_bucket, + merge_reduce_scatter_bucket, + ) + + bucket = bucket_info.collectives + + # Collect old nodes BEFORE they're erased + old_starts = list(bucket) + old_waits = [self.collective_info[n].wait_node for n in bucket] + + fused_convert_dtypes = [] + for n in old_starts: + if has_mergeable_all_gather_convert_dtype(n): + fused_convert_dtypes.append(n.args[0]) + + # Find where to place the bucketed operations + next_node = bucket[0] + while next_node in bucket: + next_node = next_node.next + + # Don't use wait_insertion_point - let merge functions place waits naturally + # The wait_insertion_point feature tries to move waits to a specific location, + # but this can cause issues when that location is one of the nodes being erased + # Create bucketed collective (this will erase old nodes) + if is_all_gather(bucket[0]): + new_nodes, replacements = merge_all_gather_bucket( + self.graph, + bucket, + insert_before=next_node, + mode="custom_ops", + ) + elif is_all_reduce_tensor(bucket[0]): + new_nodes, replacements = merge_all_reduce_bucket( + self.graph, + bucket, + mode="custom_ops", + insert_before=next_node, + ) + else: + assert is_reduce_scatter(bucket[0]) + new_nodes, replacements = merge_reduce_scatter_bucket( + self.graph, + bucket, + insert_before=next_node, + mode="custom_ops", + ) + + # Get new nodes + new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] + assert len(new_waits) == 1 + + new_wait = new_waits[0] + new_start = new_wait.args[0] + assert isinstance(new_start, fx.Node) + + # Create mapping of all erased nodes to their replacements + erased_to_new = {} + for old_start in old_starts: + erased_to_new[old_start] = new_start + for old_wait in old_waits: + erased_to_new[old_wait] = new_wait + + # Handle convert_element_type nodes that were fused and erased + # The bucketed operation may have a _pre_bucket op that handles dtype conversion + if fused_convert_dtypes: + # all gather bucketing may fuse in dtype conversion into the bucketing + # if so, we need to transfer hiding deps from the old dtype conversion + # to the new bucketing node + new_convert_dtypes_node = new_start.kwargs["out"] + assert isinstance(new_convert_dtypes_node, fx.Node) + assert ( + new_convert_dtypes_node.target + == torch.ops.bucketing._pre_bucket_all_gather.default + ) + + for n in fused_convert_dtypes: + erased_to_new[n] = new_convert_dtypes_node + + # Transfer all dependencies from old nodes to new nodes + self.aug_graph.transfer_erased_node_deps(erased_to_new) + + def _preserve_dependencies_with_tokens( + self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]] + ) -> None: + """ + Preserve dependencies using effect tokens and with_effects higher-order op. + + Uses the standalone token_dependencies utility for consistent behavior + across different overlap scheduling approaches. + """ + from torch._inductor.fx_passes.control_dependencies import ( + preserve_node_ordering, + ) + + preserve_node_ordering(self.graph, additional_deps) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..5770991dc233ef3dac26a8027f990c048acf3ce9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/overlap_scheduling.py @@ -0,0 +1,1324 @@ +import functools +import heapq +import itertools +import logging +import sys +from collections import Counter, defaultdict +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field +from typing import Any, Literal + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint +from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor +from torch._inductor.fx_passes.memory_estimator import MemoryTracker +from torch.fx.operator_schemas import normalize_function +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + +from torch._inductor.fx_passes.bucketing import bucket_key + +from ..pattern_matcher import stable_topological_sort + + +def estimate_runtime_analytical(n: torch.fx.Node) -> float: + """Estimate runtime using analytical roofline model for mm operations.""" + if n.target != torch.ops.aten.mm.default: + return 0.0 + import torch.utils._pytree as pytree + from torch.distributed._tools import RuntimeEstimator + + def _val(node: Any) -> Any: + if not isinstance(node, torch.fx.Node): + return node + return node.meta["val"] + + args = pytree.tree_map(_val, n.args) + kwargs = pytree.tree_map(_val, n.kwargs) + _, ms = RuntimeEstimator._roofline_estimate(n.target, args, kwargs) + return ms + + +@dataclass +class WhyNoOverlap: + """Track reasons why a collective cannot overlap with compute.""" + + compute_name: str + collective_name: str + + def __init__(self, compute_node: fx.Node, collective_node: fx.Node) -> None: + self.compute_name = compute_node.name + self.collective_name = collective_node.name + + def __call__(self, reason: str, *args: Any) -> None: + if log.isEnabledFor(logging.DEBUG): + log.debug( + "cannot overlap %s with %s: " + reason, # noqa: G003 + self.collective_name, + self.compute_name, + *args, + ) + + +def get_group_name(n: fx.Node) -> str: + """Extract the group name from a collective operation node.""" + opt_args_kwargs = normalize_function( + n.target, # type: ignore[arg-type] + args=n.args, + kwargs=n.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + _, kwargs = opt_args_kwargs + return kwargs["group_name"] + + +def get_custom_estimation( + n: fx.Node, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + override_size: int | None = None, +) -> float | None: + if custom_runtime_estimation is None: + return None + + return custom_runtime_estimation(n, override_size) + + +def estimate_collective_time( + n: fx.Node, + override_size: int | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, +) -> float: + """Estimate the runtime of a collective operation, optionally with an overridden size.""" + if ( + est := get_custom_estimation(n, custom_runtime_estimation, override_size) + ) is not None: + return est + + # Use analytical model (benchmarking is handled separately in alignment) + return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + n, override_size + ) + + +def is_compute_node(n: fx.Node) -> bool: + """ + Should we consider this node computationally expensive ? + Currently uses flop registration, but we could expand more generally. + """ + return ( + getattr(n.target, "overloadpacket", None) + in torch.utils.flop_counter.flop_registry + ) + + +def is_reduce_scatter(n: fx.Node) -> bool: + """Check if node is a reduce_scatter collective.""" + return "reduce_scatter" in str(n.target).lower() + + +def get_hint(x: int | torch.SymInt) -> int | None: + if isinstance(x, int): + return x + assert isinstance(x, torch.SymInt) + if not x.node.has_hint(): + return None + return x.node.hint + + +def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: + with dynamo_timed("collective_compute_do_bench"): + return functools.partial( + # pyrefly: ignore [bad-argument-type] + torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, + warmup=5, + ) + + +def benchmark_node_with_cache_key( + n: fx.Node, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, +) -> tuple[float, str | None]: + """Benchmark a compute node and return (runtime, cache_key).""" + assert is_compute_node(n) + + from torch._dynamo.testing import rand_strided + + # todo - skip unbacked, symbolic + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(n) + + if not success: + return 0, None + + unbacked_tensor = False + + key = f"{str(n.target)}: " + + def to_real(t: torch.Tensor) -> torch.Tensor | None: + shape = [get_hint(dim) for dim in t.shape] + stride = [get_hint(s) for s in t.stride()] + + if any(s is None for s in itertools.chain(shape, stride)): + nonlocal unbacked_tensor + unbacked_tensor = True + return None + + nonlocal key + key += f"T: {shape, stride, t.dtype} " + return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type] + + with _disable_current_modes(): + args, kwargs = torch.utils._pytree.tree_map_only( + torch.Tensor, + lambda t: to_real(t), + (args, kwargs), + ) + + if val := get_cached_node_time(key): + return val, key + + if unbacked_tensor: + return 0, key + + if ( + est := get_custom_estimation(n, custom_runtime_estimation, None) + ) is not None: + set_cached_node_time(key, est) + return est, key + + bench = get_collective_do_bench() + out = bench(lambda: n.target(*args, **kwargs)) # type: ignore[operator] + set_cached_node_time(key, out) + return out, key + + +def benchmark_node( + n: fx.Node, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, +) -> float: + return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0] + + +@functools.cache +def get_benchmark_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_cached_node_time(key: str) -> float: + return get_benchmark_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_node_time(key: str, value: float) -> None: + return get_benchmark_cache().set_value(key, value=value) + + +@dataclass +class CollectiveInfo: + """Track info about a collective operation""" + + start_node: fx.Node + wait_node: fx.Node + size_bytes: int + estimated_time_ms: float + exposed_time_ms: float # How much of this collective is still exposed + hiding_nodes: OrderedSet[fx.Node] = field(default_factory=OrderedSet) + + @property + def is_exposed(self) -> bool: + return self.exposed_time_ms != 0 + + +@dataclass +class CollBucket: + """Track information about a bucket of collectives.""" + + collectives: list[fx.Node] # Original collective starts + bucketed_start: fx.Node | None = None # After bucketing + bucketed_wait: fx.Node | None = None # After bucketing + total_bytes: int = 0 + + +def gb_to_bytes(gb: float) -> int: + """Convert gigabytes to bytes.""" + return int(gb * 1024 * 1024 * 1024) + + +class OverlapScheduler: + """ + Scheduler that reorders operations to maximize compute-collective overlap. + + The reordering is done as a scheduling pass. We maintain a priority queue of + schedulable nodes. The nodes are ranked by: + + 1) the compute node index they dominate. this allows reordering locally, such as with + parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward + with compute by deferring their waits. + + 2) whether the current node is a collective or wait that is currently exposed but has a compute + node which it could be overlapped with. + + 3) original order in the graph for stability. + + When we schedule compute nodes, we first overlap exposed in-flight collectives, then look for unscheduled + collectives that can be scheduled concurrently. + + TODO: + - experiment with other priority scores / allow other mechanisms of reorder / more strict adherence to original graph + - memory limit for deferred scheduling of reduce_scatter nodes. + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + max_in_flight_gb: float, + max_compute_pre_fetch: int, + collective_bucketing: bool, + insert_overlap_deps: bool, + compute_overlap_multipler: float, + max_coll_distance: int, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, + collective_estimator: Literal["analytical", "benchmark"], + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, + ): + self.gm = gm + self.graph = gm.graph + self.compute_overlap_multipler = compute_overlap_multipler + self.max_node_distance = max_coll_distance + self.max_in_flight_bytes: int = gb_to_bytes(max_in_flight_gb) + self.custom_runtime_estimation = custom_runtime_estimation + self.collective_bucketing = collective_bucketing + self.insert_overlap_deps = insert_overlap_deps + self.max_compute_pre_fetch = max_compute_pre_fetch + self.collective_estimator = collective_estimator + + # Build structures + stable_topological_sort(self.graph) + self.nodes = list(self.graph.nodes) + self.node_idx = {n: i for i, n in enumerate(self.nodes)} + self.node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = ( + self._collect_node_ancestors() + ) + + # Identify collectives and compute nodes + self.collective_info: dict[fx.Node, CollectiveInfo] = {} + self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet() + + # Identify compute nodes early (needed for baseline memory computation) + self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] + self.current_compute_index = 0 + + # Compute baseline memory profile from original schedule + self.original_mem_before_compute_index: list[int] = [] + self.original_peak_memory = self._compute_baseline_memory() + + # Maximum allowed peak memory = baseline + max(absolute, ratio * baseline) + # When both limits are specified, use the more permissive one + memory_increase_bytes = None + if max_memory_increase_gb is not None: + memory_increase_bytes = gb_to_bytes(max_memory_increase_gb) + if max_memory_increase_ratio is not None: + ratio_increase = int(self.original_peak_memory * max_memory_increase_ratio) + memory_increase_bytes = ( + max(memory_increase_bytes, ratio_increase) + if memory_increase_bytes is not None + else ratio_increase + ) + if memory_increase_bytes is None: + memory_increase_bytes = 0 + + self.allowed_peak_memory_bytes = ( + self.original_peak_memory + memory_increase_bytes + ) + + # Track cumulative prefetch memory at each compute index + # When we prefetch a collective at compute index i that will be used at index j, + # it adds memory from i to j, so we need to track this cumulative effect + self.cumulative_prefetch_mem_by_compute_index: list[int] = [ + 0 for _ in range(len(self.compute_nodes)) + ] + + self.memory_tracker = MemoryTracker(self.graph) + + self.wait_to_start: dict[fx.Node, fx.Node] = {} + self._identify_collectives() + self.wasted_compute = 0.0 + + self.compute_index_domination = self._calculate_compute_node_domination_index() + + # Scheduling state + self.potentially_hidden_collectives = ( + self.compute_potential_hidden_collectives() + ) + self.potentially_hidden_waits = self.compute_potential_hidden_waits() + self.in_degree = Counter(user for node in self.nodes for user in node.users) + self.ready: list[tuple[object, fx.Node]] = [] + + for node in self.nodes: + if self.in_degree[node] == 0: + heapq.heappush(self.ready, (self._compute_score(node), node)) + + self.in_flight: dict[fx.Node, CollectiveInfo] = {} # start -> info + self.in_flight_bytes = 0 + self.scheduled: OrderedSet[fx.Node] = OrderedSet() + self.max_compute_pre_fetch = max_compute_pre_fetch + + def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """Collect all ancestors for each node.""" + ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in self.nodes: + for input_node in node.all_input_nodes: + ancestors[node].add(input_node) + ancestors[node] |= ancestors[input_node] + + return ancestors + + def _compute_baseline_memory(self) -> int: + """ + Simulate the original schedule to compute baseline memory profile. + Returns the peak memory observed during simulation. + """ + baseline_tracker = MemoryTracker(self.graph) + + last_compute_max_memory = 0 + peak_memory = 0 + + for node in self.nodes: + baseline_tracker.schedule_node(node) + current_mem = baseline_tracker.current_memory_bytes + + # Record the max memory between this and previous compute node + last_compute_max_memory = max(last_compute_max_memory, current_mem) + + if is_compute_node(node): + self.original_mem_before_compute_index.append(last_compute_max_memory) + last_compute_max_memory = current_mem + + peak_memory = max(peak_memory, current_mem) + + return peak_memory + + def _prefetch_would_exceed_memory_budget(self, start_node: fx.Node) -> bool: + """ + Check if prefetching this collective would exceed memory budget at ANY compute node + between now and when it's used. + """ + info = self.collective_info[start_node] + size = info.size_bytes + + domination_index = self.compute_index_domination[start_node] + + # If off-path, assume it doesn't increase memory + if domination_index == sys.maxsize: + return False + + # check current mem + if ( + self.memory_tracker.current_memory_bytes + size + > self.allowed_peak_memory_bytes + ): + return True + + start_index = self.current_compute_index + + # then, check future mem + for compute_idx in range(start_index, domination_index): + cumulative_prefetch = self.cumulative_prefetch_mem_by_compute_index[ + compute_idx + ] + + # Check 1: Would cumulative prefetch exceed in-flight limit? + if (cumulative_prefetch + size) > self.max_in_flight_bytes: + return True + + # Check 2: Would total memory (baseline + cumulative prefetch) exceed budget? + baseline_mem = self.original_mem_before_compute_index[compute_idx] + projected = baseline_mem + cumulative_prefetch + size + + if projected > self.allowed_peak_memory_bytes: + return True + + return False + + def _update_cumulative_prefetch_memory( + self, collective: fx.Node, info: CollectiveInfo + ) -> None: + """ + Update cumulative prefetch memory for all compute indices this collective will be live. + """ + domination_index = self.compute_index_domination[collective] + if domination_index == sys.maxsize: + return + + for compute_idx in range(self.current_compute_index, domination_index): + self.cumulative_prefetch_mem_by_compute_index[compute_idx] += ( + info.size_bytes + ) + + def off_compute_path(self, n: fx.Node) -> bool: + """Check if a node is off the compute path (doesn't block any compute).""" + return self.compute_index_domination[n] == sys.maxsize + + def _identify_collectives(self) -> None: + """Identify all collective operations and process groups.""" + self.all_pgs: OrderedSet[str] = OrderedSet() + + for node in self.nodes: + if _schedulable_wait_node(node): + start = node.args[0] + coll_time_ms = estimate_collective_time( + start, custom_runtime_estimation=self.custom_runtime_estimation + ) + + info = CollectiveInfo( + start_node=start, + wait_node=node, + size_bytes=estimate_fx_collective_memory_footprint(start), + estimated_time_ms=coll_time_ms, + exposed_time_ms=coll_time_ms, # Initially fully exposed + ) + self.collective_info[start] = info + self.wait_to_start[node] = start + self.unscheduled_collectives.add(start) + self.all_pgs.add(get_group_name(start)) + + def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]: + """ + Compute the topological index of the earliest compute node each node dominates. + + Compute nodes are assigned indices based on their topological order (0, 1, 2, ...). + For each node, returns the minimum index of compute nodes it blocks/dominates. + Returns sys.maxsize if the node doesn't block any compute nodes. + """ + compute_node_index: dict[fx.Node, int] = {} + for node in self.graph.nodes: + if is_compute_node(node): + compute_node_index[node] = len(compute_node_index) + + domination_index: dict[fx.Node, int] = {} + for node in reversed(self.graph.nodes): + if node in compute_node_index: + # Compute nodes dominate themselves (return their own index) + domination_index[node] = compute_node_index[node] + else: + domination_index[node] = min( + (domination_index[succ] for succ in node.users), default=sys.maxsize + ) + + return domination_index + + def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( + self, + ) -> None: + """Align runtime estimations across ranks (compute + collectives).""" + log.info( + "Overlap scheduling: Aligning runtime estimations across all distributed ranks" + ) + + # Benchmark compute nodes + runtime_estimations_keys: list[str | None] = [] + runtime_estimations: list[float] = [] + compute_key_count = 0 + + # Also collect analytical estimations for logging + runtime_estimations_analytical: list[float] = [] + + for n in self.compute_nodes: + val, key = benchmark_node_with_cache_key(n, self.custom_runtime_estimation) + + # Analytical estimations + val_analytical = estimate_runtime_analytical(n) + runtime_estimations_analytical.append(val_analytical) + + runtime_estimations.append(val) + runtime_estimations_keys.append(key) + compute_key_count += 1 + + # Log compute estimations + from torch._inductor.fx_passes.node_runtime_estimation import ( + _log_compute_estimations, + ) + + _log_compute_estimations( + self.compute_nodes, + runtime_estimations, + runtime_estimations_analytical, + ) + + # Benchmark collectives if enabled (only CUDA events - others are deterministic) + # Skip if custom estimation is provided for collectives + collective_nodes: list[fx.Node] = [] + benchmarked_collective_nodes: list[ + fx.Node + ] = [] # Track which were actually benchmarked + if self.collective_estimator == "benchmark": + from torch._inductor.fx_passes.node_runtime_estimation import ( + benchmark_collective_with_cuda_events, + ) + + collective_nodes = [ + info.start_node for info in self.collective_info.values() + ] + + # Benchmark CUDA events (non-deterministic, needs alignment) + # Skip collectives with custom estimation + for n in collective_nodes: + if ( + get_custom_estimation(n, self.custom_runtime_estimation, None) + is not None + ): + continue + + # Benchmark actual size + cuda_val, cuda_key = benchmark_collective_with_cuda_events(n, nruns=5) + if cuda_val is not None: + runtime_estimations.append(cuda_val) + runtime_estimations_keys.append(cuda_key) + benchmarked_collective_nodes.append(n) + + # Single all_gather and compute medians + import torch.distributed as dist + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch.distributed.distributed_c10d import _get_default_group + + world_size = dist.get_world_size() + pg = _get_default_group() + + with unset_fake_temporarily(): + gathered_runtime_estimations: list[list[float]] = [ + [] for _ in range(world_size) + ] + dist.all_gather_object( + gathered_runtime_estimations, runtime_estimations, pg + ) + median_runtime_estimations = torch.median( + torch.tensor(gathered_runtime_estimations), dim=0 + ).values.tolist() + + # Cache medians + collective_keys = [] + collective_medians = [] + for idx, (key, median_runtime_estimation) in enumerate( + zip(runtime_estimations_keys, median_runtime_estimations) + ): + if key is None: + continue + if idx < compute_key_count: + # Compute node + set_cached_node_time(key, median_runtime_estimation) + else: + # Collective CUDA event benchmark + from torch._inductor.fx_passes.node_runtime_estimation import ( + set_cached_runtime, + ) + + set_cached_runtime(key, median_runtime_estimation) + + # Update CollectiveInfo with aligned benchmark + coll_idx = idx - compute_key_count + coll_node = benchmarked_collective_nodes[coll_idx] + info = self.collective_info[coll_node] + info.estimated_time_ms = median_runtime_estimation + info.exposed_time_ms = median_runtime_estimation + + collective_keys.append(key) + collective_medians.append(median_runtime_estimation) + + # Log benchmarks with analytical comparisons + if collective_keys: + from torch._inductor.fx_passes.node_runtime_estimation import ( + _log_collective_benchmarks, + ) + + _log_collective_benchmarks( + benchmarked_collective_nodes, + collective_keys, + collective_medians, + world_size, + ) + + log.info("Overlap scheduling: Runtime estimations aligned") + + def run(self) -> torch.fx.GraphModule: + """Run the scheduling algorithm.""" + # All ranks must make identical decisions on overlap reordering, + # Thus we must have identical runtime estimations across ranks. + # For now we do benchmarking only for compute nodes. + self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks() + + while self.ready: + if self._should_force_wait_for_memory(): + self._force_oldest_wait() + continue + + _, node = heapq.heappop(self.ready) + + # we don't always remove nodes from the heap when we schedule them + if node in self.scheduled: + continue + + if node.op == "placeholder": + self._schedule(node) + elif node in self.collective_info: + self._handle_collective_start(node) + elif _schedulable_wait_node(node): + self._handle_wait(node) + else: + self._handle_compute_or_other(node) + + self._reorder_graph() + + if self.collective_bucketing: + self._bucket_collectives() + elif self.insert_overlap_deps: + # If not bucketing, add effect tokens to preserve hiding dependencies + self._add_effect_tokens_for_overlap() + + return self.gm + + def _add_effect_tokens_for_overlap(self) -> None: + """ + Add effect tokens to preserve hiding dependency relationships when not bucketing. + + This ensures that communication-compute overlap is preserved through effect tokens + when overlap preserving bucketing is not enabled. + """ + from torch._inductor.fx_passes.control_dependencies import ( + preserve_node_ordering, + ) + + # Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node + additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + + for start_node, info in self.collective_info.items(): + if info.is_exposed: + continue + for hn in info.hiding_nodes: + # Compute depends on collective start (compute must wait for collective to start) + additional_deps[hn].add(start_node) + # Wait depends on compute (wait must wait for compute to finish) + additional_deps[info.wait_node].add(hn) + + # Apply effect tokens to preserve these dependencies + if additional_deps: + preserve_node_ordering(self.graph, additional_deps) + + def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None: + """Get runtime estimation for a node in ms. Returns None if no estimation is available.""" + + # TODO: non custom estimation of aten nodes, potentially requires notion of fusion group + if is_compute_node(node): + return benchmark_node(node, self.custom_runtime_estimation) + + if self.custom_runtime_estimation is None: + return None + + return self.custom_runtime_estimation(node, None) + + def _reduce_exposed_time_of_in_flight_collectives( + self, + node: fx.Node, + available_compute: float, + exclude_pg: str | None = None, + ) -> dict[str, float]: + """ + Reduce exposed time of in-flight collectives using available compute time. + + Collectives on different process groups can overlap simultaneously with the same + compute, so we track remaining time separately per PG. + """ + # Initialize all PGs with full available compute (except excluded) + remaining_time_per_pg: dict[str, float] = { + pg: available_compute for pg in self.all_pgs if pg != exclude_pg + } + + for start_node, info in self.in_flight.items(): + if info.exposed_time_ms == 0: + continue + + pg_name = get_group_name(start_node) + if pg_name == exclude_pg: + continue + + pg_remaining = remaining_time_per_pg[pg_name] + if pg_remaining <= 0: + continue + + overlap_amount = min(info.exposed_time_ms, pg_remaining) + info.exposed_time_ms -= overlap_amount + remaining_time_per_pg[pg_name] -= overlap_amount + info.hiding_nodes.add(node) + + return remaining_time_per_pg + + def _handle_compute_or_other(self, node: fx.Node) -> None: + """Handle scheduling compute or other nodes and attempt to overlap with collectives.""" + runtime_estimate = self.get_non_collective_runtime_estimate(node) + + # TODO: we could consider skipping overlapping for overlapable, unary chains to collectives. + # using these nodes for overlap prevents bucketing. potentially if chain time < latency + if runtime_estimate is None: + assert not is_compute_node(node), "should have estimate for compute nodes" + self._schedule(node) + return + + available_compute = runtime_estimate * self.compute_overlap_multipler + + # First, reduce exposed time of in-flight collectives (per PG) + remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives( + node, available_compute + ) + # Then, schedule new collectives for overlap + self._schedule_collectives_for_overlap(node, remaining_time_per_pg) + self._schedule(node) + + if is_compute_node(node): + self.current_compute_index += 1 + + def _schedule(self, node: fx.Node) -> None: + """Schedule a node.""" + assert node not in self.scheduled + assert all(n in self.scheduled for n in node.all_input_nodes) + self.scheduled.add(node) + self.memory_tracker.schedule_node(node) + + log.debug( + "Scheduled node %s: current_memory=%d bytes, total_scheduled=%d", + node.name, + self.memory_tracker.get_current_memory_bytes(), + len(self.scheduled), + ) + + for user in node.users: + self.in_degree[user] -= 1 + if self.in_degree[user] == 0: + heapq.heappush(self.ready, (self._compute_score(user), user)) + + def _compute_score(self, node: fx.Node) -> object: + """Compute priority score for a node""" + + if _schedulable_wait_node(node): + info = self.collective_info[self.wait_to_start[node]] + # defer waits locally if they are exposed. + compute_local_priority = int(info.is_exposed) + else: + # if we're scheduling this collective via its queue, then it was not + # pre-fetched. we might as well maximize overlap for the + # local, non-mm nodes prior to the next compute node. + if self.in_overlappable_collective_unary_chain(node): + compute_local_priority = -1 + else: + compute_local_priority = 0 + + return ( + self.compute_index_domination[node], # what index compute it blocks + compute_local_priority, # collective_start=-1, wait=1, or neither=0 + self.node_idx[node], # Original order for stability + ) + + @staticmethod + def is_cheap_fn(node: fx.Node) -> bool: + return getattr(node.target, "is_view", False) or torch.Tag.pointwise in getattr( + node.target, "tags", () + ) + + def in_overlappable_collective_unary_chain(self, curr: fx.Node) -> bool: + while True: + if len(curr.users) != 1: + return False + + user = next(iter(curr.users)) + if len(user.all_input_nodes) != 1: + return False + + if user in self.unscheduled_collectives: + return True + + if not self.is_cheap_fn(user): + return False + + curr = user + + return False + + def _should_force_wait_for_memory(self) -> bool: + """Check if we need to force a wait due to memory pressure""" + if not self.in_flight: + return False + + return self.in_flight_bytes >= self.max_in_flight_bytes + + def _force_oldest_wait(self) -> None: + """Schedule the oldest in flight wait""" + self._handle_wait(self._get_oldest_wait()) + + def _handle_collective_start(self, node: fx.Node) -> None: + """Handle scheduling a collective start.""" + info = self.collective_info[node] + + if self.should_assume_bucketed(node): + latency = estimate_collective_time( + node, 0, custom_runtime_estimation=self.custom_runtime_estimation + ) + assert latency <= info.exposed_time_ms + info.exposed_time_ms = info.exposed_time_ms - latency + + self.in_flight[node] = info + self.in_flight_bytes += info.size_bytes + self.unscheduled_collectives.discard(node) + self._schedule(node) + + def _handle_wait(self, node: fx.Node) -> None: + """Handle scheduling a wait.""" + assert node in self.wait_to_start + coll_start = self.wait_to_start[node] + assert coll_start in self.in_flight + + # Scheduling a wait of a collective also forces the wait + # of every node enqueued prior to the collective on the + # same process group + group_name = get_group_name(coll_start) + to_schedule: list[fx.Node] = [] + for in_flight_coll in self.in_flight: + if in_flight_coll == coll_start: + break + if get_group_name(in_flight_coll) == group_name: + to_schedule.append(in_flight_coll) + + for coll_to_schedule in to_schedule: + self._handle_wait(self.collective_info[coll_to_schedule].wait_node) + + # If we are waiting on an exposed collective, use this time to + # overlap on other PGs. + info = self.collective_info[coll_start] + if info.exposed_time_ms > 0: + exposed_time = info.exposed_time_ms + exclude_pg = group_name + + remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives( + node, exposed_time, exclude_pg=exclude_pg + ) + self._schedule_collectives_for_overlap( + node, remaining_time_per_pg, exclude_pg=exclude_pg + ) + + self.in_flight_bytes -= self.in_flight[coll_start].size_bytes + del self.in_flight[coll_start] + self._schedule(node) + + def _schedule_collectives_for_overlap( + self, + overlap_node: fx.Node, + remaining_time_per_pg: dict[str, float], + exclude_pg: str | None = None, + ) -> None: + """Opportunistically schedule collectives that can be hidden by available overlap time.""" + if not remaining_time_per_pg or all( + t <= 0 for t in remaining_time_per_pg.values() + ): + return + + overlap_node_ancestors = self.node_ancestors[overlap_node] + + # Compile candidates - limit by distance to bound compile time + candidates = [] + for i, collective in enumerate(self.unscheduled_collectives): + if i > self.max_node_distance: + break + + pg_name = get_group_name(collective) + if pg_name == exclude_pg: + continue + + if ( + not self.off_compute_path(collective) + and self.compute_index_domination[collective] + - self.current_compute_index + > self.max_compute_pre_fetch + ): + continue + + candidates.append(collective) + + # Sort candidates prioritizing: + # 1. reduce_scatter operations (reduce memory pressure) + # 2. Earlier domination index + # 3. Original order for stability + candidates.sort( + key=lambda n: ( + not is_reduce_scatter(n), # reduce_scatter first + self.compute_index_domination[n], + self.node_idx[n], + ), + ) + + for collective in candidates: + pg_name = get_group_name(collective) + pg_available_time = remaining_time_per_pg[pg_name] + + if pg_available_time <= 0: + continue + + why = WhyNoOverlap(overlap_node, collective) + info = self.collective_info[collective] + + if ( + collective in overlap_node_ancestors + or overlap_node in self.node_ancestors[collective] + ): + why("dependency conflict") + continue + + # Check if prefetching would exceed memory budget + if self._prefetch_would_exceed_memory_budget(collective): + why("prefetch would exceed memory budget") + continue + + # Try to free memory by forcing hidden waits + while ( + self.in_flight + and (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes + and self._wait_is_hidden(self._get_oldest_wait(), overlap_node) + ): + self._force_oldest_wait() + + if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes: + why("in-flight memory limit") + continue + + # Check if we can reach this collective without scheduling compute, other collectives, or waits + path = self._find_schedulable_path(collective, overlap_node, why) + if path is None: + continue + + log.debug( + "Overlapping collective %s with node %s: coll_domination=%d, current_depth=%d", + collective.name, + overlap_node.name, + self.compute_index_domination[collective], + self.current_compute_index, + ) + + # TODO: We previously tracked path compute time and added it back to available + # overlap time. With per-PG tracking this is complex: if there were in-flight + # collectives on one PG but not another, we can't add path time back to the PG + # that wasn't in-flight + + # Schedule path and collective + self._schedule_path_to_collective(path, overlap_node) + self._handle_collective_start(collective) + self._update_cumulative_prefetch_memory(collective, info) + + # Update exposed time for this collective + overlap_amount = min(pg_available_time, info.exposed_time_ms) + info.exposed_time_ms -= overlap_amount + info.hiding_nodes.add(overlap_node) + + # Update available time for this PG + remaining_time_per_pg[pg_name] -= overlap_amount + + if sum(remaining_time_per_pg.values()) == 0: + break + + if remaining_time_per_pg: + self.wasted_compute += min(remaining_time_per_pg.values()) + + def _find_schedulable_path( + self, target: fx.Node, curr_overlap_node: fx.Node | None, why: WhyNoOverlap + ) -> OrderedSet[fx.Node] | None: + """Find path to target by collecting unscheduled dependencies.""" + # Get unscheduled ancestors + unscheduled_ancestors = self.node_ancestors[target] - self.scheduled + + # only schedule non distributed, non compute nodes + for node in unscheduled_ancestors: + if is_compute_node(node): + why("path blocked by compute node %s", node.name) + return None + + if node in self.unscheduled_collectives: + why("path blocked by unscheduled collective %s", node.name) + return None + + # if we schedule a wait tensor whose start collective is hidden by the + # current compute node we are scheduling, then we are effectively exposing it. + # similarly, dont schedule a wait of a collective that could be otherwise hidden, + # thus forcing it to be exposed. + # however, if it is already hidden it's fine to schedule it + if _schedulable_wait_node(node): + info = self.collective_info[self.wait_to_start[node]] + # Allow if fully hidden by other nodes + if not info.is_exposed and curr_overlap_node not in info.hiding_nodes: + continue + + why( + "path blocked by wait node %s (exposed=%s, hiding_nodes=%s)", + node.name, + info.is_exposed, + curr_overlap_node in info.hiding_nodes, + ) + + # Skip c10 ops and dtensor shard ops - they should be scheduled via main loop + target_str = str(node.target) + if "c10" in target_str or "_dtensor" in target_str: + log.debug( + "Skipping c10/dtensor op %s in path to collective", + node.name, + ) + return None + + return unscheduled_ancestors + + def should_assume_bucketed(self, node: fx.Node) -> bool: + """ + Check if there's an in-flight collective that can be bucketed with the given node. If so, assume they will bucket. + This is a optimistic heuristic to account for latency reduction with bucketing. The two nodes may not get bucketed. + """ + if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency: + return False + + key = bucket_key(node, mode="custom_ops_multidtype") + if key is None: + return False + + for in_flight_coll in self.in_flight: + if bucket_key(in_flight_coll, mode="custom_ops_multidtype") == key: + return True + + return False + + def _get_oldest_wait(self) -> fx.Node: + oldest_start = next(iter(self.in_flight)) + return self.collective_info[oldest_start].wait_node + + def _wait_is_hidden( + self, wait_node: fx.Node, overlap_node: fx.Node | None = None + ) -> bool: + assert is_wait_tensor(wait_node) + info = self.collective_info[self.wait_to_start[wait_node]] + return not info.is_exposed and overlap_node not in info.hiding_nodes + + def _schedule_path_to_collective( + self, path: OrderedSet[fx.Node], curr_overlap_node: fx.Node + ) -> None: + """Schedule all nodes needed to reach a collective.""" + + assert all(n not in self.scheduled for n in path) + for node in sorted(path, key=lambda n: self.node_idx[n]): + assert not (is_compute_node(node) or node in self.unscheduled_collectives) + if _schedulable_wait_node(node): + # When we schedule wait tensors, we also force realization of all + # collectives enqueued prior to their corresponding collective. + # It's possible the scheduling of one wait tensor here has forced + # another in the path. If so, skip scheduling it. + if node in self.scheduled: + continue + + info = self.collective_info[self.wait_to_start[node]] + assert curr_overlap_node not in info.hiding_nodes + self._handle_wait(node) + continue + + self._schedule(node) + + def reorder_graph(self) -> None: + output_node = self.graph.output_node() + for node in self.scheduled: + if node.op == "placeholder": + continue + output_node.prepend(node) + self.graph.lint() + + def _reorder_graph(self) -> None: + """Reorder graph based on schedule.""" + exposed = [ + c + for c in self.collective_info.values() + if c.exposed_time_ms == c.estimated_time_ms + ] + + potentially_hidden_collectives = self.compute_potential_hidden_collectives() + bad_exposed = [ + c for c in exposed if c.start_node in potentially_hidden_collectives + ] + + # Compute total exposed and potential exposed time + total_exposed = sum(c.exposed_time_ms for c in self.collective_info.values()) + hideable_exposed_ms = sum( + self.collective_info[c].exposed_time_ms + for c in potentially_hidden_collectives + ) + total_potential_exposed = sum( + c.estimated_time_ms for c in self.collective_info.values() + ) + + counters["inductor"]["overlap_scheduling_exposed"] += len(exposed) + counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed) + counters["inductor"]["overlap_scheduling_potentially_hidden"] += len( + potentially_hidden_collectives + ) + counters["inductor"]["overlap_original_mem"] = self.original_peak_memory + counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory + + log.info( + "Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, " + "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes, " + "total_exposed_ms=%.2f, hideable_exposed_ms=%.2f, total_potential_exposed_ms=%.2f, " + "wasted_compute_ms=%.2f", + len(exposed), + len(bad_exposed), + len(potentially_hidden_collectives), + self.original_peak_memory, + self.memory_tracker.peak_memory, + total_exposed, + hideable_exposed_ms, + total_potential_exposed, + self.wasted_compute, + ) + + self.reorder_graph() + + def _bucket_collectives(self) -> None: + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + graph=self.graph, + collective_info=self.collective_info, + scheduled=self.scheduled, + max_bucket_memory_gb=2.0, # Could make this configurable + max_coll_distance=self.max_node_distance, + insert_overlap_deps=self.insert_overlap_deps, + ) + bucketer.bucket_collectives() + + def compute_potential_hidden_nodes( + self, nodes_to_check: Iterable[fx.Node] + ) -> dict[fx.Node, fx.Node]: + """ + Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node + """ + + def could_be_hidden(start: fx.Node) -> fx.Node | None: + for compute_node in self.compute_nodes: + if ( + start not in self.node_ancestors[compute_node] + and compute_node not in self.node_ancestors[start] + ): + return compute_node + + return None + + # TODO: We could potentially limit compute nodes per overlap time, + # today, this is optimistic, and just serves to avoid deferring + # collectives/waits that have no possible overlap as well as for analysis of how + # successfully we hid compute + potentially_hidden = {} + for node in nodes_to_check: + if mm := could_be_hidden(node): + potentially_hidden[node] = mm + + return potentially_hidden + + def compute_potential_hidden_collectives(self) -> dict[fx.Node, fx.Node]: + """Compute which collective operations could be hidden by compute.""" + return self.compute_potential_hidden_nodes(self.collective_info.keys()) + + def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]: + """Compute which wait operations could be hidden by compte.""" + wait_nodes = [info.wait_node for info in self.collective_info.values()] + return self.compute_potential_hidden_nodes(wait_nodes) + + +def schedule_overlap_bucketing( + gm: torch.fx.GraphModule, + max_in_flight_gb: float = 5, + max_compute_pre_fetch: int = 200, + collective_bucketing: bool = False, + insert_overlap_deps: bool = False, + compute_overlap_multipler: float = 1.0, + max_coll_distance: int = 200, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + collective_estimator: Literal["analytical", "benchmark"] = "analytical", + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, +) -> torch.fx.GraphModule: + """Schedule nodes to maximize compute-collective overlap. + + Args: + gm: Input graph module to optimize. + max_in_flight_gb: Maximum GB of concurrent collective data. Too much in flight memory + can cause memory fragmentation within the CUDA Caching Allocator. + max_compute_pre_fetch: Maximum mm nodes to pre fetch. Note: should already be limited by max_in_flight_gb and + max_memory_increase_gb + collective_bucketing: Enable overlap-preserving collective bucketing. + insert_overlap_deps: Insert overlap dependencies using control deps operator. This should only be used if + compiling with inductor, or for subsequent passes before removing the ops prior to execution. + compute_overlap_multipler: Scale factor for compute time used to hide collectives. This can be used + to address over or under aggressive overlapping. + max_coll_distance: Maximum pre fetch or bucketing candidates. Mainly intended for compile time + custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node. + If None, uses default estimations. This is currently limited to collectives and compute nodes. + collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas, + "benchmark" uses CUDA events with power-of-2 rounding and interpolation. + max_memory_increase_gb: Maximum GB increase above baseline memory (absolute cap). If None, no absolute limit. + max_memory_increase_ratio: Maximum increase as ratio of baseline peak memory. If None, no ratio limit. + Uses minimum of absolute and ratio limits when both are specified. + """ + return OverlapScheduler( + gm, + compute_overlap_multipler=compute_overlap_multipler, + max_in_flight_gb=max_in_flight_gb, + max_coll_distance=max_coll_distance, + max_compute_pre_fetch=max_compute_pre_fetch, + custom_runtime_estimation=custom_runtime_estimation, + collective_bucketing=collective_bucketing, + insert_overlap_deps=insert_overlap_deps, + collective_estimator=collective_estimator, + max_memory_increase_gb=max_memory_increase_gb, + max_memory_increase_ratio=max_memory_increase_ratio, + ).run() + + +def schedule_overlap_bucketing_from_inductor_configs( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Schedule nodes to maximize compute-collective overlap using inductor configs. + + Reads configuration from torch._inductor.config.aten_distributed_optimizations + and calls schedule_overlap_bucketing with those settings. + """ + from torch._inductor import config + + dist_opts = config.aten_distributed_optimizations + + kwargs: dict[str, object] = {} + + config_keys = ( + "collective_bucketing", + "max_compute_pre_fetch", + "custom_runtime_estimation", + "insert_overlap_deps", + "collective_estimator", + "max_memory_increase_gb", + "max_memory_increase_ratio", + "compute_overlap_multipler", + "max_in_flight_gb", + "max_coll_distance", + ) + for key in config_keys: + if (val := getattr(dist_opts, key, None)) is not None: + kwargs[key] = val + + return schedule_overlap_bucketing(gm, **kwargs) # type: ignore[arg-type] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..556b32562dcd5533e526aa02c273ac7aca87b2e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,945 @@ +import functools +import itertools +import operator +import typing +from collections.abc import Callable, Sequence +from typing import Any + +import torch +import torch._inductor.runtime.runtime_utils +from torch import Tensor +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor import utils +from torch._inductor.autoheuristic.autoheuristic import ( + AHContext, + AutoHeuristic, + LocalFeedback, +) +from torch._inductor.autoheuristic.autoheuristic_utils import ( + context_add_strides, + context_add_using_tf32, + pad_mm_operations, + pad_mm_precondition, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch + +from ...utils._triton import has_triton +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) + + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match: Match, kwarg_names: Sequence[str]) -> list[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args( + *arg_names: str, +) -> Callable[[Callable[..., Any]], Callable[[Match], Any]]: + def decorator(func: Callable[..., Any]) -> Callable[[Match], Any]: + def wrapper(match: Match) -> Any: + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + return get_alignment_size_dtype(x.dtype) + + +def get_alignment_size_dtype(dtype: torch.dtype) -> int: + if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16: + return 8 + elif dtype == torch.float32 or dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return (a.is_cuda and b.is_cuda) or (a.is_xpu and b.is_xpu) + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Tensor | None) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + # pyrefly: ignore [missing-attribute] + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimensions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + # pyrefly: ignore [missing-attribute] + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: int | torch.SymInt, alignment_size: int) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + + # ignore dim that can be squeezed away + if x == 1: + return 0 + + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def pad_addmm( + input: Tensor | None, + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta: float = 1.0, + alpha: float = 1.0, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + # for paddings, dim order is reversed for some reasons + # and for every dim, we need to specify left and right padding + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def addmm_replace( + input: Tensor | None, + mat1: Tensor, + mat2: Tensor, + beta: float = 1.0, + alpha: float = 1.0, +) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # we have experienced some large perf hits in this case, even in bandwidth bound regimes + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0)) + ): # doesn't repro on h100s: + return True + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we estimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.cache +def get_pad_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key: str) -> bool: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_should_pad(key: str, value: bool) -> None: + return get_pad_cache().set_value(key, value=value) + + +def get_cached_base_mm_benchmark_time(key: str) -> float: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_base_mm_benchmark_time(key: str, value: float) -> None: + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Tensor | None = None, + is_base_time_key: bool = False, +) -> str: + def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]: + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None + if mat1.dtype != torch.float32 + else torch.backends.cuda.matmul.allow_tf32 or torch.backends.mkldnn.allow_tf32 + ) + + def fmt_pad(name: str) -> str | None: + if is_base_time_key: + return None + return f"exclude_pad:{should_exclude_padding_time(match, name)}" + + key = ( + tensor_key(mat1), + tensor_key(mat2), + fmt_pad("mat1"), + fmt_pad("mat2"), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + key = str(key) + if is_base_time_key: + key = f"base mm time: {key}" + return key + + +def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node: + if node.op is operator.getitem: + return get_non_view_def(node.args[0]) # type: ignore[arg-type] + + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and utils.is_view(node.target) + ): + return get_non_view_def(node.all_input_nodes[0]) + + return node + + +def should_exclude_padding_time(match: Match, arg_name: str) -> bool: + node_def = get_non_view_def(match.kwargs[arg_name]) + + # constant padding converts tensors to contiguous so even if the input tensor + # can be planned layout transform is not free. TODO - way to pad and preserve layout ? + if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): + return False + + # TODO - see issue https://github.com/pytorch/pytorch/issues/128889 + # We would only able to completely plan these out if we were only doing + # first dimension padding. non-first we would still need a copy + # because these outputs are fixed dense. + cannot_plan_output = [ + aten.mm.default, + aten.convolution.default, + aten.convolution_backward.default, + aten.bmm.default, + aten.addmm.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_efficient_attention.default, + ] + + if node_def.target in cannot_plan_output: + return False + + if ( + node_def.target is aten.cat.default + and len(node_def.all_input_nodes) + > torch._inductor.config.max_pointwise_cat_inputs + ): + return False + + # optimistically assume we should be able to memory plan away + # all non inputs + return node_def.op != "placeholder" + + +def should_pad(key: str, ori_time: float, pad_time: float) -> bool: + multiplier = 1.1 + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options: + multiplier = torch._inductor.config.post_grad_fusion_options[ + "shape_padding_multiplier" + ].get("value", 1.1) + counters["inductor"]["shape_padding_multiplier"] += 1 + should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier + set_cached_should_pad(key, should_pad) + return should_pad + + +def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool: + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and (torch.xpu.is_available() or torch.cuda.get_device_capability() < (9, 0)) + ): # doesn't repro on h100s: + return True + return False + + +def should_pad_bench(*args: Any, **kwargs: Any) -> bool: + with dynamo_timed( + "pad_mm_benchmark", + log_pt2_compile_event=False, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + return _should_pad_bench(*args, **kwargs) + + +def get_do_bench() -> Callable[[Callable[[], Any]], float]: + with dynamo_timed("pad_mm_benchmark_get_do_bench"): + return functools.partial( + # pyrefly: ignore [bad-argument-type] + torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, + warmup=5, + ) + + +def _should_pad_bench( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Tensor | None = None, +) -> bool: + do_bench = get_do_bench() + + m_padded_length = 0 + n_padded_length = 0 + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + elif op is torch.ops.aten.bmm: + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + def realize_symbols( + ds: torch.Size | tuple[torch.SymInt, ...], + ) -> list[int]: + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + + if torch._inductor.config.force_shape_pad: + return True + + if torch._inductor.config.deterministic: + # In deterministic mode, don't benchmark for pad-mm and assumes + # no padding. + # + # Check the deterministic mode after 'force_shape_pad' + # so unit test relying on force_shape_pad should still pass + return False + + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + + if not has_triton(): + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(match, mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + # pyrefly: ignore [bad-argument-type] + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + + # since we key on whether or not the inputs can be memory planned, set cache for the + # original time which is unaffected by whether or not the input can be planned + ori_time_key = should_pad_bench_key( + match, mat1, mat2, op, input, is_base_time_key=True + ) + ori_time = get_cached_base_mm_benchmark_time(ori_time_key) + if ori_time is None and op is torch.ops.aten.addmm and input is not None: + # realize bias for addmm + input = realize_tensor(input) + + mat1_pad = mat1 + mat2_pad = mat2 + + is_bmm = op is torch.ops.aten.bmm + + mat1_pre_padded = should_exclude_padding_time(match, "mat1") + fns = [] + if mat1_pre_padded and (m_padded_length or k_padded_length): + mat1_pad = pad_mat1( + mat1_pad, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0) + else: + mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0) + + fns.append(write_pad) + + mat2_pre_padded = should_exclude_padding_time(match, "mat2") + if mat2_pre_padded and (k_padded_length or n_padded_length): + mat2_pad = pad_mat2( + mat2_pad, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0) + else: + mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0) + + fns.append(write_pad) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and (input.is_cuda or input.is_xpu): + input_pad = torch.randn_like(input) + fns.append( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + elif op is torch.ops.aten.mm: + fns.append( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + else: + fns.append( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + + def orig_bench_fn(): + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + op(mat1, mat2) + else: + op(input, mat1, mat2) + + def pad_bench_fn(): + for fn in fns: + fn() + + if ( + torch._inductor.config.run_autoheuristic("pad_mm") + and op is torch.ops.aten.mm + ): + ah_should_pad = run_autoheuristic( + mat1, + mat2, + orig_bench_fn, + pad_bench_fn, + m_padded_length, + k_padded_length, + n_padded_length, + do_bench, + mat1_pre_padded, + mat2_pre_padded, + ori_time, + ori_time_key, + key, + ) + if ah_should_pad is not None: + return ah_should_pad + + if ori_time is None: + ori_time = do_bench(orig_bench_fn) + set_cached_base_mm_benchmark_time(ori_time_key, ori_time) + + pad_time = do_bench(pad_bench_fn) + + counters["inductor"]["pad_mm_bench"] += 1 + return should_pad(key, ori_time, pad_time) + + +def get_context( + mat1: Tensor, + mat2: Tensor, + mat1_pre_padded: bool, + mat2_pre_padded: bool, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> AHContext: + context = AHContext() + + context.add_feature("m", mat1.shape[0]) + context.add_feature("k", mat1.shape[1]) + context.add_feature("n", mat2.shape[1]) + + context_add_strides(context, "mat1", mat1.stride()) + context_add_strides(context, "mat2", mat2.stride()) + + context.add_feature("m_padded_length", m_padded_length) + context.add_feature("k_padded_length", k_padded_length) + context.add_feature("n_padded_length", n_padded_length) + + context.add_feature("mat1_align_size", get_alignment_size(mat1)) + context.add_feature("mat2_align_size", get_alignment_size(mat2)) + + context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True) + + context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True) + context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True) + + context_add_using_tf32(context, mat1.dtype) + return context + + +def run_autoheuristic( + mat1: Tensor, + mat2: Tensor, + orig_bench_fn: Callable[[], None], + pad_bench_fn: Callable[[], None], + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + do_bench: Callable[[Callable[[], Any]], float], + mat1_pre_padded: bool, + mat2_pre_padded: bool, + ori_time: float, + ori_time_key: str, + key: str, +) -> bool | None: + def feedback_fn( + choice: str, + ) -> float | None: + if choice == orig_choice: + return do_bench(orig_bench_fn) + elif choice == pad_choice: + return do_bench(pad_bench_fn) + return None + + def fallback() -> str: + return "autotune" + + orig_choice = "orig" + pad_choice = "pad" + choices = [orig_choice, pad_choice] + feedback = LocalFeedback(feedback_fn) # type: ignore[arg-type] + context = get_context( + mat1, + mat2, + mat1_pre_padded, + mat2_pre_padded, + m_padded_length, + k_padded_length, + n_padded_length, + ) + name = "pad_mm" + autoheuristic = AutoHeuristic( + fallback=fallback, + choices=choices, + feedback=feedback, + context=context, + name=name, + augment_context=pad_mm_operations(), + precondition=pad_mm_precondition, + ) + choice = autoheuristic.get_choice() + choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} + ah_should_pad = choice2should_pad.get(choice) + + if torch._inductor.config.collect_autoheuristic(name): + ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) + ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) + + # if precondition is not satisfied, autoheuristic does not collect data + if ah_ori_time is not None and ah_pad_time is not None: + if ori_time is None: + set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) + return should_pad(key, ah_ori_time, ah_pad_time) + if ah_should_pad is not None: + set_cached_should_pad(key, ah_should_pad) + return ah_should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.mm + ) + + +def pad_mat1( + mat1: Tensor, *, m_padded_length: int, k_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or m_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, k_padded_length, 0, m_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat1, pad_arg) + else: + return mat1 + + +def pad_mat2( + mat2: Tensor, *, k_padded_length: int, n_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or n_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, n_padded_length, 0, k_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat2, pad_arg) + else: + return mat2 + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + res = aten.mm(mat1, mat2) + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + return pad_mm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.bmm + ) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=True, + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=True, + ) + res = aten.bmm(mat1, mat2) + if m_padded_length != 0: + res = res[:, :-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :, :-n_padded_length] + return res + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + return pad_bmm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +@functools.cache +def _pad_mm_init() -> None: + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + elif torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + name = pattern.__name__ + + gen_register_replacement( + f"{name}_training", + pattern, + replacement, + args, + # pyrefly: ignore [bad-argument-type] + joint_fwd_bwd, + # pyrefly: ignore [bad-argument-type] + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + + gen_register_replacement( + f"{name}_inference", + pattern, + replacement, + args, + # pyrefly: ignore [bad-argument-type] + fwd_only, + # pyrefly: ignore [bad-argument-type] + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..4a350b81bbecb47b044c3805bd8af82b04531d45 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1923 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters +from torch._inductor import comms +from torch._inductor.virtualized import ops # noqa: F401 +from torch._logging import trace_structured +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.utils._ordered_set import OrderedSet + +from .. import config, ir, pattern_matcher # noqa: F401 +from ..codegen.common import custom_backend_passes +from ..comms import remove_fsdp2_unsharded_param_graph_input_usage +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + fwd_only, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MultiOutputPattern, + MULTIPLE, + PatternMatcherPass as PatternMatcherPassBase, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) +from ..utils import ( + decode_device, + get_all_devices, + get_gpu_type, + is_gpu, + is_pointwise_use, + OPTIMUS_EXCLUDE_POST_GRAD, +) +from ..virtualized import V +from .b2b_gemm import B2B_GEMM_PASS +from .ddp_fusion import fuse_ddp_communication +from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import micro_pipeline_tp_pass +from .pre_grad import is_same_dict, save_inductor_dict +from .reinplace import reinplace_inplaceable_ops +from .split_cat import POST_GRAD_PATTERNS + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +PatternMatcherPass = functools.partial( + PatternMatcherPassBase, subsystem="post_grad_passes" +) + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="post_grad_passes", + ) + + if not torch._dynamo.config.skip_fsdp_hooks: + remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) + + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass( + reorder_for_locality + ) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + post_grad_custom_pre_pass + ) + + if torch._C._has_mkldnn: + if ( + config.cpp.enable_grouped_gemm_template + and config.max_autotune + and "CPP" in config.max_autotune_gemm_backends + ): + from .mkldnn_fusion import grouped_gemm_pass + + grouped_gemm_pass(gm.graph) + + if config.cpp.enable_concat_linear: + from .quantization import concat_linear_woq_int4 + + # Concat linear optimization for WOQ int4 + concat_linear_woq_int4(gm) + + if config.pattern_matcher: + lazy_init() + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + functools.partial(group_batch_fusion_passes, pre_grad=False) + ) + GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + GraphTransformObserver(gm, "remove_assert_ops").apply_graph_pass( + remove_assert_ops + ) + for i, patterns in enumerate(pass_patterns): + GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass( + patterns.apply + ) + for pass_name in config.post_grad_fusion_options: + # skip all patterns for group batch fusions or quantization patterns + if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD: + continue + pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + GraphTransformObserver(gm, pass_name).apply_graph_pass( + pattern_matcher_pass.apply + ) + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_post_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + if config.b2b_gemm_pass: + B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] + + if config._micro_pipeline_tp: + micro_pipeline_tp_pass(gm.graph) + + if config._fuse_ddp_communication: + GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass( + lambda graph: fuse_ddp_communication( + graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ) + ) + + if post_grad_custom_post_pass := config.post_grad_custom_post_pass: + GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass( + post_grad_custom_post_pass + ) + + GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort) + + GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass( + move_constructors_to_gpu + ) + + fake_tensor_updater.incremental_update() + + for device, custom_backend_pass in custom_backend_passes.items(): + if custom_backend_pass is not None: + gm_devices = [d.type for d in get_all_devices(gm)] + if device in gm_devices: + pass_name = "custom_backend_passes_" + device + GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) + + collectives_bucketing: bool = False + + if config.bucket_reduce_scatters_fx != "none": + from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter + from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter + + p = ( + bucket_fsdp_reduce_scatter + if "fsdp" in config.bucket_reduce_scatters_fx + else bucket_reduce_scatter + ) + GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass( + lambda graph: p( + graph.owning_module, + config.bucket_reduce_scatters_fx_bucket_size_determinator, + config.bucket_reduce_scatters_fx, # type: ignore[arg-type] + ) + ) + collectives_bucketing = True + + if config.bucket_all_reduces_fx != "none": + from torch._inductor.fx_passes.bucketing import bucket_all_reduce + + GraphTransformObserver(gm, "bucket_all_reduce").apply_graph_pass( + lambda graph: bucket_all_reduce( + graph.owning_module, + config.bucket_all_reduces_fx_bucket_size_determinator, + config.bucket_all_reduces_fx, # type: ignore[arg-type] + ) + ) + collectives_bucketing = True + + # Fx all_gather bucketing introduces mutation op + # Keeping it in the end to keep invariant of functional graph for previous passes. + if config.bucket_all_gathers_fx != "none": + from torch._inductor.fx_passes.bucketing import bucket_all_gather + from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather + + p = ( + bucket_fsdp_all_gather # type: ignore[assignment] + if "fsdp" in config.bucket_all_gathers_fx + else bucket_all_gather + ) + GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass( + lambda graph: p( + graph.owning_module, + config.bucket_all_gathers_fx_bucket_size_determinator, + config.bucket_all_gathers_fx, # type: ignore[arg-type] + ) + ) + collectives_bucketing = True + + if collectives_bucketing: + # Fx collectives bucketing passes require topological sort for the cases: + # when bucketed collectives have users before the last collective in the bucket + # AND when inputs of bucketed collective have ancestors after the first collective in the bucket. + # + # In this case we can not manually pick the place for bucketed collective insertion. + # But we are guaranteed by the bucketing (independent collectives in the bucket), + # that it is possible to reorder nodes to satisfy all ordering requirements. + # + # --- before bucketing --- + # in0 = ... + # wait_ag0 = ag(in0) + # user0(wait_ag0) + # ... + # pre_in1 = ... + # in1 = transform(pre_in1) + # wait_ag1 = ag(in1) + # user1(wait_ag1) + # + # --- after bucketing --- + # + # in0 = ... + # user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective. + # + # pre_in1 = ... + # in1 = transform(pre_in1) + # ag_bucket(in0+in1) + # wait_bucket + # wait_ag0 = wait_bucket[0] + # wait_ag1 = wait_bucket[1] + # user1(wait_ag1) + stable_topological_sort(gm.graph) + + # Apply overlap scheduling if enabled + if config.aten_distributed_optimizations.enable_overlap_scheduling: + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_from_inductor_configs, + ) + + overlap_deps = config.aten_distributed_optimizations.insert_overlap_deps + + # by default, insert overlap deps within inductor + with config.patch( + "aten_distributed_optimizations.insert_overlap_deps", + True if overlap_deps is None else overlap_deps, + ): + GraphTransformObserver(gm, "overlap_scheduling").apply_graph_pass( + lambda graph: schedule_overlap_bucketing_from_inductor_configs( + graph.owning_module + ) + ) + + # Keep these last, since they introduce mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( + functools.partial(reinplace_inplaceable_ops, fake_tensor_updater), + ) + GraphTransformObserver( + gm, "decompose_triton_kernel_wrapper_functional" + ).apply_graph_pass(decompose_triton_kernel_wrapper_functional) + GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( + decompose_auto_functionalized + ) + if not torch._dynamo.config.skip_fsdp_hooks: + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather + ) + GraphTransformObserver(gm, "decompose_scan_to_while_loop").apply_gm_pass( + decompose_scan_to_while_loop + ) + GraphTransformObserver(gm, "decompose_map_to_while_loop").apply_gm_pass( + decompose_map_to_while_loop + ) + + gm.recompile() + gm.graph.lint() + + +def prepare_softmax_pattern(x, dim): + xmax = x.amax(dim=dim, keepdim=True) + xsub = x - xmax + xexp = xsub.exp() + xsum = xexp.sum(dim=dim, keepdim=True) + return xmax, xsum, xsub, xexp + + +def prepare_softmax_replacement(x, dim): + """ + Return xsub since otherwise log-softmax can not be matched + due to a use of this intermediate node. Same reason to return + xsub.exp() for softmax. + """ + from torch._inductor.inductor_prims import prepare_softmax_online + + xmax, xsum = prepare_softmax_online(x, dim) + xsub = x - xmax + return xmax, xsum, xsub, xsub.exp() + + +def prepare_softmax_extra_check(match): + """ + We only have triton online softmax kernels currently. + """ + device_type = match.kwargs["x"].meta["val"].device.type + return ( + config.online_softmax + and device_type in ["cuda", "xpu"] + and getattr(config, f"{device_type}_backend") == "triton" + ) + + +def decompose_map_to_while_loop(gm: torch.fx.GraphModule): + """This is similar to decompose_scan_to_while_loop.""" + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.map_impl), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + assert len(kwargs) == 0, ( + "kwargs of map are not merged into args before entering decompose_map_to_while_loop_pass" + ) + subgraph, fx_xs, fx_additional_inputs = args + sub_gm: torch.fx.GraphModule = getattr(gm, subgraph.target) + cur_node = match.nodes[0] + mapped_outputs = cur_node.meta["val"] + + def lower_to_while_loop(*args, **kwargs): + assert len(kwargs) == 0 + xs, additional_inputs = pytree.tree_unflatten(args, tree_spec) + assert isinstance(xs, (tuple, list)) and isinstance( + additional_inputs, (tuple, list) + ), (xs, additional_inputs) + map_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # Similar to NOTE [Pre-allocate scan's output buffer] + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, map_length)) + if isinstance(arg, torch.SymInt) + } + out_buffers = [ + torch.empty_strided( + resolve_shape_to_proxy(out.size(), bound_symbols), + resolve_shape_to_proxy(out.stride(), bound_symbols), + device=out.device, + dtype=out.dtype, + layout=out.layout, + requires_grad=out.requires_grad, + ) + for out in mapped_outputs + ] + + while_loop_operands = (loop_idx, out_buffers, xs) + while_loop_flat_operands, operands_spec = pytree.tree_flatten( + while_loop_operands + ) + while_loop_additional_inputs = additional_inputs + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + def cond_fn(*flat_args): + loop_idx, _, _, _ = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + return loop_idx < map_length + + def body_fn(*flat_args): + loop_idx, out_bufs, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < map_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + outs = sub_gm(*sub_xs, *additional_inputs) + + for out, buffer in zip(outs, out_bufs): + buffer_slice = torch.ops.aten.select.int(buffer, 0, idx_int) + buffer_slice.copy_(out) + return loop_idx + 1, *out_bufs, *xs + + _, final_out, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(while_loop_flat_operands), + tuple(while_loop_additional_inputs), + ), + operands_spec, + ) + return (final_out,) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + (fx_xs, fx_additional_inputs) + ) + match.replace_by_example( + lower_to_while_loop, lower_to_while_loop_args, run_functional_passes=False + ) + + graph_pass.apply(gm) + + for _node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.map_impl + ): + raise AssertionError("map is not lowered to while_loop") + + +def resolve_shape_to_proxy( + shape: list[int | torch.SymInt], bound_symbols: dict[Any, Any] +): + """ + Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values. + When we trace this function, we'll get a graph with call_function nodes that describes how the shape expr is + computed from bound_symbols' values. + + Suppose shape = (s1*s2, s1+s2) and bound_symbols = {s1: arg0, s2: arg1}, the result will be + (arg0 * arg1, arg0 + arg1). + """ + from torch.utils._sympy.interp import sympy_interp + from torch.utils._sympy.reference import PythonReferenceAnalysis + + ret = [] + for s in shape: + if isinstance(s, torch.SymInt): + ret.append( + sympy_interp( + PythonReferenceAnalysis, + bound_symbols, + s.node.expr, + ), + ) + else: + assert isinstance(s, int) + ret.append(s) + return ret + + +def decompose_scan_to_while_loop(gm: torch.fx.GraphModule): + """ + NOTE [decompose scan to while_loop] + This pass decomposes `scan` to `while_loop` by replacing the scan fx_node with a while_loop hop. + + Suppose we have a function f: + + def f(): + init = torch.zeros([]) + xs = torch.arange(4) + ys = [] + for i in range(xs.size(0)): + init = xs[i] + init + ys.append(init) + + # Return the final carry and stack the intermediates + return init, torch.stack(ys) + + We could rewrite it with a scan with the benefits of reducing compilation time/binary size, reducing + memory usage, supporting loops over unbacked shapes and cudagraph etc. + + def g(): + def step_fn(init: torch.Tensor, x: torch.Tensor): + next_init = x + init + return next_init, next_init + + init = torch.zeros([]) + xs = torch.arange(4) + final_carry, ys = torch._higher_order.scan(step_fn, init, xs) + return final_carry, ys + + This pass will rewrite scan into: + + def k(): + init = torch.zeros([]) + xs = torch.arange(4) + + # we create a loop_idx and loop through xs.shape[0] + loop_idx = torch.zeros([]) + ys = torch.empty_strided(_shape_stride_of_ys) + def cond_fn(loop_idx, ys, init, xs): + return loop_idx < xs.shape[0] + + # we pre-allocate the output buffer ys and inplace + # copy the y of each intermediate into a slice. + # NOTE [Pre-allocate scan's output buffer]. + def body_fn(loop_idx, ys, init, xs): + int_idx = loop_idx.item() + next_init, y = step_fn(init, xs[int_idx]) + ys[int_idx].copy_(y) + return loop_idx + 1, ys, next_init, xs + + final_carry, _, _, ys = torch._higher_order.while_loop(cond_fn, body_fn, (loop_idx, ys, init, xs)) + return final_carry, ys + """ + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.scan), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.scan import _extract_carry_and_out + + assert len(kwargs) == 0, ( + "kwargs of scan are not merged into args before entering decompose_scan_to_while_loop_pass" + ) + + combine_subgraph, fx_init, fx_xs, fx_additional_inputs = args + assert combine_subgraph.op == "get_attr", "first arg is not combine_subgraph" + sub_gm: torch.fx.GraphModule = getattr(gm, combine_subgraph.target) + cur_node = match.nodes[0] + num_init_leaves = len(fx_init) + _, ys_outputs = _extract_carry_and_out(cur_node.meta["val"], num_init_leaves) + + def lower_to_while_loop(*args, **kwargs): + """ + The traced graph of this function will be used to replace the original scan fx_node. + """ + assert len(kwargs) == 0 + + # Step 1: construct necessary inputs to while_loop based on scan's input. + ( + init, + xs, + additional_inputs, + ) = pytree.tree_unflatten(args, tree_spec) + scan_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # NOTE [Pre-allocate scan's output buffer] + # In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node. + # However, the meta consists of concrete symints, we need to bind those symints with + # proxies in order to trace the torch.empty_strided call correctly. + # + # Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs + # in dynamo so we can always re-construct the sym expression from placeholders. + # See Note [Auto lift basic free symbols when create_graph_input] for how this is done. + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, scan_length)) + if isinstance(arg, torch.SymInt) + } + ys_outs = [ + torch.empty_strided( + resolve_shape_to_proxy(ys_out.size(), bound_symbols), + resolve_shape_to_proxy(ys_out.stride(), bound_symbols), + device=ys_out.device, + dtype=ys_out.dtype, + layout=ys_out.layout, + requires_grad=ys_out.requires_grad, + ) + for ys_out in ys_outputs + ] + + while_loop_operands = (loop_idx, ys_outs, init, xs) + flat_operands, operands_spec = pytree.tree_flatten(while_loop_operands) + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + # Step 2: create the cond_fn and body_fn for while_loop + def cond_fn(*flat_args): + loop_idx, _, _, _, _ = pytree.tree_unflatten( + flat_args, operands_and_additional_inputs_spec + ) # type: ignore[has-type] + return loop_idx < scan_length # type: ignore[has-type] + + def body_fn(*flat_args): + loop_idx, ys_outs, carry, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, # type: ignore[has-type] + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < scan_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + next_carry, ys = _extract_carry_and_out( + sub_gm(*(list(carry) + sub_xs + list(additional_inputs))), + num_init_leaves, + ) + for y, y_out in zip(ys, ys_outs): + y_out_slice = torch.ops.aten.select.int(y_out, 0, idx_int) + y_out_slice.copy_(y) + return loop_idx + 1, *ys_outs, *next_carry, *xs + + # Step 3: call the while_loop operator + _, ys_outs, last_carry, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(flat_operands), + tuple(additional_inputs), + ), + operands_spec, + ) + return list(last_carry) + list(ys_outs) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + ( + fx_init, + fx_xs, + fx_additional_inputs, + ) + ) + match.replace_by_example( + lower_to_while_loop, + lower_to_while_loop_args, + run_functional_passes=False, + ) + + graph_pass.apply(gm) + + for _node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.scan + ): + raise AssertionError("scan is not lowered to while_loop") + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + # Put this patterns in post-grad pass rather than joint-graph + # pass since otherwise there will be perf/peak-memory regression: + # https://github.com/pytorch/pytorch/issues/148141 + register_replacement( + # pyrefly: ignore [bad-argument-type] + prepare_softmax_pattern, + # pyrefly: ignore [bad-argument-type] + prepare_softmax_replacement, + [torch.empty(4, 8)], + scalar_workaround=dict(dim=-1), + # pyrefly: ignore [bad-argument-type] + trace_fn=fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_dicts=pass_patterns[1], + extra_check=prepare_softmax_extra_check, + ) + + +def reorder_for_locality(graph: torch.fx.Graph): + if torch.distributed.is_available(): + + def check(): + # This is a wait node, and `other_node`` is some collective node. + # Eager semantics allow waits to be issued in a different order than + # the collectives. Reordering this wait node might reorder collectives + # which cause hangs. Once we have SPMD mode, we can safely reorder them. + # However, increasing the locality between a collective and its wait node + # is generally worse for performance. + return node.target != torch.ops._c10d_functional.wait_tensor.default + else: + + def check(): + return True + + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + and check() + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = OrderedSet[torch.fx.Node]() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesn't work well with mutation + first_copy = next( + iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), + None, + ) + past_mutating_epilogue = first_copy is None + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern( + pattern, extra_check=_return_true, pass_number=1 +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, + extra_check, + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[pass_number], + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Check if all required values exist + mat1_val = match.kwargs["mat1"].meta.get("val") + mat2_val = match.kwargs["mat2"].meta.get("val") + mat3_val = match.kwargs["mat3"].meta.get("val") + mat4_val = match.kwargs["mat4"].meta.get("val") + + if mat1_val is None or mat2_val is None or mat3_val is None or mat4_val is None: + return False + + *_b1, m1, k1 = mat1_val.shape + *_b2, k2, n1 = mat2_val.shape + if k1 != k2: + return False + + *_b1, m2, k3 = mat3_val.shape + *_b2, k4, n2 = mat4_val.shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, list(shape)) + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = OrderedSet( + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + ) + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != OrderedSet(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) # type: ignore[arg-type] + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + + slice_dim_size = self.shape[dim] + if ( + statically_known_true(sym_eq(start, 0)) + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_dim_size) + ) + and statically_known_true(sym_eq(step, 1)) + ): + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + slice_scatter_dim_size = self.shape[dim] + if ( + self.shape == src.shape + and start == 0 + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_scatter_dim_size) + ) + and step == 1 + ): + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device, non_blocking=True): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view.default) +def view_default_noop(arg, size): + return statically_known_true(sym_eq(arg.shape, tuple(size))) + + +@register_noop_decomp(aten.view.dtype) +def view_dtype_noop(arg, dtype): + return arg.dtype == dtype + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = OrderedSet[torch.fx.Node]() + input_storages = OrderedSet[int | None]() + output_storages = OrderedSet[int | None]() + + for node in graph.find_nodes(op="placeholder"): + inputs.add(node) + input_storages.add(get_node_storage(node)) + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + # nested subgraphs can have singleton outputs + outputs = (outputs,) + for out in outputs: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def remove_assert_ops(graph: torch.fx.Graph): + """ + Removes aten._assert_tensor_metadata.default op because + 1) it will be lowered to a no-op in inductor + 2) it can block fusion, such as unfuse_bias_add_to_pointwise fusion. + + This op could come from aten.to functionalization in export. + + For example, if we have a graph like below + + %addmm = aten.addmm.default(%linear_bias, %arg3_1, %permute) + %_assert_tensor_metadata = aten._assert_tensor_metadata.default(%addmm, None, None, torch.float16) + %convert_element_type_3 = prims.convert_element_type.default(%addmm, torch.float32) + %pow_1 = aten.pow.Tensor_Scalar(%convert_element_type_3, 2) + + We still want to fuse add from addmm with pow, instead of fusing add with mm, according to unfuse_bias_add_to_pointwise fusion. + + However, aten._assert_tensor_metadata.default is not a pointwise op, and would fail the should_prefer_unfused_addmm check. + + We remove this op so it doesn't block fusion decisions. It's safe because this op is lowered to a no-op with @register_lowering. + + """ + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten._assert_tensor_metadata.default + ): + graph.erase_node(node) + + +def decompose_triton_kernel_wrapper_functional(graph): + """Decomposes triton_kernel_wrapper_functional nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + for _ in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") + + +def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mode = args[0] + return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs) + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def _maybe_resolve_constant_get_attr(node): + # Resolve getattr node to its value because they don't always have meta["val"] + if ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and "val" not in node.meta + ): + const_attr = getattr(graph.owning_module, node.target) # type: ignore[arg-type] + assert isinstance( + const_attr, (torch.fx.GraphModule, pytree.TreeSpec) + ), (type(const_attr), const_attr) + return const_attr + return node + + flat_args = [_maybe_resolve_constant_get_attr(arg) for arg in flat_args] + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mutable_op = args[0] + return auto_functionalized_v2_dense( + mutable_op, only_clone_these_bases, **kwargs + ) + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + # Remove unused get_attr nodes and their corresponding attributes from the graph module. + # When auto_functionalizing a hop, we need to clean up get_attr nodes for _constant_schema + # and the auto_functionalized graph module that are no longer referenced. + unused_get_attr_nodes = [] + removable_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet() + protected_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet() + + # First pass: identify unused get_attr nodes and track attribute usage + for node in graph.nodes: + if node.op != "get_attr": + continue + + if len(node.users) == 0: + # Node is unused, mark for removal + unused_get_attr_nodes.append(node) + + # Check if the attribute can be removed from the module + if ( + hasattr(graph.owning_module, node.target) + and isinstance( + getattr(graph.owning_module, node.target), torch.fx.GraphModule + ) + and node.target not in protected_attrs + ): + removable_attrs.add(node.target) + else: + # Node is used, protect its attribute from removal + if node.target in removable_attrs: + removable_attrs.remove(node.target) + protected_attrs.add(node.target) + + # Second pass: clean up unused nodes and attributes + for node in unused_get_attr_nodes: + graph.erase_node(node) + + for attr_name in removable_attrs: + assert isinstance(attr_name, str) + delattr(graph.owning_module, attr_name) + + graph.lint() + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized + ): + raise AssertionError("auto_functionalized was not removed") + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + subgraph_names: OrderedSet[str] = OrderedSet( + x.target for x in gm.graph.find_nodes(op="get_attr") + ) + + for child_name, child_mod in gm.named_children(): + if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule): + view_to_reshape(child_mod) + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.view.default + ): + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not is_gpu(inp.meta["val"].device.type): + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction( + aten.addmm, + KeywordArg("inp"), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta): + def repl(inp, x1, x2, alpha, beta): + mm_result = x1 @ x2 + if alpha != 1: + mm_result = alpha * mm_result + if beta != 1: + inp = beta * inp + return inp + mm_result + + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + inp_dtype = inp.meta["val"].dtype + + # aten cublas integration assumes equal dtypes + if inp_dtype != mat1.meta["val"].dtype or inp_dtype != mat2.meta["val"].dtype: + return False + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def register_partial_reduction_pattern(): + "Reuse partial reductions in complete reductions" + + # post grad equivalents + equiv_red = { + aten.amax.default: aten.max.default, + aten.amin.default: aten.min.default, + } + + # TODO: to support other reductions like sum, would need to skip + # lower precision reductions since partial output would need to be kept at fp32. + for red_op in (aten.amax.default, aten.amin.default): + inp = KeywordArg("input") + partial_reduc = CallFunction( + red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim") + ) + full_reduc = CallFunction([red_op, equiv_red[red_op]], inp) + + @register_graph_pattern( + MultiOutputPattern([partial_reduc, full_reduc]), + # pyrefly: ignore [bad-argument-type] + pass_dict=pass_patterns[2], + ) + def reuse_partial(match, input, reduced_dims, keepdim): + partial_red, full_red = match.output_nodes() + + # if they're small, reuse not worth it + if not statically_known_true(input.meta["val"].numel() >= 4096): + return True + + def replacement(inp: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + partial = partial_red.target(inp, reduced_dims, keepdim) + complete = full_red.target(partial) + return (partial, complete) + + counters["inductor"]["partial_reduction_reuse"] += 1 + match.replace_by_example(replacement, [input]) + + +register_partial_reduction_pattern() + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +def is_index_put_and_requires_h2d_sync_for_gpu_value(node): + from torch.fx.operator_schemas import normalize_function + + if node.target not in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: + return False + # Inductor falls back to aten.index_put_. + # index_put_ will will call nonzero() and perform a H2D sync if + # any of its indices are bool/byte tensors + # However, it will short-circuit this H2D sync and run mask_fill_ + # if the value we are putting is a cpu scalar. + # Therefore, when inductor sees an index_put_ with byte tensor indices, + # it should *not* convert the cpu scalar value into a gpu tensor. + args_, _kwargs = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc] + any_byte_bool_indices = False + indices = args_[1] + for i in indices: + if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]: + any_byte_bool_indices = True + + val = args_[2].meta["val"] + val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1 + # If both these conditions hold, then converting the val + # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_ + return any_byte_bool_indices and val_is_cpu_scalar + + +class ConstructorMoverPass: + def __init__( + self, target: str, allow_outputs: bool = False, allow_inputs: bool = False + ) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + - allow_inputs: allow inputs to be moved + """ + + self.target = target + self.allow_inputs = allow_inputs + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def is_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node is on the target device. + """ + node_device = self.get_node_device(node) + return node_device is not None and node_device.type == self.target + + def is_cpu_scalar_tensor(self, node: fx.Node) -> bool: + """ + Returns whether a node is a cpu scalar tensor. + """ + device = self.get_node_device(node) + is_cpu = device is not None and device.type == "cpu" + ten = node.meta.get("val") + is_scalar = isinstance(ten, torch.Tensor) and len(ten.size()) == 0 + return is_cpu and is_scalar + + def all_inputs_are_cpu_scalar_or_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node's inputs are either cpu scalar tensors or + on the target device. + """ + inputs = ( + inp + for inp in itertools.chain(node.args, node.kwargs.values()) + if isinstance(inp, fx.Node) + ) + return all( + self.is_cpu_scalar_tensor(inp) or self.is_on_target_device(inp) + for inp in inputs + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + + if is_index_put_and_requires_h2d_sync_for_gpu_value(node): + return True + + return False + + def get_node_device(self, node: fx.Node) -> torch.device | None: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + # pyrefly: ignore [redundant-condition] + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = OrderedSet[torch.device]() + constructors = [] + cpu_placeholders: OrderedSet[fx.Node] = OrderedSet() + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if ( + self.allow_inputs + and node.op == "placeholder" + and self.is_cpu_scalar_tensor(node) + ): + cpu_placeholders.add(node) + constructors.append(node) + continue + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if node.kwargs.get("device") != torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + target_device = next(iter(target_devices)) + movable_cpu_placeholders = movable_constructors & cpu_placeholders + if movable_cpu_placeholders: + node = next(iter(reversed(movable_cpu_placeholders))) + last_node = node + unsqueezed_nodes = [] + for elem in movable_cpu_placeholders: + with graph.inserting_after(last_node): + unsqueezed_nodes.append( + graph.call_function(torch.ops.aten.unsqueeze.default, (elem, 0)) + ) + last_node = unsqueezed_nodes[-1] + with graph.inserting_after(last_node): + cpu_concat = graph.call_function( + torch.ops.aten.cat.default, (unsqueezed_nodes,) + ) + last_node = cpu_concat + with graph.inserting_after(last_node): + gpu_concat = graph.call_function( + torch.ops.prims.device_put.default, + (cpu_concat, target_device, True), + ) + last_node = gpu_concat + with graph.inserting_after(last_node): + gpu_split = graph.call_function( + torch.ops.aten.unbind.int, (gpu_concat,) + ) + last_node = gpu_split + for idx, node in enumerate(movable_cpu_placeholders): + with graph.inserting_after(last_node): + gpu_node = graph.call_function(operator.getitem, (gpu_split, idx)) + node.replace_all_uses_with( + gpu_node, + lambda x: x + not in [cpu_concat, gpu_concat, gpu_split, gpu_node] + + unsqueezed_nodes + and x.target != torch.ops.aten.copy_.default, + ) + last_node = gpu_node + + # noop elimination if there are other device_put for gpu_node to + # target device. Alternatively, we could just move the other device_put + # earlier in the graph, but that is not supported in fx graph yet. + noop_device_puts = [ + user + for user in gpu_node.users + if user.target is torch.ops.prims.device_put.default + and user.args[1] == target_device + ] + for noop in noop_device_puts: + noop.replace_all_uses_with(gpu_node) + graph.erase_node(noop) + + movable_constructors -= movable_cpu_placeholders + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: list[fx.Node] + ) -> OrderedSet[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to gpu + cannot_move_to_gpu = OrderedSet[fx.Node]() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to gpu, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: dict[fx.Node, OrderedSet[fx.Node]] = { + c: OrderedSet([c]) for c in constructors + } + + def make_dependencies_equivalent( + set1: OrderedSet[fx.Node], set2: OrderedSet[fx.Node] + ) -> OrderedSet[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: list[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_gpu.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a gpu + # tensor. we can convert its cpu input to gpu without making further changes + if self.allow_cpu_device(user) and self.is_on_target_device(user): + del cpu_indeg[user] + elif ( + self.allow_inputs + and self.all_inputs_are_cpu_scalar_or_on_target_device(user) + ): + # this node takes only cpu scalar tensors or gpu tensors as inputs + # and outputs a gpu tensor. we can convert its cpu scalar inputs to gpu + # without making further changes + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_gpu.update(constructor_dependencies[node]) + + all_cannot_move_to_gpu = cannot_move_to_gpu.copy() + for constructor in cannot_move_to_gpu: + all_cannot_move_to_gpu.update(equal_constructor_sets[constructor]) + + return OrderedSet(constructors) - all_cannot_move_to_gpu + + +def move_constructors_to_gpu(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to gpu when safe + """ + + # cudagraph does not support cpu tensors. In this pass, we update the graph + # by explicitly moving cpu scalar tensors to gpu when profitable, relying on + # graph partition to split off this data copy, and cudagraphifying + # the remaining gpu ops. + allow_inputs_outputs = bool( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + ConstructorMoverPass( + get_gpu_type(), + allow_inputs=allow_inputs_outputs, + allow_outputs=allow_inputs_outputs, + )(graph) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd81f9b8cd57a1384806191ff7b8338f7b36807 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py @@ -0,0 +1,877 @@ +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import logging +import types +from collections.abc import Sequence + +import torch +import torch.nn as nn +from torch._dynamo.utils import counters, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) +from torch.fx.passes.graph_transform_observer import ( + GraphTransformObserver as GraphTransformObserverBase, +) +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + +from .. import config +from ..fx_utils import matches_module_function_pattern +from ..pattern_matcher import ( + init_once_fakemode, + PatternMatcherPass as PatternMatcherPassBase, + stable_topological_sort, +) +from ..utils import is_cpu_device, pass_execution_and_save +from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS +from .misc_patterns import numpy_compat_normalization +from .split_cat import PRE_GRAD_PATTERNS + + +PatternMatcherPass = functools.partial( + PatternMatcherPassBase, subsystem="pre_grad_passes" +) +GraphTransformObserver = functools.partial( + GraphTransformObserverBase, subsystem="pre_grad_passes" +) + +log = logging.getLogger(__name__) + +efficient_conv_bn_eval_pass = PatternMatcherPass( + pass_name="efficient_conv_bn_eval_pass" +) + +fuse_split_linear_add_pass = PatternMatcherPass( + pass_name="fuse_split_linear_add_pass", +) +fuse_chunk_squeeze_cat_pass = PatternMatcherPass( + pass_name="fuse_chunk_squeeze_cat_pass", +) +remove_reshape_pass = PatternMatcherPass( + pass_name="remove_reshape_pass", +) + +# based on predispatch aten IR +normalization_pass_aten = PatternMatcherPass(pass_name="normalization_pass_aten") +merge_splits_pass_aten = PatternMatcherPass(pass_name="merge_splits_pass_aten") +split_cat_pass_aten = PatternMatcherPass(pass_name="split_cat_pass_aten") +unbind_stack_pass_aten = PatternMatcherPass(pass_name="unbind_stack_pass_aten") +merge_getitem_cat_pass_aten = PatternMatcherPass( + pass_name="merge_getitem_cat_pass_aten" +) +merge_stack_tahn_unbind_pass_aten = PatternMatcherPass( + pass_name="merge_stack_tahn_unbind_pass_aten" +) +mutate_cat_pass_aten = PatternMatcherPass(pass_name="mutate_cat_pass_aten") +remove_split_with_size_one_pass_aten = PatternMatcherPass( + pass_name="remove_split_with_size_one_pass_aten" +) + + +def save_inductor_dict(pass_to_compare=None): + if not pass_to_compare: + pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list( + config.post_grad_fusion_options.keys() + ) + return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare} + + +def is_same_dict(inductor_dict, optimus_dict): + for pass_name, count in optimus_dict.items(): + if count != dict(inductor_dict).get(pass_name, 0): + return False + return True + + +def shape_prop(mod) -> None: + return None + + +def normalize_node_kwargs_pass(graph): + return None + + +def fuse_parallel_linear_pass(graph): + return None + + +def remove_split_ops(graph, shape_prop): + return None + + +def remove_split_ops_pass(graph): + remove_split_ops(graph.owning_module, shape_prop) + + +def fuse_chunk_reshape_unsqueeze_concat_pass(graph): + return None + + +def fuse_chunk_reshape_concat_pass(graph): + return None + + +def remove_noop_pass(graph): + return None + + +def stack_to_unsqueeze_pass(graph): + return None + + +def merge_concats_pass(graph): + return None + + +def relu_nan_to_num(graph): + return None + + +def fuse_split_getitem_squeeze_cat(graph): + return None + + +def use_triton_dot_compress(graph): + return None + + +def use_triton_lce_replace_simple_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_simple_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_triton_lce_replace_normal_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_normal_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_matmul_lce_replace_normal_LCE(graph): + return None + + +def use_matmul_fuse_lce_replace_first_LCE(graph): + return None + + +@init_once_fakemode +def lazy_init(): + from . import efficient_conv_bn_eval, split_cat # noqa: F401 + + if config.is_fbcode(): + from . import fb # type: ignore[attr-defined] # noqa: F401 + + +def _get_pass_name_func(p): + if isinstance(p, PatternMatcherPassBase): + pass_name = p.pass_name + pass_func = p.apply + elif isinstance(p, types.FunctionType): + pass_name = p.__name__.lstrip("_") + pass_func = p + else: + pass_name = None + pass_func = None + + return pass_name, pass_func + + +def _run_pre_dispatch_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: str | None = None, + remove_passes: str | None = None, +) -> None: + # order matters + default_pass_list = [ + # normalize passes, must be called as the first passes + normalization_pass_aten, + normalize_node_kwargs_pass, + remove_noop_pass, + relu_nan_to_num, + fuse_chunk_reshape_concat_pass, + group_batch_fusion_passes, + normalize_node_kwargs_pass, + fuse_chunk_squeeze_cat_pass, + merge_concats_pass, + fuse_split_linear_add_pass, + remove_reshape_pass, + fuse_parallel_linear_pass, + remove_split_ops_pass, + stack_to_unsqueeze_pass, # run before fuse_chunk_reshape_unsqueeze_concat_pass + fuse_chunk_reshape_unsqueeze_concat_pass, + ] + + full_pass_list = default_pass_list + [ + fuse_split_getitem_squeeze_cat, + use_triton_dot_compress, + use_triton_lce_replace_simple_LCE, + use_triton_lce_replace_normal_LCE, + use_matmul_fuse_lce_replace_first_LCE, + use_matmul_lce_replace_normal_LCE, + ] + + log.info( + f"pre_grad_passes: add_passes: {add_passes}, remove_pass: {remove_passes}" # noqa: G004 + ) + add_passes_list = [] + remove_passes_list = [] + if add_passes: + add_passes_list = add_passes.split(",") + if remove_passes: + remove_passes_list = remove_passes.split(",") + + shape_prop = lambda mod: ShapeProp( # noqa: E731 + gm=mod, + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode=detect_fake_mode(example_inputs), + ).propagate(*tuple(example_inputs)) + + for p in default_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + # should not happen + if pass_name is None or pass_func is None: + continue + if pass_name in remove_passes_list: + continue + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + for p in full_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + if pass_name is None or pass_func is None: + continue + if pass_name in add_passes_list: + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + if "remove_noop" not in remove_passes_list: + # Remove noops at the end, which may be generated other passes. + pass_execution_and_save( + remove_noop_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply remove_noop pass", + ) + shape_prop(gm) + + +def pre_grad_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: str | None = None, + remove_passes: str | None = None, +) -> torch.fx.GraphModule: + """ + Apply passes on the input FX graph using Torch IR. + + WARNING: + The IR before grad is not functional or normalized, so it is harder + to write passes on this IR. Passes must be safe with respect to + aliasing and mutation and need to handle all possible arg schemas. + + Consider adding a new pass to post_grad.py or joint_graph.py which + are after functionalization and normalization. + """ + if config.pattern_matcher: + lazy_init() + if hasattr( + config, "fx_passes_numeric_check" + ) and config.fx_passes_numeric_check.get("pre_grad", False): + gm_before_fx_passes = gm.__copy__() + # explicitly run with predispatch atenIR based passes + if config.is_predispatch: + _run_pre_dispatch_passes(gm, example_inputs, add_passes, remove_passes) + else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ + if example_inputs is not None: + gm = fuse_fx(gm, example_inputs) + numpy_compat_normalization(gm.graph) + # We should always do the normalization_pass first + if "normalization_pass" in config.pre_grad_fusion_options: + pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"] + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + group_batch_fusion_passes(gm.graph, pre_grad=True) + for pass_name in config.pre_grad_fusion_options: + # skip all patterns for group batch fusions + if pass_name in PRE_GRAD_FUSIONS or pass_name == "normalization_pass": + continue + pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_pre_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. + efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.pre_grad_custom_pass is not None: + GraphTransformObserver(gm, "pre_grad_custom_pass").apply_graph_pass( + config.pre_grad_custom_pass + ) + stable_topological_sort(gm.graph) + + from .quantization import quant_lift_up + + quant_lift_up(gm) + + gm.graph.lint() + gm.recompile() + + if ( + config.pattern_matcher + and hasattr(config, "fx_passes_numeric_check") + and config.fx_passes_numeric_check.get("pre_grad", False) + and example_inputs is not None + ): + from .numeric_utils import numeric_check_if_enabled + + gm_after_fx_passes = gm.__copy__() + numeric_check_if_enabled( + gm_before_fx_passes, # type: ignore[possibly-undefined] + gm_after_fx_passes, + example_inputs, + config.fx_passes_numeric_check.get("num_iterations", 1), + config.fx_passes_numeric_check.get("precision", 1e-4), + ) + + return gm + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: + is_cpu = is_cpu_device(example_inputs) + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode = detect_fake_mode(example_inputs) + + gm = sink_cat_after_pointwise(gm) + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + with GraphTransformObserver(gm, "linear_permute_fusion"): + gm = linear_permute_fusion(gm) + with GraphTransformObserver(gm, "permute_linear_fusion"): + gm = permute_linear_fusion(gm) + with GraphTransformObserver(gm, "permute_matmul_fusion"): + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled() or not is_cpu: + return gm + if config.freezing: + with GraphTransformObserver(gm, "remove_identity"): + gm = remove_identity(gm) + with GraphTransformObserver(gm, "fuse_conv_bn"): + gm = fuse_conv_bn(gm) + return gm + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + + class ConvBNFusion: + def __init__( + self, + bn_node, + conv_module, + bn_module=None, # For BN Module + bn_running_mean=None, # For Functional BN + bn_running_var=None, + bn_eps=None, + bn_weight=None, + bn_bias=None, + ) -> None: + self.bn_nodes = [ + bn_node, + ] + self.conv_module = conv_module + self.bn_module = bn_module + self.bn_running_mean = bn_running_mean + self.bn_running_var = bn_running_var + self.bn_eps = bn_eps + self.bn_weight = bn_weight + self.bn_bias = bn_bias + self.fusion_enabled = True + + def add_bn_node(self, bn_node): + self.bn_nodes.append(bn_node) + + def disable_fusion(self): + self.fusion_enabled = False + + def is_fusion_enabled(self): + return self.fusion_enabled + + conv_bn_to_fuse: dict[int, ConvBNFusion] = {} + for pattern in modules_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn) + else: + if bn == conv_bn_to_fuse[hash_id].bn_module: + # Do fusion if same bn module + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn = conv_bn_fusion.bn_module + + # pyrefly: ignore [bad-argument-type] + fused_conv = fuse_conv_bn_eval(conv, bn) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + + gm.graph.lint() + for pattern in module_function_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + + def _used_by_same_conv_module(users): + conv_module_name = users[0].args[0].target + return all( + conv_module_name == user.args[0].target for user in users + ) + + bn_args_is_constant = all( + n.op == "get_attr" + and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users))) + for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion( + node, + conv, + bn_running_mean=bn_running_mean, + bn_running_var=bn_running_var, + bn_eps=bn_eps, + bn_weight=bn_weight, + bn_bias=bn_bias, + ) + else: + if ( + hash(bn_running_mean) + == hash(conv_bn_to_fuse[hash_id].bn_running_mean) + and hash(bn_running_var) + == hash(conv_bn_to_fuse[hash_id].bn_running_var) + and torch.allclose( + torch.tensor(bn_eps), + torch.tensor(conv_bn_to_fuse[hash_id].bn_eps), + ) + and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight) + and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias) + ): + # Do fusion if same functional bn + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn_running_mean = conv_bn_fusion.bn_running_mean + bn_running_var = conv_bn_fusion.bn_running_var + bn_eps = conv_bn_fusion.bn_eps + bn_weight = conv_bn_fusion.bn_weight + bn_bias = conv_bn_fusion.bn_bias + + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + # pyrefly: ignore [bad-argument-type] + bn_running_mean, + # pyrefly: ignore [bad-argument-type] + bn_running_var, + # pyrefly: ignore [bad-argument-type] + bn_eps, + bn_weight, + bn_bias, + ) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + gm.graph.lint() + gm.recompile() + + return gm + + +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target is torch.nn.functional.linear + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["weight"] # type: ignore[return-value] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] # type: ignore[return-value] + else: + return self.node.kwargs.get("bias", None) # type: ignore[return-value] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["other"] # type: ignore[return-value] + + +def check_permute(node: torch.fx.Node) -> bool: + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[operator, union-attr] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def one_user(node): + users = list(node.users) + return users[0] if len(users) == 1 else None + + def is_view(node): + return node.op == "call_method" and node.target == "view" + + def is_pointwise_unary(node): + ops = "call_function", "call_method" + pointwise = torch.relu, torch.tanh, "relu", "tanh" + return node.op in ops and node.target in pointwise + + g = module.graph + for node in g.nodes: + if node.op != "call_function" or node.target != torch.cat: + continue + + cat_or_view = node + while True: + user = one_user(cat_or_view) + if not user or not is_view(user): + break + cat_or_view = user + + if user and is_pointwise_unary(user): + with g.inserting_before(node): + + def cat_args(tensors, dim=0): + return tensors, dim + + tensors, dim = cat_args(*node.args, **node.kwargs) + new_kwargs = { + name: val for name, val in user.kwargs.items() if name != "input" + } + new_tensors = [ + g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs) + for arg in tensors + ] + new_cat = g.create_node( + "call_function", torch.cat, args=(new_tensors, dim) + ) + user.replace_all_uses_with(cat_or_view) + node.replace_all_uses_with(new_cat) + g.erase_node(user) + g.erase_node(node) + g.lint() + module.recompile() + return module + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes(op="call_method", target="permute"): + if check_permute(node): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target is torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None +) -> torch.Tensor: + if bias is None: + return torch.matmul(weight, input.transpose(-1, -2)) + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes( + op="call_function", target=torch.nn.functional.linear + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in itertools.chain( + module.graph.find_nodes(op="call_function", target=torch.bmm), + module.graph.find_nodes(op="call_function", target=torch.matmul), + ): + normalized = NormalizedMatmulNode(node) + input_A_node = normalized.get_input() + input_B_node = normalized.get_other() + input_A = input_A_node + input_B = input_B_node + Atrans = Btrans = False + if ( + input_A_node.op == "call_method" + and input_A_node.target == "permute" + and check_permute(input_A_node) + ): + Atrans = True + if len(input_A_node.args) > 0: + input_A = input_A_node.args[0] # type: ignore[assignment] + else: + input_A = input_A_node.kwargs["input"] # type: ignore[assignment] + + if ( + input_B_node.op == "call_method" + and input_B_node.target == "permute" + and check_permute(input_B_node) + ): + Btrans = True + if len(input_B_node.args) > 0: + input_B = input_B_node.args[0] # type: ignore[assignment] + else: + input_B = input_B_node.kwargs["input"] # type: ignore[assignment] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(input_A, input_B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if Atrans and len(input_A_node.users) == 0: + module.graph.erase_node(input_A_node) + if Btrans and len(input_B_node.users) == 0: + module.graph.erase_node(input_B_node) + + module.graph.lint() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None +) -> torch.Tensor: + if bias is None: + return torch.matmul(input.transpose(-1, -2), weight.t()) + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul( + A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool +) -> torch.Tensor: + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..951a62acf227610007db48a9aae1aa6795d01ee8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,3968 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import math +import operator +from typing import Any + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg + +from .. import config +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + ListOf, + Match, + stable_topological_sort, +) +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +# Only for per tensor quant since permute may changes the channel idx +_PER_TENSOR_QUANTIZE_OPS = [ + quantized_decomposed.quantize_per_tensor.default, + quantized_decomposed.quantize_per_tensor.tensor, +] + +_VIEW_OPS = [ + aten.transpose.int, + aten.permute.default, + aten.view.default, + aten.reshape.default, +] + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _get_pattern_output_dtype(match: Match): + """ + Get the pattern's output dtype from node's meta + Assume only 1 output node in this matched pattern. + """ + pattern_output_nodes = match.output_nodes() + assert len(pattern_output_nodes) == 1 + output_node = pattern_output_nodes[0] + assert isinstance(output_node, torch.fx.Node) + output_dtype = output_node.meta["val"].dtype + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + ] + return output_dtype + + +def _may_generate_pattern_with_dtype_convert( + pattern, dtype=Arg(), with_dtype_convert=True, users=1 +): + if with_dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + _users=users, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): + # only insert to_dtype if is_bf16 is True + computation_call = _may_generate_pattern_with_dtype_convert( + call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + ) + return unary_fusion(computation_call) + + +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) + return dequantize_per_tensor_activation_pattern + + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) + return CallFunction( + qconv_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) + return CallFunction( + qconv_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("accum"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("accum_scale"), + KeywordArg("accum_zero_point"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("x_2"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("x2_scale"), + KeywordArg("x2_zp"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +dequantize_accum_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("accum"), + KeywordArg("accum_scale"), + KeywordArg("accum_zp"), + Arg(), + Arg(), + KeywordArg("accum_dq_dtype"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + dtype_convert=False, + swap_inputs=False, +): + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + dtype_convert, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, + ), + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qconv_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_conv_optimization_pattern() + ) + + +def _is_valid_qconv_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv_pointwise.tensor, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qconv2d_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + torch.float32, + torch.bfloat16, + ] + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + output_dtype = kwargs["output_dtype"] + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qconv_unary_lower_count"] += 1 + counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qlinear_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_linear_optimization_pattern() + ) + + +def _is_valid_qlinear_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_linear_unary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs.get("b") + + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qlinear_unary_lower_count"] += 1 + counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = kwargs["x_2"] + x2_scale = kwargs["x2_scale"] + x2_zp = kwargs["x2_zp"] + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs.get("b") + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + x2.realize() + from .mkldnn_fusion import _qlinear_binary_can_be_inplace + + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + if ( + # TODO Ensure sum is safe and remove such check, i.e., + # x2 is not used by other operations + # or current qlinear sum is the last user of x2. + # This needs to be ensured when registering + # the lowering pattern of quantized_linear_binary. + binary_op_name == "sum" and (not _qlinear_binary_can_be_inplace(x2)) + ): + binary_op_name = "add" + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + x2, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qlinear_binary_lower_count"] += 1 + counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype in [torch.float32, torch.bfloat16]: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if (output_dtype in [torch.uint8, torch.int8]) + or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) + ) + if ( + len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = kwargs["accum"] + accum_scale = kwargs["accum_scale"] + accum_zp = kwargs["accum_zero_point"] + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + output_scale = kwargs["output_scale"] + output_zero_point = kwargs["output_zero_point"] + + # post ops + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace(accum), ( + "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + ) + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qconv2d_binary_lower_count"] += 1 + counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_lowering(): + # QConv2d + for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]): + qconv_pattern = get_qconv_pt2e_pattern(x_scale_zp_are_tensors, users) + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) + _register_quantized_conv_lowering( + qconv_pattern, + 2, # pass_number + computation_op, + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + _register_quantized_linear_unary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _register_quantization_binary_lowering(): + # QConv2d + for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]): + qconv_pattern = get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors, users) + computation_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) + _register_quantized_conv_binary_lowering( + qconv_pattern, + 2, # pass_number + computation_op, + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + _register_quantized_linear_binary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs.get("stride") + padding = kwargs.get("padding", 0) + dilation = kwargs.get("dilation", 1) + ceil_mode = kwargs.get("ceil_mode", False) + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + counters["inductor"]["qmaxpool2d_matcher_count"] += 1 + counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_lowmem_maxpool2d_pattern = CallFunction( + prims._low_memory_max_pool_with_offsets.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + KeywordArg("offset_dtype"), + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_lowmem_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant( + dequantize_lowmem_maxpool2d_get_item_pattern + ), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + # Get dequant nodes at input + dequant_nodes = filter_nodes( + match.nodes, quantized_decomposed.dequantize_per_tensor.default + ) + zero_points = [node.args[2] for node in dequant_nodes] + # Get quant nodes at output + quant_nodes = filter_nodes( + match.nodes, quantized_decomposed.quantize_per_tensor.default + ) + assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(quant_nodes[0].args[2]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + scales = [node.args[1] for node in dequant_nodes] + scales.append(quant_nodes[0].args[1]) + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + counters["inductor"]["qcat_matcher_count"] += 1 + counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _is_valid_concat_linear_int8_woq_optimization_pattern(): + def fn(match): + if not config.cpp.enable_concat_linear: + return False + assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") + for key in ["x", "w1", "w2", "w3", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + w1 = match.kwargs["w1"].meta["val"] + w2 = match.kwargs["w2"].meta["val"] + w3 = match.kwargs["w3"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + if len(match.kwargs["scales"].meta["val"].size()) > 1: + return False + num_scales = match.kwargs["scales"].meta["val"].numel() + w1_cols = match.kwargs["w1"].meta["val"].size()[0] + w2_cols = match.kwargs["w2"].meta["val"].size()[0] + w3_cols = match.kwargs["w3"].meta["val"].size()[0] + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and w1.dtype == torch.int8 + and w2.dtype == torch.int8 + and w3.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + and x.device.type in ("cpu", "cuda") + and x.device == w1.device + and w1.device == w2.device + and w2.device == w3.device + and x.device == scales.device + and num_scales == w1_cols + w2_cols + w3_cols + ) + + return fn + + +def _is_valid_woq_optimization_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("x", "weight", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + weight = match.kwargs["weight"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and weight.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + and x.device.type in ("cpu", "cuda") + and x.device == weight.device + and x.device == scales.device + ) + + return fn + + +def _register_concat_linear_int8_woq_lowering( + pattern, computation_woq, computation_reshape +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(), + pass_number=4, + ) + def woq_int8(match: Match, *args, **kwargs): + x = kwargs["x"] + w1 = kwargs["w1"] + w2 = kwargs["w2"] + w3 = kwargs["w3"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = ( + w1.meta["val"].size()[0] + + w2.meta["val"].size()[0] + + w3.meta["val"].size()[0] + ) + origin_x_size = tuple(x.meta["val"].size()) + x_shape = [-1, origin_x_size[-1]] + out_shape = list(origin_x_size[:-1] + (out_features,)) + mm_node_of_x = None + for candidate in iter(x.users.keys()): + if ( + candidate.target is aten.mm.default + and list(candidate._input_nodes)[1].target is aten.cat.default + ): + mm_node_of_x = candidate + break + assert mm_node_of_x is not None, "unable to find mm node" + _, cat_wgt_node = mm_node_of_x._input_nodes + scaling_node = next(iter(mm_node_of_x.users.keys())) + user_of_scaling_node = next(iter(scaling_node.users.keys())) + # Some other pass is making some changes that entails + # adding a node before it's used, but it can only be found when + # lint is run. stable_topological_sort() is being run before lint, + # so that error was not being being discovered. + # We call stable_topological_sort here as a workaround. + stable_topological_sort(match.graph) + with match.graph.inserting_before(user_of_scaling_node): + new_cat_node = match.graph.call_function( + aten.cat.default, + args=([w1, w2, w3], 0), + ) + x_reshape_node = match.graph.call_function( + computation_reshape, args=(x, x_shape) + ) + new_woq_node = match.graph.call_function( + computation_woq, + args=(x_reshape_node, new_cat_node, scales), + ) + new_woq_node.meta = copy.copy(x.meta) + output_reshape_node = match.graph.call_function( + computation_reshape, args=(new_woq_node, out_shape) + ) + scaling_node.replace_all_uses_with(output_reshape_node) + match.graph.erase_node(scaling_node) + match.graph.erase_node(mm_node_of_x) + match.graph.erase_node(cat_wgt_node) + match.graph.lint() + + return woq_int8 + + +def _register_woq_lowering(pattern, computation_woq, computation_reshape): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_woq_optimization_pattern(), + ) + def woq_int8(match: Match, *args, **kwargs): + x = kwargs["x"] + weight = kwargs["weight"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = weight.get_size()[0] + origin_x_size = x.get_size() + x_shape = [-1, origin_x_size[-1]] + out_shape = origin_x_size[:-1] + [ + out_features, + ] + func1 = L[computation_reshape](x, x_shape) + func2 = L[computation_woq](func1, weight, scales) + return L[computation_reshape](func2, out_shape) + + return woq_int8 + + +def _register_woq_mm_int8_pattern1(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, with x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern2(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, w/o x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern3(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to bmm + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.bmm.default, + CallFunction(aten.expand.default, KeywordArg("x"), Arg()), + CallFunction( + aten.expand.default, + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern4(): + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg("weight"), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_int8_woq_concat_linear_pattern(): + def _create_wgt_node(wgt_node_name: str): + return CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg(wgt_node_name), + Arg(), + ), + Arg(), + ) + + cat_wgt = CallFunction( + aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1 + ) + + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt), + KeywordArg("scales"), + ) + _register_concat_linear_int8_woq_lowering( + _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape + ) + + +def _register_quantization_lowerings(): + _register_quantization_unary_lowering() + _register_quantization_binary_lowering() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _register_woq_lowerings(): + _register_woq_mm_int8_pattern1() + _register_woq_mm_int8_pattern2() + _register_woq_mm_int8_pattern3() + _register_woq_mm_int8_pattern4() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + dequant_node = ( + dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- reshape <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- dequant + ) + else: + dequant_node = ( + dequant_pattern_end_node # pattern: linear <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- dequant + ) + + if ( + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert source_node.op == "call_function", ( + "clone_to_new_node only support node.op call_function" + ) + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dequantize_per_tensor + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * dequantize_per_tensor + def _find_first_node_in_dequant_pattern(_node): + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + return _node + else: + assert len(_node.args) >= 1, ( + "In in dequant pattern, each node should have more than 1 arg." + ) + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv_pattern(dtype, with_dtype_convert): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") + or meta_value.dim() not in [3, 4] + ): + # Only support conv1d/2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + + if not with_dtype_convert: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass( + pattern, pass_number, dtype=torch.float32, with_dtype_convert=False +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv_pattern(dtype, with_dtype_convert), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if not with_dtype_convert: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if with_dtype_convert: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_node) # type: ignore[arg-type] + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) # type: ignore[arg-type] + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32, with_dtype_convert=False +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("autocast_act_dtype"), + with_dtype_convert, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns( + dtype=torch.float32, with_dtype_convert=False +): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + with_dtype_convert, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + with_dtype_convert, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_node( + linear_node, + input_index, + input_dim_exceeds_two, + input_contiguous, + with_dtype_convert, +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if not with_dtype_convert: + # pattern: linear -> reshape -> dequant + dequant_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> dequant + activation_to_bf16_node = act_reshape_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if not with_dtype_convert: + dequant_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + if not with_dtype_convert: + # pattern: linear -> dequant + dequant_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> dequant + activation_to_bf16_node = linear_node.args[input_index] + dequant_node = activation_to_bf16_node.args[0] + return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert +): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, + input_index, + input_dim_exceeds_two, + input_contiguous, + with_dtype_convert, + ) + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_dtype_convert=False, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, + input_index, + input_dim_exceeds_two, + input_contiguous, + with_dtype_convert, + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs.get("b") + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if with_dtype_convert: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + input_dim_exceeds_two=False, + is_tensor_overload=False, + with_dtype_convert=False, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + with_dtype_convert, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + with_dtype_convert, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, + is_tensor_overload=False, + with_dtype_convert=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + with_dtype_convert, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, + is_tensor_overload=False, + with_dtype_convert=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + is_tensor_overload, + with_dtype_convert, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + input_dim_exceeds_two, + is_tensor_overload, + with_dtype_convert, + ) + + +def _generate_linear_dynamic_fp16_pattern( + _dequant_weight_pattern, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + dtype = torch.float32 + t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype) + + if input_dim_exceeds_two and not input_contiguous: + # pattern is + # x -> expand -> bmm (-> add) (-> relu) + # w -> dequant -> permute -> expand / + pattern_no_bias = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + KeywordArg("x"), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + pattern_with_bias = CallFunction( + aten.add.Tensor, + pattern_no_bias, + KeywordArg("b"), + ) + if relu_fused: + pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias) + pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias) + return pattern_with_bias, pattern_no_bias + + x_pattern_with_reshape = _may_generate_pattern_with_reshape( + KeywordArg("x"), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + dequant_linear_no_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype, with_dtype_convert in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + if dtype == torch.float32 and with_dtype_convert: + continue + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( + dtype, with_dtype_convert + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + with_dtype_convert=with_dtype_convert, + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + with_dtype_convert, + ) in linear_weight_prepack_cases: + if dtype == torch.float32 and with_dtype_convert: + continue + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, + input_dim_exceeds_two, + is_tensor_overload=is_tensor_overload, + with_dtype_convert=with_dtype_convert, + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + with_dtype_convert=with_dtype_convert, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for ( + dtype, + with_bias, + is_tensor_overload, + with_dtype_convert, + ) in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] + ): + if dtype == torch.float32 and with_dtype_convert: + continue + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + is_tensor_overload=is_tensor_overload, + with_dtype_convert=with_dtype_convert, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_dtype_convert=with_dtype_convert, + ) + + +def _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + def _extra_check_fn(match: Match): + return match.kwargs["dtype_fp16"] == torch.float16 + + @register_freezing_graph_pattern( + pattern, + extra_check=_extra_check_fn, + pass_number=pass_number, + ) + def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + fp32 activation + | + mm/addmm <- t <- to_fp32 <- to_fp16 <- weight + | + (reshape) <- (relu) + + OR + + fp32 activation + | + expand + | + bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight + | + (add) <- (relu) + + Insert weight prepack node and change the pattern to: + fp32 activation + | + onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight + (or onednn.linear_relu_dynamic_fp16) + """ + # find params + x = kwargs["x"] + w = kwargs["w"] + bias = kwargs.get("b") + + # find linear node + nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default] + linear_nodes = [] + for node in nodes_to_find: + linear_nodes.extend(filter_nodes(match.nodes, node)) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + assert isinstance(linear_node, torch.fx.node.Node) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + # find relu node + relu_node = None + if relu_fused: + relu_node = match.output_node() + assert isinstance(relu_node, torch.fx.node.Node) + + # find reshape node, expand node and add node + ( + act_reshape_node, + output_reshape_node, + expand_x_node, + expand_w_node, + add_bias_node, + ) = (None, None, None, None, None) + t_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + t_node = linear_node.args[weight_index] + output_reshape_node = next(iter(linear_node.users)) + assert output_reshape_node.target is aten.reshape.default + else: + expand_x_node = linear_node.args[input_index] + expand_w_node = linear_node.args[weight_index] + assert isinstance(expand_w_node, torch.fx.node.Node) + t_node = expand_w_node.args[0] + if bias: + add_bias_node = next(iter(linear_node.users)) + assert add_bias_node.target is aten.add.Tensor + else: + t_node = linear_node.args[weight_index] + assert isinstance(t_node, torch.fx.node.Node) + + w_to_fp32_node = t_node.args[0] + assert ( + isinstance(w_to_fp32_node, torch.fx.node.Node) + and w_to_fp32_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + w_to_fp16_node = w_to_fp32_node.args[0] + assert ( + isinstance(w_to_fp16_node, torch.fx.node.Node) + and w_to_fp16_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + w, + x_shape, + ) + packed_weight_op = torch.ops.onednn.linear_prepack_fp16 + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + # create new linear node and insert on graph + new_args: tuple[Any, ...] = ( + x, + prepack_weight_node, + bias, + ) + linear_op = ( + torch.ops.onednn.linear_relu_dynamic_fp16.default + if relu_fused + else torch.ops.onednn.linear_dynamic_fp16.default + ) + new_linear_node = graph.call_function(linear_op, args=new_args) + out_node = match.output_node() + out_node.replace_all_uses_with(new_linear_node) + + # Erase the original nodes in the reverse order + new_linear_node.meta.update(out_node.meta) + if relu_node is not None: + graph.erase_node(relu_node) + if output_reshape_node is not None: + graph.erase_node(output_reshape_node) + if add_bias_node is not None: + graph.erase_node(add_bias_node) + graph.erase_node(linear_node) + if act_reshape_node is not None: + assert isinstance(act_reshape_node, torch.fx.node.Node) + graph.erase_node(act_reshape_node) + if expand_x_node is not None: + assert isinstance(expand_x_node, torch.fx.node.Node) + graph.erase_node(expand_x_node) + if expand_w_node is not None: + assert isinstance(expand_w_node, torch.fx.node.Node) + graph.erase_node(expand_w_node) + graph.erase_node(t_node) + graph.erase_node(w_to_fp32_node) + graph.erase_node(w_to_fp16_node) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _register_linear_dynamic_fp16_weight_prepack(): + to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + weight_pattern = CallFunction( + to_dtype_op, + CallFunction( + to_dtype_op, + KeywordArg("w"), + KeywordArg("dtype_fp16"), + ), + KeywordArg("dtype_fp32"), + ) + cases = itertools.product( + [False, True], # input_dim_exceeds_two + [True, False], # input_contiguous + [False, True], # relu fused + ) + for input_dim_exceeds_two, input_contiguous, relu_fused in cases: + patterns = _generate_linear_dynamic_fp16_pattern( + weight_pattern, + input_dim_exceeds_two, + input_contiguous, + relu_fused, + ) + for pattern in patterns: + _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number=0 if relu_fused else 1, + input_dim_exceeds_two=input_dim_exceeds_two, + input_contiguous=input_contiguous, + relu_fused=relu_fused, + ) + + +def _register_smooth_quant_int_mm_pattern(): + """ + The pattern is: + (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape + or + (with bias) pattern_no_bias -> add (-> reshape -> reshape) + """ + + # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist + # When torch.compile'ing with dynamic=False, they don't exist + def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True): + return CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mul.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten._int_mm.default, + CallFunction( + aten.reshape.default, + KeywordArg("a"), + KeywordArg("in_shape"), + ) + if reshape_a + else KeywordArg("a"), + KeywordArg("b"), + ), + KeywordArg("dtype"), + ), + ( + CallFunction( + aten.expand.default, + KeywordArg("x_scale"), + Arg(), + ) + if expand_a_scale + else KeywordArg("x_scale") + ), + ), + KeywordArg("w_scale"), + ) + + def _with_outer_reshape(pattern): + return CallFunction( + aten.reshape.default, pattern, KeywordArg("out_shape_no_bias") + ) + + # for torch.compile(dynamic=False) + pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False)) + pattern_with_bias_1 = CallFunction( + aten.add.Tensor, + pattern_no_bias_1, + KeywordArg("bias"), + ) + # for torch.compile(dynamic=True) + pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True)) + pattern_with_bias_2 = CallFunction( + aten.reshape.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.add.Tensor, + pattern_no_bias_2, + KeywordArg("bias"), + ), + Arg(), + ), + KeywordArg("out_shape_with_bias"), + ) + + # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # when both activation and weights are symmetrically quantized. + # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. + # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. + # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op. + pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=False, reshape_a=False + ) + pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=True, reshape_a=False + ) + + def _validate_pattern(match: Match): + if len(match.nodes) not in [4, 5, 6, 7, 10]: + return False + # Make sure weight is a constant + aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0] + if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node): + return False + if aten_int_mm_node.args[1].op != "get_attr": + return False + + if len(match.nodes) == 10: + # Check the two tailing reshape nodes can be fused + if match.nodes[9].args[1] != match.nodes[6].args[1]: + return False + if len(match.nodes) == 10 or ( + len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor + ): + bias_idx = 7 if len(match.nodes) == 10 else 6 + # Check bias shape + bias_node = match.nodes[bias_idx].args[1] + if not isinstance(bias_node, torch.fx.node.Node): + return False + if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr] + return False + return True + + pattern_to_pass_number = { + pattern_no_bias_2: 0, + pattern_with_bias_2: 0, + pattern_no_bias_1: 1, + pattern_with_bias_1: 1, + pattern1_with_no_outer_or_act_reshape: 2, + pattern2_with_no_outer_or_act_reshape: 2, + } + for pattern, pass_number in pattern_to_pass_number.items(): + + @register_freezing_graph_pattern( + pattern, + extra_check=_validate_pattern, + pass_number=pass_number, + ) + def _int_mm_weight_prepack(match: Match, *args, **kwargs): + bias = kwargs.get("bias") + x = kwargs["a"] + weight = kwargs["b"] + dtype = kwargs["dtype"] + x_scale = kwargs["x_scale"] + w_scale = kwargs["w_scale"] + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + transpose_node = match.graph.call_function( + aten.permute.default, args=(weight, [1, 0]) + ) + contig_node = match.graph.call_function( + aten.contiguous.default, args=(transpose_node,) + ) + packed_weight_inputs = ( + contig_node, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = match.graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + dummy_zp = None + w_scale = match.graph.call_function( + prims.convert_element_type.default, args=(w_scale, torch.float32) + ) + + x_scale_shape = x_scale.meta.get("tensor_meta").shape + x_scale_is_scalar = False + if not has_free_symbols(x_scale_shape): + prod = 1 + for d in x_scale_shape: + prod *= d + x_scale_is_scalar = prod == 1 + + new_args: tuple[Any, ...] + if x_scale_is_scalar: + # in this case, we can call onednn.qlinear directly + new_args = ( + x, + x_scale, + dummy_zp, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + else: + # onednn.qlinear does not support per-channel quantization of x + # so in this case, we have to apply x scale and add bias ourselves after qlinear + in_shape = kwargs.get("in_shape") + if in_shape is None: + x_reshaped = x + else: + x_reshaped = match.graph.call_function( + aten.reshape.default, args=(x, in_shape) + ) + new_args = ( + x_reshaped, + 1.0, # x_scale + 0, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + None, # bias + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise, args=new_args + ) + # apply x scale + new_out_node = match.graph.call_function( + aten.mul.Tensor, args=(new_linear_node, x_scale) + ) + + # Add bias and reshape + has_outer_reshape = ( + kwargs.get("out_shape_with_bias") is not None + or kwargs.get("out_shape_no_bias") is not None + ) + + if has_outer_reshape: + out_shape = kwargs.get( + "out_shape_with_bias", kwargs["out_shape_no_bias"] + ) + if bias is not None: + new_out_node = match.graph.call_function( + aten.add.Tensor, args=(new_out_node, bias) + ) + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + else: + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + out_node.replace_all_uses_with(new_out_node) + new_out_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +class PostOpAttr: + def __init__( + self, + binary_op_name: str = "none", + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ) -> None: + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + +def _register_qconv_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + if post_op_attr.unary_op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + post_op_attr.scalars_attr = [min_value, max_value] + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + accum = ( + kwargs["accum"] + if output_dtype in [torch.uint8, torch.int8] + else kwargs["accum_after_dequant"] + ) + accum_scale = ( + kwargs["accum_scale"] + if output_dtype in [torch.uint8, torch.int8] + else 1.0 + ) + accum_zp = ( + kwargs["accum_zp"] + if output_dtype in [torch.uint8, torch.int8] + else 0 + ) + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_conv_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qconv2d_binary_matcher_count" + if has_binary_post_op + else "qconv_unary_matcher_count" + ) + nodes_key = ( + "qconv2d_binary_matcher_nodes" + if has_binary_post_op + else "qconv_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + return qconv + + +def _register_qconv_unary_fusion(): + from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + conv_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + get_qconv_pt2e_pattern(users=1), + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_qconv_pt2e_pattern(users=1), aten.relu.default + ), + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(users=1), + 1, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + get_qconv_pt2e_pattern(users=1), aten.relu.default + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(users=1), + 1, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + +def _register_qconv_binary_fusion(): + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + swap_binary_inputs_list = [False, True] + binary_replace_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(users=1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + ), + PostOpAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(users=1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ), + ), + } + ) + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(users=1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ) + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + else: + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(users=1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + +def _register_qlinear_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qlinear_post_op_fusion(match: Match, *args, **kwargs): + """ + Match the pattern: + qlinear - post op + """ + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs.get("b") + + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + other = kwargs["other"] if "other" in kwargs else kwargs["accum"] + x2_scale = 1.0 + x2_zp = 0 + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_linear_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qlinear_binary_matcher_count" + if has_binary_post_op + else "qlinear_unary_matcher_count" + ) + nodes_key = ( + "qlinear_binary_matcher_nodes" + if has_binary_post_op + else "qlinear_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + +def _register_qlinear_unary_fusion(): + from .mkldnn_fusion import ( + _gelu_fusion_1 as _gelu_fusion_erf, + _gelu_fusion_2 as _gelu_fusion_tanh, + ) + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + +def _register_qlinear_binary_fusion(): + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + PostOpAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + +@functools.cache +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() + _register_linear_dynamic_fp16_weight_prepack() + + # Step 4: weight prepack for SmoothQuant from Torchao + _register_smooth_quant_int_mm_pattern() + + # Step 5: QLinear post op Fusion + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): + # skip fusion on ARM + _register_qconv_unary_fusion() + _register_qconv_binary_fusion() + _register_qlinear_unary_fusion() + _register_qlinear_binary_fusion() + + +def _is_valid_concat_linear_woq_int4_fusion(computation_nodes): + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr] + group_size = computation_nodes[0].args[2] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act # share same activation + and ( + node.args[1].meta.get("val").size(1) == in_feature_size + ) # same in feature size + and (node.args[1] != wgt or gemm_idx == 0) + and node.args[1].op == "get_attr" # wgt are all constants + and node.args[2] == group_size # same group size + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + +def concat_linear_woq_int4(gm: torch.fx.GraphModule): + """ + Concat Linear optimization pass for WOQ int4 + This pass fuses the original pattern: + def ... + return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...) + into a single operation: + def ... + concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp) + return split(concat_res, split_size_list) + """ + + def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype): + # Concat the wgts and scale_zps, and repack the wgt + unpacked_wgts = [] + for packed_wgt in packed_wgts: + # Get the unpacked weight list + # Same as https://github.com/pytorch/pytorch/pull/156174 + K = packed_wgt.size(1) * 2 + N = packed_wgt.size(0) + x = torch.eye(K).to(dtype=act_dtype) + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .to(dtype=act_dtype) + .expand(K // group_size, N, 2) + .contiguous() + ) + unpacked_wgts.append( + torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + packed_wgt, + group_size, + qscales_and_zeros, + ) + .t() + .contiguous() + .to(torch.int32) # N, K + ) + concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0) + repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + concat_unpacked_wgt, 1 + ) + concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous() + return repack_w, concat_scale_zp + + graph = gm.graph + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_concat_linear_woq_int4_fusion(users): + with graph.inserting_before(node): + assert all(user.args[1].op == "get_attr" for user in users) + computation_node_0 = users[0] + packed_wgts = [getattr(gm, user.args[1].target) for user in users] + group_size = computation_node_0.args[2] + scale_zps = [getattr(gm, user.args[3].target) for user in users] + out_feature_size_list = [ + packed_wgt.size(0) for packed_wgt in packed_wgts + ] + repack_w, concat_scale_zp = concat_wgt( + packed_wgts, scale_zps, group_size, act.meta.get("val").dtype + ) + repack_w_node_name = computation_node_0.args[1].target + "_concat" + concat_scale_zp_node_name = ( + computation_node_0.args[3].target + "_concat" + ) + gm.register_buffer(repack_w_node_name, repack_w) + setattr(gm, repack_w_node_name, repack_w) + gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp) + setattr(gm, concat_scale_zp_node_name, concat_scale_zp) + + repack_w_node = graph.create_node( + "get_attr", repack_w_node_name, (), {} + ) + with graph.inserting_after(repack_w_node): + concat_scale_zp_node = graph.create_node( + "get_attr", concat_scale_zp_node_name, (), {} + ) + + with graph.inserting_after(concat_scale_zp_node): + concat_int4_gemm_node = graph.create_node( + "call_function", + computation_op, + ( + act, + repack_w_node, + group_size, + concat_scale_zp_node, + ), + ) + with graph.inserting_after(concat_int4_gemm_node): + split_node = graph.create_node( + "call_function", + torch.ops.aten.split_with_sizes.default, + ( + concat_int4_gemm_node, + out_feature_size_list, + 1, # split dim + ), + ) + with graph.inserting_after(split_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + split_node, + gemm_idx, + ), + ) + with graph.inserting_after(get_item): + clone_node = graph.create_node( + "call_function", + torch.ops.aten.clone.default, + (get_item,), + {"memory_format": torch.contiguous_format}, + ) + user.replace_all_uses_with(clone_node) + graph.erase_node(user) + + +def quant_lift_up(graph_module: torch.fx.GraphModule): + """ + Lift up the quant node before view like nodes. It can benefit performance + of Attention like block. For example, we have the pattern as: + + DQ + DQ LINEAR + LINEAR VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + Q Q + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + We want to lift up the quant nodes from matmul before view like nodes + as the output of Linear node. + + DQ + DQ LINEAR + LINEAR Q + Q VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + It produces a DQ->LINEAR->Q pattern which can be fused by backend. + """ + + def is_view_op(node): + return node.op == "call_function" and node.target in _VIEW_OPS + + for node in graph_module.graph.nodes: + # Leslie: Here we verify that the quant node has exactly + # one input FX node, with constant scalar value for scale and zero point. + # For the case input of quant node has more than one input FX nodes, + # extend the implementation to lift up all the connected nodes + # before the view nodes to keep the topological order. + if ( + node.op == "call_function" + and node.target in _PER_TENSOR_QUANTIZE_OPS + and len(node.all_input_nodes) == 1 + and is_view_op(node.all_input_nodes[0]) + ): + quant_node = node + input_node_of_quant = quant_node.args[0] + + # Check the nodes along lift up path has only 1 user node + # Propagate view like node to find where to insert the new quant node + could_lift_up = True + current_node = quant_node + input_node = current_node.args[0] + while is_view_op(input_node): + if len(input_node.users) != 1: + could_lift_up = False + break + current_node = input_node + input_node = current_node.args[0] + + # Further check the input node of the first view node has only 1 user node + if could_lift_up and len(input_node.users) == 1: + counters["inductor"]["quant_lift_up_count"] += 1 + # Replace dequant's input from quant to quant's input + quant_node.replace_all_uses_with(input_node_of_quant) + # Insert the new quant node + with graph_module.graph.inserting_before(current_node): + new_quant_node = graph_module.graph.node_copy(quant_node) + input_node.replace_all_uses_with(new_quant_node) + + # Update inputs of new_quant_node + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == input_node_of_quant: + return input_node + else: + return n + + new_args = map_arg(new_quant_node.args, maybe_replace_node) + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) + new_quant_node.args = new_args # type: ignore[assignment] + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] + graph_module.graph.erase_node(quant_node) + + graph_module.graph.lint() + graph_module.recompile() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..e42e8a1139770d488929f772b0441fe4f616d449 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py @@ -0,0 +1,795 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from collections.abc import Callable, Sequence +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.fx.node +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger +from torch._guards import detect_fake_mode +from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, +) +from torch._inductor import config, inductor_prims +from torch._inductor.fx_utils import get_node_storage, is_node_realized +from torch._inductor.lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, +) +from torch._inductor.virtualized import V +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + GuardOnDataDependentSymNode, +) +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.reinplace import _is_view_op +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass(frozen=True) +class InplaceableOp: + inplace_op: Callable[..., Any] + mutated_arg: int + extra_check: Callable[[torch.fx.Node], bool] = lambda node: True + + +_SCATTER_OP_TO_VIEW = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} +_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} + + +def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (args, kwargs), + ) + with V.fake_mode: + fake_result = fn(*fake_args, **fake_kwargs) + + node = graph.call_function(fn, args, kwargs) + + node.meta["val"] = fake_result + + return node + + +@dataclass +class ViewOp: + target: torch._ops.OpOverload + args: tuple[Any, ...] + kwargs: dict[str, Any] + + +def _inplace_generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + tmp = inp + for view in view_ops: + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (view.args, view.kwargs), + ) + # slice and select can allocate new unbacked symints, but those won't be reflected + # in the output of this function, hence shall be ignored. + fake_mode = detect_fake_mode(fake_args) + with ( + fake_mode.shape_env.ignore_fresh_unbacked_symbols() + if fake_mode and fake_mode.shape_env + else nullcontext() + ): + tmp = view.target(tmp, *fake_args, **fake_kwargs) + try: + tmp.copy_(src) + except RuntimeError as e: + raise RuntimeError( + f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}" + ) from e + return inp + + +def _generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + out = inp.clone() + return _inplace_generalized_scatter(out, src, view_ops) + + +def _decompose_scatter_functional_helper( + graph: torch.fx.Graph, + inp: torch.Tensor, + src: torch.Tensor, + view_ops: list[ViewOp], +) -> torch.fx.Node: + view_op, view_ops_tail = view_ops[0], view_ops[1:] + + if view_ops_tail: + view = graph_call_function( + graph, view_op.target, inp, *view_op.args, **view_op.kwargs + ) + src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] + + return graph_call_function( + graph, + _VIEW_OP_TO_SCATTER[view_op.target], + inp, + src, + *view_op.args, + **view_op.kwargs, + ) + + +def _decompose_scatter_functional( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter to a sequence of view_scatter operations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + view = aten.slice(inp, 0, 0, 10) + view_updated = aten.slice_scatter(view, src, 1, 10, -10) + inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) + """ + assert node.target is _generalized_scatter + return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] + + +def _decompose_scatter_mutating( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter using mutations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + inp_updated = aten.clone(inp) + slice1 = aten.slice(inp_updated, 0, 0, 10) + slice2 = aten.slice(slice1, 1, 10, -10) + slice2.copy_(src) + + """ + assert node.target in (_generalized_scatter, _inplace_generalized_scatter) + inp, src, view_ops = node.args + assert not node.kwargs + + if node.target is _generalized_scatter: + inp = graph_call_function(graph, aten.clone, inp) + + tmp = inp + for view in view_ops: # type: ignore[union-attr] + tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + # we need to set unbacked bindings that could have been created in the view ops. + if (V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings( + V.fake_mode.shape_env, tmp.meta["val"] + ) + ): + tmp.meta["unbacked_bindings"] = symbol_to_path + + graph_call_function(graph, aten.copy_.default, tmp, src) + return inp # type: ignore[return-value] + + +# View ops whose view_scatter op is lowered into mutations anyway, +# so is never a pessimisation to decompose. +_ALWAYS_MUTATING_SCATTER_OPS = OrderedSet( + [ + aten.as_strided.default, + aten.diagonal.default, + ] +) + + +def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: + _, _, view_ops = node.args + view_ops = cast(Sequence[torch.fx.node.Argument], view_ops) + return any( + target in _ALWAYS_MUTATING_SCATTER_OPS + for view in view_ops + if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload) + ) + + +def should_reinplace_scatter(node: torch.fx.Node) -> bool: + """Choose between mutating and functional scatter decompositions + + Reinplacing view scatter ops can be pessimising as it blocks fusion with the + input or output tensor computations. However, it is still profitable if the + input and output would have been realized anyway. + + """ + inp, _src, _view_ops = node.args + + # Mutating scatter ops unconditionally realize input and output + if scatter_always_uses_mutation(node): + return True + + if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] + return True + + # If the output is copied back into the input, this forces both to be + # realized as the output is a user of the input + if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr] + user.target is aten.copy_.default and user.args[0] is inp for user in node.users + ): + return True + + # Otherwise, assume fusions will make functional variants profitable + return False + + +def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: + """Replace _generalized_scatter with normal aten ops""" + for node in itertools.chain( + graph.find_nodes(op="call_function", target=_generalized_scatter), + graph.find_nodes(op="call_function", target=_inplace_generalized_scatter), + ): + use_mutation = ( + node.target is _inplace_generalized_scatter + or scatter_always_uses_mutation(node) + ) + + with graph.inserting_before(node): + if use_mutation: + new_node = _decompose_scatter_mutating(graph, node) + else: + new_node = _decompose_scatter_functional(graph, node) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: + """ + This canonicalizes view scatter ops into a generalized form, defined as: + def scatter(inp, src, views): + tmp = inp.clone() + for view in views: + tmp = view(tmp) + tmp.copy_(src) + + We also fuse consecutive view scatter ops of the form + a = scatter(view2(self), src, [view1]) + b = scatter(self, a, [view2]) + which can be rewritten as + b = scatter(self, src, [view2, view1]) + a = view2(b) + + This is both more efficient as we only do a single scatter, and also + easier to reinplace since there is only one use of `self` + """ + + node_to_view_base: dict[torch.fx.Node, torch.fx.Node] = {} + node_to_view_op: dict[torch.fx.Node, list[ViewOp]] = defaultdict(list) + + def handle_views(node: torch.fx.Node): + inp = node.args[0] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment] + node_to_view_op[node] = [ + *node_to_view_op[inp], # type: ignore[index] + ViewOp( + node.target, # type: ignore[arg-type] + args=node.args[1:], + kwargs=node.kwargs, + ), + ] + + def handle_view_scatter(node: torch.fx.Node): + assert len(node.args) >= 2 + inp, src = node.args[:2] + + assert isinstance(node.target, torch._ops.OpOverload) + scatter_view_op = ViewOp( + _SCATTER_OP_TO_VIEW[node.target], + args=node.args[2:], + kwargs=node.kwargs, + ) + + def can_fuse(): + if src.target is not _generalized_scatter: # type: ignore[union-attr] + return False + src_inp, _src_src, _src_scatter_view_op = src.args # type: ignore[union-attr] + + inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] + return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] + *node_to_view_op[inp], # type: ignore[index] + scatter_view_op, + ] + + if not can_fuse(): + with graph.inserting_before(node): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src, + [scatter_view_op], + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + return + + _src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + with graph.inserting_before(src): # type: ignore[arg-type] + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src_src, + [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if src.users: # type: ignore[union-attr] + new_src = graph_call_function( + graph, + _SCATTER_OP_TO_VIEW[node.target], + new_node, + *node.args[2:], + **node.kwargs, + ) + + handle_views(new_src) + src.replace_all_uses_with(new_src) # type: ignore[union-attr] + + graph.erase_node(src) # type: ignore[arg-type] + + for node in graph.nodes: + if _is_view_op(node.target): + handle_views(node) + elif node.target in _SCATTER_OP_TO_VIEW: + handle_view_scatter(node) + + +inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), + _generalized_scatter: InplaceableOp( + _inplace_generalized_scatter, + 0, + extra_check=should_reinplace_scatter, + ), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops: dict[torch._ops.OpOverload, InplaceableOp] = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = OrderedSet([triton_kernel_wrapper_functional]) + + +# Operators that don't depend on the tensor data +META_ONLY_OPS = OrderedSet( + [ + aten.sym_size.int, + aten.sym_stride.int, + aten.sym_numel.default, + aten.sym_storage_offset.default, + ] +) + + +def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: + """ + Reinplaces in-placeable operations. + If there are no uses of a view of the mutated arg after the current node, + it is possible to inplace the op. + This above algorithm could be justified by observing side effects. While + we traverse the graph in forwards direction, only latter nodes could view + side effects of the current node. If the current node is not used later as + well as no view of this node is used later in the graph, then it is safe to + inplace as there would be no way to observe the side effects. + This condition is slightly different for graph inputs where they can only + be inplaced if the above condition is true and there's a copy_ in the + epilogue that signals that the caller wants to observe the mutation. + + Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from + input args, so instead of checking mutation on placeholder, AOTInductor + checks mutation on get_attr. This is subject to change in future. + """ + + copy_args_to_copy_nodes = {} + # maps argument to the first copy_ node that mutates it. + copy_nodes = {} + mutated_inputs = OrderedSet[Any]() + storage_to_nodes = defaultdict(list) + node_order: dict[Any, int] = {} + for i, node in enumerate(reversed(graph.nodes)): + node_order[node] = len(graph.nodes) - i - 1 + storage_to_nodes[get_node_storage(node)].append(node) + if node.target is aten.copy_.default and node.args[0].op in ( + "placeholder", + "get_attr", + ): + dst = node.args[0] + src = node.args[1] + # If the target is a getitem and it indexes a possible clone, + # then skip over it + if src.target is operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) + or (src.args[0].target is torch.ops.higher_order.auto_functionalized) + ): + src = src.args[0] + + copy_args_to_copy_nodes[(dst, src)] = node + copy_nodes[dst] = node + + mutated_inputs.add(node.args[0]) + + def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg): + node_loc = node_order[node] + copy_node_loc = node_order[copy_node] if copy_node is not None else None + + def is_meta_only_user(node): + if _is_view_op(node.target): + return all(is_meta_only_user(u) for u in node.users) + return node.target in META_ONLY_OPS + + for view in shared_view_nodes: + for user in view.users: + user_loc = node_order[user] + # Skip all users before node + if user_loc <= node_loc: + continue + # Ignore uses after the copy_ epilogue node, where the input + # has already been mutated anyway + if copy_node_loc is not None and copy_node_loc <= user_loc: + continue + # Reinplacing does not change shape metadata + if is_meta_only_user(user): + continue + # If our graph looks like: + # foo(mutated_arg) + # mutated_arg.copy_(other) + # then it's safe for us to reinplace foo because mutated_arg + # will get overwritten anyways. + if ( + user.target is torch.ops.aten.copy_.default + and mutated_arg is user.args[0] + ): + continue + return True + return False + + def can_inplace(node, mutated_arg): + # ls should be a list of tensors that all shares the same storage. + def _overlap(ls) -> bool: + try: + return len(compute_overlapping_tensors(ls)) != 0 + except GuardOnDataDependentSymNode: + # If we fail with data dependent error we assume they all overlap. + return True + + if isinstance(mutated_arg, (list, tuple)): + # TODO Using _overlap here causes a several issues. + unique_storages = OrderedSet(get_node_storage(arg) for arg in mutated_arg) + if len(unique_storages) != len(mutated_arg): + # At least two Tensors in mutated_arg alias each other, so we can't reinplace it. + # We can probably do better (that is, reinplace one of them and clone the other) + # but that requires more work and mutable List[Tensor] are not that common. + return False + return all(can_inplace(node, arg) for arg in mutated_arg) + + if get_node_storage(mutated_arg) is None: + return False + + shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + + # Only keep tensor that might overlap with mutated_arg. + shared_view_nodes = [ + v + for v in shared_view_nodes + if _overlap([mutated_arg.meta["val"], v.meta["val"]]) + ] + + if mutated_arg.op in ("placeholder", "get_attr"): + # Get the first copy_ node that mutates the mutated_arg. + copy_node = copy_nodes.get(mutated_arg) + if copy_node is None: + # There is no copy_ back to the candidate mutated_arg (which is a graph input). + # Therefore the semantics of the program are that it does not mutate + # mutated_arg, so we cannot re-inplace it. + return False + if any_use_of_views_after_node( + node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg + ): + return False + + return True + elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + + # If mutated arg is view of any of the inputs of the graph, + # do not allow for inplacing. + # This would require more sophisticated algorithm to handle + return False + else: + return not any_use_of_views_after_node( + node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg + ) + + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ): + # Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes. + missed_bytes = 0 + + def bytes(node): + t = node.meta.get("val", None) + if ( + t is not None + and isinstance(t.element_size(), int) + and isinstance(t.numel(), int) + ): + return t.element_size() * t.numel() + else: + return 0 + + for node in missed_nodes: + if isinstance(node, (list, tuple)): + for n in node: + missed_bytes += bytes(n) + else: + missed_bytes += bytes(node) + + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance. Total size of missed opportunities with static shapes is" + " : %s bytes.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_bytes, + ) + + ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args)) + ReinplaceCounters.add_missed_bytes(trigger, missed_bytes) + + replace_dict: dict[torch.fx.Node, torch.fx.Node] = {} + + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, trigger + ): + tensors_to_clone: list[str] = [] + storage_of_reinplaced_args = OrderedSet[int | None]() + + # Those used to count possibly_missed_reinplacing_opportunities + missed_nodes = [] + missed_args = [] + + # TODO this logic can be made more precise using _overlap + def tensor_with_same_storage_already_reinplaced(arg): + if isinstance(arg, (list, tuple)): + return any( + get_node_storage(a) in storage_of_reinplaced_args for a in arg + ) + return get_node_storage(mutated_arg) in storage_of_reinplaced_args + + for arg in old_tensors_to_clone: + assert arg in kwargs + + mutated_arg = kwargs[arg] + + # Let's say we have: + # - op(x, y) that mutates both x and y + # - new_x, new_y = functional_op(x, y) is the functional variant + # If we are presented with functional_op(x, x), we must not reinplace + # this into op(x, x), because then it would be writing to the same Tensor. + # Instead, it's OK to reinplace one of them and to clone the other: + # >>> y = x.clone() + # >>> op(x, y) + # This also applies if we have views: functional_op(x, x[0]) + # should not reinplace into op(x, x[0]). + should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced( + mutated_arg + ) + if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + if trigger != ReInplaceTrigger.AUTO_FUNC_V2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target is operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg + + if isinstance(mutated_arg, (list, tuple)): + for a in mutated_arg: + storage_of_reinplaced_args.add(get_node_storage(a)) + else: + storage_of_reinplaced_args.add(get_node_storage(mutated_arg)) + else: + if should_attempt_reinplace: + missed_args.append(arg) + missed_nodes.append(mutated_arg) + + tensors_to_clone.append(arg) + + log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ) + return tensors_to_clone + + for node in graph.nodes: + if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + node.target = inplaceable_op.inplace_op + elif node.target is torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + ReInplaceTrigger.AUTO_FUNC_V2, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone + elif node.target is torch.ops.higher_order.auto_functionalized: + _mutable_op = node.args[0] + from torch._higher_order_ops.auto_functionalize import get_mutable_args + + tensors_to_clone, _ = get_mutable_args(_mutable_op) + # Don't try to reinplace Tensor | None args that are None. + tensors_to_clone = [ + t for t in tensors_to_clone if node.kwargs[t] is not None + ] + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + tensors_to_clone, + node.kwargs, + _mutable_op._name, + ReInplaceTrigger.AUTO_FUNC_V1, + ) + + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = tensors_to_clone + elif node.target in inplaceable_triton_ops: + kernel_idx = node.kwargs["kernel_idx"] + kernel = kernel_side_table.get_kernel(kernel_idx) + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + if isinstance(kernel, JITFunction): + kernel_name = kernel.fn.__name__ + elif isinstance(kernel, Autotuner): + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ + else: + raise AssertionError("Unknown triton kernel type") + + # inplaceable_triton_ops take an additional argument called + # tensors_to_clone which contain a list of tensors to clone + # This pass iterates over them and sees which ones are safe + # to eliminate (i.e. no longer need the clones) + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + node.kwargs["tensors_to_clone"], + node.kwargs["kwargs"], + kernel_name, + ReInplaceTrigger.TRITON_OPS, + ) + + kwargs = dict(node.kwargs) + kwargs["tensors_to_clone"] = tensors_to_clone + node.kwargs = immutable_dict(kwargs) + if "eager_input_vals" in node.meta: + # We changed the kwargs, so we need to update eager_input_vals + # to something sane. + args, kwargs = node.meta["eager_input_vals"] + new_kwargs = {**kwargs} + new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) + new_kwargs = immutable_dict(new_kwargs) + node.meta["eager_input_vals"] = (args, new_kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + replace_dict[copy_node] = copy_node.args[0] + + node.target = inplaceable_op.inplace_op + for node, replacement in replace_dict.items(): + while replacement in replace_dict: + replacement = replace_dict[replacement] + replace_dict[node] = replacement + + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + +def reinplace_inplaceable_ops( + fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater, + graph: torch.fx.Graph, +) -> None: + with enable_python_dispatcher(): + canonicalize_view_scatter_ops(graph) + # canonicalize_view_scatter_ops adds new operations to the graph. + # We run fake_tensor_updater to update the alias information. + # Correct alias information is required for `reinplace_inplaceable_ops_core`. + fake_tensor_updater.incremental_update() + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..150ba5cde4a7cb7c2e3f1a8987082ea11c766c3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import collections +import logging + +import torch +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass(subsystem="joint_graph_passes") +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + with GraphTransformObserver(gm, "fuse_seed_creation_pass", "joint_graph_passes"): + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +# pyrefly: ignore [bad-argument-type] +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +# pyrefly: ignore [bad-argument-type] +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +# pyrefly: ignore [bad-argument-type] +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +# pyrefly: ignore [bad-argument-type] +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(replacement, [size]) + + +# pyrefly: ignore [bad-argument-type] +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(low, high, size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(replacement, [low, high, size]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..6347bda3b525c200ce21cb87ecc2b4a3a685e25c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,3040 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +import os +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + PatternMatcherPass, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = tuple[ + _Arguments | None, + _Arguments | None, + _Arguments | None, + _Arguments | None, +] +_Range: TypeAlias = tuple[int, int] + + +PRE_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} +POST_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} + +pre_grad_pass_names = [ + "normalization_pass", + "remove_split_with_size_one_pass", + "merge_getitem_cat_pass", + "merge_stack_tahn_unbind_pass", + "merge_splits_pass", + "mutate_cat_pass", + "split_cat_pass", + "unbind_stack_pass", + "split_cat_to_slices_pass", + "unbind_cat_to_view_pass", + "split_stack_to_cats_pass", + "unbind_stack_to_slices_pass", + "move_reshape_out_of_split_stack_pass", + "einsum_to_pointwise_pass", +] + +post_grad_pass_names = [ + "normalization_aten_pass", + "decompose_mm_pass", + "unbind_stack_aten_pass", + "shape_padding_multiplier", + "pad_aten_mm_pass", + "split_cat_aten_pass", + "select_cat_aten_pass", + "move_view_after_cat_aten_pass", +] + +backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor") + +for pass_name in pre_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in PRE_GRAD_FUSIONS: + continue + PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + +for pass_name in post_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in POST_GRAD_FUSIONS: + continue + POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + + +def construct_pattern_matcher_pass(pass_name: str): + """ + Return the specific pattern_matcher_pass given the pass name. + """ + if pass_name in PRE_GRAD_PATTERNS: + return PRE_GRAD_PATTERNS[pass_name] + else: + return POST_GRAD_PATTERNS[pass_name] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +def _get_dim(node: Any): + assert isinstance(node, torch.fx.Node) + if "dim" in node.kwargs: + assert isinstance(node.kwargs["dim"], int) + return node.kwargs["dim"] + if node.target is torch.unbind: + if len(node.args) == 2: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + if node.target is torch.split: + if len(node.args) == 3: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + raise AssertionError( + f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], tuple[torch.fx.Node | None, Any | None, int | None] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters[backend]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +def remove_split_with_size_one(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + # theoretically nodes with no users should be removed, but we have seen the corner case + # thus we add its users check to walk around the StopIteration error. + if len(split_sections) == 1 and len(split_node.users.keys()) > 0: + # find the grand children of the split_node + next_users = find_next_users(split_node) + user = next(iter(split_node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, split_input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(split_node) + counters[backend]["remove_split_with_size_one_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if not is_node_meta_valid(input): + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + # pyrefly: ignore [unsupported-operation] + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters[backend]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + # pyrefly: ignore [unsupported-operation] + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + new_args = (tensors,) + new_kwargs = {"dim": cat_dim} + if ( + cat_node.args == new_args + and cat_node.kwargs == new_kwargs + and cat_node.op == "call_function" + and cat_node.target is torch.cat + ): + return + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + kwargs=new_kwargs, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters[backend]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, # type: ignore[arg-type] + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters[backend]["normalization_pass"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users: + for getitem_user in getitem_node.users: + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + new_squeeze_node.meta.update(squeeze_node.meta) + match.graph.erase_node(squeeze_node) + + +@register_graph_pattern( + CallMethodVarArgs("reshape", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_reshape_default(match: Match, *args, **kwargs): + reshape_node = match.nodes[0] + if not is_node_meta_valid(reshape_node): + log.debug("example value absent for node: %s", reshape_node) + return + reshape_input = get_arg_value(reshape_node, 0) + + if free_symbols(reshape_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", reshape_node) + return + + with match.graph.inserting_after(reshape_node): + new_reshape_node = match.graph.call_function( + torch.reshape, + args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), + ) + reshape_node.replace_all_uses_with(new_reshape_node) + new_reshape_node.meta.update(reshape_node.meta) + match.graph.erase_node(reshape_node) + + +@register_graph_pattern( + CallMethodVarArgs("clamp", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallFunctionVarArgs(torch.clamp, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_clamp_default(match: Match, *args, **kwargs): + clamp_node = match.nodes[0] + if not is_node_meta_valid(clamp_node): + log.debug("example value absent for node: %s", clamp_node) + return + + if free_symbols(clamp_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", clamp_node) + return + if len(clamp_node.args) > 1: + args = (get_arg_value(clamp_node, 0),) + kwargs = { + "min": get_arg_value(clamp_node, 1, kwarg_name="min"), + "max": get_arg_value(clamp_node, 2, kwarg_name="max"), + } + else: + args = clamp_node.args + kwargs = clamp_node.kwargs + with match.graph.inserting_after(clamp_node): + new_clamp_node = match.graph.call_function( + torch.clamp, + args=args, + kwargs=kwargs, + ) + clamp_node.replace_all_uses_with(new_clamp_node) + new_clamp_node.meta.update(clamp_node.meta) + match.graph.erase_node(clamp_node) + + +@register_graph_pattern( + CallMethodVarArgs("detach", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_detach_default(match: Match, *args, **kwargs): + detach_node = match.nodes[0] + if not is_node_meta_valid(detach_node): + log.debug("example value absent for node: %s", detach_node) + return + + if free_symbols(detach_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", detach_node) + return + + with match.graph.inserting_after(detach_node): + new_detach_node = match.graph.call_function( + torch.detach, + args=detach_node.args, + ) + detach_node.replace_all_uses_with(new_detach_node) + new_detach_node.meta.update(detach_node.meta) + match.graph.erase_node(detach_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split) -> None: + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = OrderedSet[int]() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: list[int], + next_split_sections: list[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = _get_dim(first_split) + + to_remove = [] + + with graph.inserting_before(first_split): # type: ignore[arg-type] + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + if is_node_meta_valid(first_split_input): + new_split.meta["example_value"] = torch.split( + first_split_input.meta["example_value"], + new_split_sections, + dim=first_split_dim, + ) + first_split_num_to_user = { + user.args[1]: user + for user in first_split.users # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = {user.args[1]: user for user in node.users} + # It is not necessary all getitems from the split node are used. + for next_split_num in range(len(next_split_sections)): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + if next_split_num not in next_split_num_to_user: + continue + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters[backend]["merge_splits_pass"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, + split_node, + next_users, + user_inputs_list_new, + transform_params_list, # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters[backend]["unbind_stack_pass"] += 1 + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: list[torch.fx.Node] + ) -> list[list[torch.fx.Node | _Range]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: list[list[torch.fx.Node | _Range]] = [] + for user in next_users: + if user.target in (torch.cat, torch.stack): + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> list[torch.fx.Node | _Range]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = OrderedSet(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> list[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = OrderedSet(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: list[torch.fx.Node | int] + ) -> list[torch.fx.Node | _Range]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: list[list[torch.fx.Node | _Range]], + ) -> list[_Range] | None: + ranges = OrderedSet[Any]() + for user_inputs in user_inputs_list: + ranges.update(u for u in user_inputs if isinstance(u, tuple)) + + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters[backend]["scmerge_split_sections_removed"] = len(split_sections) - len( + split_ranges + ) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool: + for range_, next_range in itertools.pairwise(ranges): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: list[_Range], min_: int, max_: int) -> list[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[torch.fx.Node | _Range]], + ) -> list[list[_TransformParam]] | None: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = _get_dim(split_node) + split_sections = split_node.args[1] + transform_params_list: list[list[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in (torch.cat, torch.stack): + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target is torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + + 1 # type: ignore[index] + ] + # All sections should be equal + if len(OrderedSet(subset_split_sections)) != 1: # type: ignore[arg-type] + return None + + num_splits = len(subset_split_sections) # type: ignore[arg-type] + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target is torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + user_inputs_list: list[list[torch.fx.Node | _Range]], + split_ranges: list[_Range], + ) -> list[list[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = _get_dim(split_node) + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] + new_split.meta["example_value"] = torch.split( + split_input.meta["example_value"], # type: ignore[union-attr] + [r[1] - r[0] for r in split_ranges], + dim=split_dim, + ) + counters[backend]["scmerge_split_added"] += 1 + split_items = [] + with graph.inserting_after(new_split): + for i in range(len(split_ranges)): + getitem = graph.call_function(operator.getitem, args=(new_split, i)) + if is_node_meta_valid(new_split): + getitem.meta["example_value"] = new_split.meta["example_value"][ + i + ] + split_items.append(getitem) + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + # pyrefly: ignore [bad-argument-type] + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list_new, + transform_params_list: list[list[_TransformParam]], + ): + split_dim = _get_dim(split_node) + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in (torch.cat, torch.stack): + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack, to_stack_meta = [], [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + # pyrefly: ignore [bad-argument-type] + if not is_node_meta_valid(user_input_new): + log.debug("example value absent for node: %s", user_input_new) + return + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + # pyrefly: ignore [missing-attribute] + to_stack_meta.append(user_input_new.meta["example_value"]) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + to_stack, to_stack_meta = [], [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + # pyrefly: ignore [missing-attribute] + to_stack_meta.append(user_input_new.meta["example_value"]) + continue + + if unflatten_params: + # pyrefly: ignore [missing-attribute] + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *unflatten_params, # type: ignore[arg-type] + ) + if movedim_params: + # pyrefly: ignore [missing-attribute] + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *movedim_params, # type: ignore[arg-type] + ) + if flatten_params: + # pyrefly: ignore [missing-attribute] + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type] + user_input_new_meta, + *flatten_params, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(user_input_new) + user_inputs_new_transformed_meta.append( + # pyrefly: ignore [missing-attribute] + user_input_new.meta["example_value"] + ) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + user_inputs_new_transformed_meta, + dim=cat_dim, + ) + counters[backend]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + new_cat_node.meta["example_value"] = ( + user_inputs_new_transformed_meta[-1] + ) + + if ( + user_node.target is torch.cat + and split_dim != cat_dim + and split_node.target is torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node_meta = new_cat_node.meta["example_value"] + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + new_cat_node.meta["example_value"] = torch.flatten( + new_cat_node_meta, + cat_dim, + cat_dim + 1, + ) + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + ): + to_remove = [split_node] + counters[backend]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in (torch.cat, torch.stack): + continue + counters[backend]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + if len(node.users.keys()) == 0: + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + if not is_node_meta_valid(unbind_node): + return + # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node + # before we do the unbind remove, otherwise it will hit the error when we unbind part of them + getitem_indices = [getitem_node.args[1] for getitem_node in unbind_node.users] + if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] + getitem_indices + ) != len(unbind_node.meta["example_value"]): + return + num_unbind = len(getitem_indices) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: list[int], + next_users: list[torch.fx.Node], + user_inputs_list: list[list[torch.fx.Node | _Range]], + ) -> list[_Range] | None: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[torch.fx.Node | _Range]], + ) -> list[list[_TransformParam]] | None: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = _get_dim(split_node) + transform_params_list: list[list[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target is torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target is torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1) -> None: + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target is torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target is torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + if is_node_meta_valid(split_input): + unbind.meta["example_value"] = torch.unbind( + split_input.meta["example_value"], dim=dim + ) + for item_index, getitem_node in sorted( + [(getitem_node.args[1], getitem_node) for getitem_node in split.users] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters[backend]["split_cat_pass"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +reshape_getitem_split = ListOf( + CallFunction( + torch.reshape, + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def simplify_split_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target is torch.split) + # pyrefly: ignore [bad-argument-type] + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: # type: ignore[union-attr] + return False + return True + + +def remove_zeros(split_sections: list[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: list[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in itertools.pairwise(arr)) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # pyrefly: ignore [bad-return] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), +) +def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target is torch.split) + split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target is torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + # type: ignore[union-attr] + indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index] + split_node, + indices, # type: ignore[arg-type] + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 # type: ignore[index] + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters[backend]["merge_getitem_cat_pass"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), +) +def mutate_cat_node(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target is torch.split) + _split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target is torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] + # remove the cat node + graph.erase_node(cat_user) + counters[backend]["mutate_cat_pass"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, + list(range(indices[0])), # type: ignore[arg-type] + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, + indices, # type: ignore[arg-type] + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters[backend]["mutate_cat_pass"] += 1 + + +getitem_split_aten = ListOf( + CallFunction( + operator.getitem, + CallFunctionVarArgs([torch.ops.aten.split_with_sizes.default], users=MULTIPLE), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + assert isinstance(split_node.meta["val"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["val"]] + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + # we also need to check the input of the split_node + # primals =torch.randn(4096, 300) + # split = torch.ops.aten.split.Tensor(primals, 320, 1) -> truncate to 300 automatically + # split_2 = torch.ops.aten.split_with_sizes.default(primals, [320], dim = 1) -> runtime error + split_input_size = split_input.meta["val"].shape[split_dim] + split_size = min(split_size, split_input_size) + split_section_list = [split_size] * (len(split_node.meta["val"])) + new_args = (split_input, split_section_list) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters[backend]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split_with_sizes.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_with_size_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_sections, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_sections is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters[backend]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_aten_pass"), +) +def merge_split_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + threshold_to_cat = torch._inductor.config.post_grad_fusion_options[ + "split_cat_aten_pass" + ].get("threshold_to_cat", 10) + # get the getitem nodes from the split node + getitem_nodes = list(split_node.users.keys()) + for cat_node in list(getitem_nodes[0].users.keys()): + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + try: + cat_input_len = len(cat_inputs) + except TypeError: + continue + if cat_input_len < threshold_to_cat: + continue + # check split node and cat node has same dim, and all getitem nodes have same parent node + parent_to_indices = defaultdict(list) # type: ignore[var-annotated] + parent_to_getitems = defaultdict(list) # type: ignore[var-annotated] + for cat_input in cat_inputs: + # skip all non-getitem cat input + if cat_input.target != operator.getitem: + continue + current_getitem_parent = cat_input.args[0] + split_dim = get_arg_value(current_getitem_parent, 2, "dim") + if split_dim != cat_dim: + break + getitem_idx = cat_input.args[1] + if ( + current_getitem_parent not in parent_to_indices + ) or getitem_idx != parent_to_indices[current_getitem_parent][-1][-1] + 1: + parent_to_indices[current_getitem_parent].append([getitem_idx]) + parent_to_getitems[current_getitem_parent].append([cat_input]) + else: + parent_to_getitems[current_getitem_parent][-1].append(cat_input) + parent_to_indices[current_getitem_parent][-1].append(getitem_idx) + + cat_inputs_list = list(cat_inputs) + update_cat_arg = [] + # iterate through the indices to construct the slice nodes + for parent, indices in parent_to_indices.items(): + for idx, indice in enumerate(indices): + start, end = indice[0], indice[-1] + split_sections = list(parent.args[1]) + input_of_current_getitem_parent = parent.args[0] + if len(indice) >= threshold_to_cat or len(indice) == len( + split_sections + ): + if len(indice) != len(split_sections): + # get the start and end slicing indices + slice_node = graph.call_function( + torch.ops.aten.slice.Tensor, + args=( + input_of_current_getitem_parent, + split_dim, # type: ignore[possibly-undefined] + sum(split_sections[:start]), + sum(split_sections[: end + 1]), + ), + ) + else: + slice_node = input_of_current_getitem_parent + # find the index in the cat_inputs_list given the getitem node + update_cat_arg.append( + ( + slice_node, + cat_inputs_list.index(parent_to_getitems[parent][idx][0]), + cat_inputs_list.index(parent_to_getitems[parent][idx][-1]), + ) + ) + + result = [] + i = 0 + for slice_tensor, start, end in update_cat_arg: + while i < start: + result.append(cat_inputs_list[i]) + i += 1 + result.append(slice_tensor) + i = end + 1 + while i < len(cat_inputs_list): + result.append(cat_inputs_list[i]) + i += 1 + + cat_node.update_arg(0, result) + for getitem_node in getitem_nodes: + if len(getitem_node.users) == 0: + graph.erase_node(getitem_node) + if len(split_node.users) == 0: + graph.erase_node(split_node) + counters[backend]["split_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + ListOf( + CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE), + partial=True, + ), + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"), +) +def merge_select_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + node = match.nodes[0] + node_input = get_arg_value(node, 0, "tensors") + # get the select nodes from the node + select_nodes = list(node_input.users.keys()) + for cat_node in list(node.users.keys()): + if cat_node.target is torch.ops.aten.cat.default: + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + # check all select nodes has same slice dim + if not all( + select_node.args[1] == select_nodes[0].args[1] + for select_node in select_nodes + ): + continue + # We only consider the case where selece slice dim and cat node has same dim + if select_nodes[0].args[1] != cat_dim: + continue + if not is_node_meta_valid(cat_node): + continue + # check the cat node has consecutive indices + indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] + if ( + not is_sorted_and_consecutive(indices) # type: ignore[arg-type] + or len(select_nodes) != len(cat_inputs) + ): + continue + # check all the select nodes can be merged to the cat node input + if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr] + continue + # reshape the node input to be the same shape as the cat node + with graph.inserting_before(node): + view_node = graph.call_function( + torch.ops.aten.view.default, + args=(node_input, cat_node.meta["val"].shape), + ) + # replace the node input with the new node + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters[backend]["select_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.debug("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + # pyrefly: ignore [unsupported-operation] + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters[backend]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target is torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target is torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters[backend]["unbind_stack_aten_pass"] += 1 + + +def divide_into_consecutive_sublists(indices: list[int]) -> list[list[int]]: + n = len(indices) + if n <= 1: + return [indices] + + # Initialize the list of sublists + sublists = [] + + # Iterate over the indices + i = 0 + while i < n: + # Initialize the current sublist + sublist = [indices[i]] + + # Iterate over the remaining indices + j = i + 1 + while j < n and indices[j] == indices[j - 1] + 1: + # Add the next index to the current sublist + sublist.append(indices[j]) + j += 1 + + # Add the current sublist to the list of sublists + sublists.append(sublist) + # Move to the next index + i = j + + return sublists + + +def update_args_from_split_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) + # case 1: the number of getitems is the same as the split size, eliminate the split + if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( + getitem_indices + ): + # we can merge the getitems from the previous parent + new_cat_args.append(split_input) + new_cat_args_meta.append(split_input.meta["example_value"]) + else: + if len(getitem_indices) > 0: + # case 2: the number of getitems is smaller than the split size but larger than the threshold, and + # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups + geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) + for sublist in geitem_indices_sublist: + if len(sublist) >= threshold_to_cat: + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + start_fused_size = sum(split_size[: sublist[0]]) + end_fused_size = sum(split_size[: sublist[-1] + 1]) + slice_list = [] + for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append( + slice(start_fused_size, end_fused_size, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(split_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = split_input.meta[ + "example_value" + ][tuple(slice_list)] + new_cat_args.append(slice_node) + new_cat_args_meta.append(slice_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in sublist: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append( + idx_to_getitems[i].meta["example_value"] + ) + + +def reshape_cat_node( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + unbind_input: torch.fx.Node, + cat_dim: int, + unbind_dim: int, + cat_shape: torch.Size, +) -> torch.fx.Node: + if cat_dim != unbind_dim: + # construct the permute node args, which has the same shape as the slice node + # then it has the same dim as the unbind_input, i.e., shape of cat + 1 + with graph.inserting_after(cat_node): + permute_list = list(range(len(cat_shape) + 1)) + permute_list[unbind_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[unbind_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(unbind_input, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + unbind_input.meta["example_value"], permute_list + ) # type: ignore[arg-type] + else: + permute_node = unbind_input + with graph.inserting_after(permute_node): + reshape_node = graph.call_function( + torch.reshape, args=(permute_node, tuple(cat_shape)) + ) + reshape_node.meta["example_value"] = torch.reshape( + permute_node.meta["example_value"], tuple(cat_shape) + ) # type: ignore[arg-type] + return reshape_node + + +def update_args_from_unbind_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, # cat or stack node + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input + unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim + cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim + # case 1: the number of getitems is the same as the split size, eliminate the split + size = list(unbind_input.meta["example_value"].shape)[unbind_dim] + if size == len(getitem_indices): + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + # we can merge the getitems from the previous parent + reshape_node = reshape_cat_node( + graph, node, unbind_input, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( + getitem_indices + ): + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + slice_list = [] + for i in range(len(cat_shape) + 1): + if i != unbind_dim: + slice_list.append(slice(None, None, None)) # start, end, step + else: + slice_list.append( + slice(getitem_indices[0], getitem_indices[-1] + 1, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(unbind_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = torch.narrow( + unbind_input.meta["example_value"], + unbind_dim, + getitem_indices[0], + getitem_indices[-1] - getitem_indices[0] + 1, + ) + reshape_node = reshape_cat_node( + graph, node, slice_node, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in getitem_indices: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) + + +def construct_cat_args( + graph: torch.fx.Graph, + cat_or_stack_node: torch.fx.Node, + inputs: list[torch.fx.Node], + split_or_unbind_node: torch.fx.Node, + threshold_to_cat: int = 2, + run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] +) -> tuple[list[torch.fx.Node], list[torch.Tensor]]: + new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] + new_cat_args_meta = [] # type: ignore[var-annotated] + for input in inputs: + if input.target != operator.getitem: + # update the last arg based on getitem_indices and parents_seens + if len(parents_seen) > 0: + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + # reset the indices array + getitem_indices, idx_to_getitems = [], {} + else: + # get the parent node of the getitem input + parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] + if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # cannot use parents_seen to check since the first item could be non getitem node + if len(parents_seen) == 0: + parents_seen.append(parent) + idx_to_getitems[idx] = input + getitem_indices.append(idx) + # case: we only have one getitem input, and it is in the last position + if input == inputs[-1]: + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # if it is the last input in the tensors, we also check if it can be optimized + if parent != parents_seen[-1] or input == inputs[-1]: + if input == inputs[-1]: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + # reset the indices array for the next parent + # remember to add the last element since it is the first + # item in this round of parent + # add the parent to the list of seen parents + parents_seen.append(parent) + getitem_indices, idx_to_getitems = [idx], {idx: input} + else: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + return new_cat_args, new_cat_args_meta + + +def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.Node]): + nodes = OrderedSet[Any]() + for input in inputs: + if input.target is operator.getitem: + nodes.add(input.args[0]) # type: ignore[union-attr] + if len(input.users.keys()) == 0: + graph.erase_node(input) + # check the split node to remove if it has no users + for node in nodes: + if len(node.users.keys()) == 0: # type: ignore[union-attr] + graph.erase_node(node) # type: ignore[arg-type] + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# other inputs getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) other inputs -> -> user=multiple +# / \ +# cat (user=mul, dim=1, split_node) + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_nodes = [node for node in match.nodes if node.target is torch.split] + if split_nodes: + split_node = next(node for node in split_nodes) + else: + # Handle the case where there are no nodes with a target of torch.split + return + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_cat_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + next_users = find_next_users(split_node) + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, _ = construct_cat_args( + graph, + cat_node, + cat_inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: if new cat args has length 1, we can remove the cat node + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters[backend]["split_cat_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): + new_args = (new_cat_args,) + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + # split and cat have the same dim + kwargs={"dim": split_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) + counters[backend]["split_cat_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=0) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), +) +def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_cat_to_view_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + cat_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + # get the view shape + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters[backend]["unbind_cat_to_view_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(cat_node, 1, "dim") + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) # type: ignore[arg-type] + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters[backend]["unbind_cat_to_view_pass"] += 1 + + +def reshape_cat_node_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + split_or_unbind_dim: int, +) -> None: + # reshape the cat node to the stack node shape + stack_shape = stack_node.meta["example_value"].shape + stack_dim = _get_dim(stack_node) + if stack_dim != split_or_unbind_dim: + # case 1: the stack dim is not the same as the split dim + # we need to reshape the split input before we do the reshape + reshape_list = list(stack_shape) + reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( + reshape_list[split_or_unbind_dim], + reshape_list[stack_dim], + ) + reshape_node = graph.call_function( + torch.reshape, + args=(cat_node, tuple(reshape_list)), + ) + reshape_node.meta["example_value"] = torch.reshape( + cat_node.meta["example_value"], + tuple(reshape_list), # pyrefly: ignore [bad-argument-type] + ) + permute_list = list(range(len(stack_shape))) + permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( + permute_list[split_or_unbind_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(reshape_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + reshape_node.meta["example_value"], permute_list + ) + else: + # case 2: the stack dim is the same as the split dim + # we can directly reshape the split input + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, *stack_shape), # type: ignore[arg-type] + ) + stack_node.replace_all_uses_with(reshape_node) + reshape_node.meta.update(stack_node.meta) + stack_inputs = stack_node.args[0] # type: ignore[union-attr] + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + + +def convert_reshape_cat_arg_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + stack_node_shape: torch.Size, + stack_dim: int, + split_dim: int, +) -> torch.fx.Node: + # reshape the cat node to the stack node shape + cat_shape = cat_node.meta["example_value"].shape + if stack_dim != split_dim: + permute_list = list(range(len(cat_shape))) + permute_list[stack_dim], permute_list[split_dim] = ( + permute_list[split_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(cat_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + cat_node.meta["example_value"], permute_list + ) + else: + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] + ) + reshape_node.meta["example_value"] = torch.Tensor.view( + permute_node.meta["example_value"], + tuple(stack_node_shape), # type: ignore[arg-type] + ) + return reshape_node + + +# ############pattern to be optimized is######### +# | | +# split split (dim=1) +# / \ / \ +# getitem ... getitem other ops +# \ | / / +# stack(user=mul, dim=1 or 2) -> can be different dim +# | + +# ################after transformation############# + +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / +# cat(user=mul, dim=1) cat_other_opts +# \ / +# cat +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), +) +def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target is torch.split) + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_stack_to_cats_pass" + ].get("threshold_to_cat", 10) + # get the stack_node and check its inputs and meta data + next_users = find_next_users(split_node) + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) + counters[backend]["split_stack_to_cats_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": split_dim}, + ) + cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] + new_cat_args_meta, dim=split_dim + ) + reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) + counters[backend]["split_stack_to_cats_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=1) -> user=multiple +# \ ... / \ +# others getitem getitem getitem -> user=multiple +# \ \ / \ +# stack(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view others +# | / +# stack +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), +) +def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_stack_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) + counters[backend]["unbind_stack_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(stack_node, 1, "dim") + with graph.inserting_after(stack_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) + reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) + counters[backend]["unbind_stack_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### +# input +# | +# split(dim=1) -> user=multiple +# \ \ +# others getitem getitem +# \ \ / +# reshape reshape reshape other_op +# \ \ / / +# stack(user=mul, dim=0) +# | + +# ################after transformation############# +# input +# | +# permute +# | +# reshape others +# | / +# cat (dim=0) +# | + + +def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: + # cat_arg must be the split input + view_shape_list = [] + for user in cat_arg.users: + if user.target is torch.split: + for getitem in user.users: + if getitem.target is operator.getitem: + reshape_user = [ + user for user in getitem.users if user.target is torch.reshape + ] + if len(reshape_user) > 0: + view_shape_list = list( + reshape_user[0] + .meta["example_value"] + .unsqueeze(stack_dim) + .shape + ) + view_shape_list[stack_dim] = -1 + return view_shape_list + return view_shape_list + + +@register_graph_pattern( + CallFunction( + torch.stack, + reshape_getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), +) +def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): + split_node = next(node for node in match.nodes if node.target is torch.split) + split_dim = _get_dim(split_node) + split_users = list(split_node.users.keys()) + stack_nodes = [node for node in match.nodes if node.target is torch.stack] + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "move_reshape_out_of_split_stack_pass" + ].get("threshold_to_cat", 10) + for stack_node in stack_nodes: + if not is_node_meta_valid(stack_node): + log.debug("example value absent for node: %s", stack_node) + continue + stack_dim = _get_dim(stack_node) + stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + inputs = [] + for stack_input in stack_inputs: + if stack_input.target != torch.reshape: + inputs.append(stack_input) + else: + inputs.append(stack_input.args[0]) # type: ignore[union-attr] + new_cat_args, _new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_node = convert_reshape_cat_arg_to_stack( + graph, + new_cat_args[0], + stack_node, + stack_node.meta["example_value"].shape, + stack_dim, + split_dim, + ) + stack_node.replace_all_uses_with(reshape_node) + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters[backend]["move_reshape_out_of_split_stack_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # decompose the cat args into multiple stack nodes, i.e., we stack + # all the nodes exist in the stack inputs and reshape the rest followed by a cat + stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] + for cat_arg in new_cat_args: + if cat_arg not in stack_inputs: + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + # cat_arg must be the split input + view_shape_list = get_view_shape_list(cat_arg, stack_dim) + stack_node_shape = torch.reshape( + cat_arg.meta["example_value"], tuple(view_shape_list) + ).shape # type: ignore[union-attr] + cat_inputs.append( + convert_reshape_cat_arg_to_stack( + graph, + cat_arg, + stack_node, + stack_node_shape, + stack_dim, + split_dim, + ) + ) + stack_node_input, stack_node_input_meta = [], [] + else: + stack_node_input.append(cat_arg) + stack_node_input_meta.append(cat_arg.meta["example_value"]) + + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(cat_inputs,), + kwargs={"dim": stack_dim}, + ) + stack_node.replace_all_uses_with(cat_node) + cat_node.meta.update(stack_node.meta) + graph.erase_node(stack_node) + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters[backend]["move_reshape_out_of_split_stack_pass"] += 1 + + +view_getitem_split_aten = ListOf( + CallFunction( + [torch.ops.aten.reshape.default], + CallFunction( + operator.getitem, + CallFunctionVarArgs( + torch.ops.aten.split_with_sizes.default, users=MULTIPLE + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + view_getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"), +) +def move_view_after_cat(match: Match, *args, **kwargs): + split_node = next( + node + for node in match.nodes + if node.target is torch.ops.aten.split_with_sizes.default + ) + split_input, split_section, split_dim = _get_split_args_default(split_node) + split_users = list(split_node.users.keys()) + getitem_indices = [ + getitem.args[1] for getitem in split_users if getitem.target is operator.getitem + ] + if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type] + return + cat_nodes = [ + node for node in match.nodes if node.target is torch.ops.aten.cat.default + ] + graph = match.graph + for cat_node in cat_nodes: + if not is_node_meta_valid(cat_node): + log.debug("example value absent for node: %s", cat_node) + continue + cat_dim = _get_dim(cat_node) + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + # we only consider the following special case + if len(cat_inputs) != len(split_section): + continue + # check if the cat inputs are all the view nodes + if not all( + view_node.target is torch.ops.aten.reshape.default + for view_node in cat_inputs + ): + continue + # check if the view nodes are all from getitem nodes + if not all( + view_node.args[0].target is operator.getitem for view_node in cat_inputs + ): + continue + view_indices = [view.args[0].args[1] for view in cat_inputs] + if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type] + continue + if cat_dim != split_dim: + # construct permute node + permute_list = list(range(len(cat_node.meta["val"].shape) + 1)) + permute_list[split_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[split_dim], + ) + permute_node = graph.call_function( + torch.ops.aten.permute.default, + args=(split_input, permute_list), + ) + else: + permute_node = split_input + + with graph.inserting_before(cat_node): + view_node = graph.call_function( + torch.ops.aten.reshape.default, + args=(permute_node, list(cat_node.meta["val"].shape)), + ) + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters[backend]["move_view_after_cat_aten_pass"] += 1 + + +def match_einsum_strings(s: str) -> bool: + """ + This function takes a string s as input, where s is in the format "3 letter string, + 4 letter string -> 3 letter string". + It checks if the strings match the rule and returns True if they do, False otherwise. + + The rule is: + - The three strings have the same first two characters. + - The first two strings have the same third character. + - The second and third strings have the same last character. + """ + + # Split the input string into parts + parts = s.replace("->", ",").split(",") + + # Strip leading/trailing whitespaces from each part + parts = [part.strip() for part in parts] + + # Check if we have exactly three parts + if len(parts) != 3: + return False + + # Extract the strings + s1, s2, s3 = parts + + # Check if the strings have the correct lengths + if len(s1) != 3 or len(s2) != 4 or len(s3) != 3: + return False + + # Check the rule + return s1[:2] == s2[:2] == s3[:2] and s1[2] == s2[2] and s2[3] == s3[2] + + +@register_graph_pattern( + CallFunctionVarArgs(torch.functional.einsum, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("einsum_to_pointwise_pass"), +) +def replace_einsum_to_pointwise(match: Match, *args, **kwargs): + def repl(input, weights): + return (input.unsqueeze(-1) * weights).sum(-2) + + def should_replace_einsum(einsum_node) -> bool: + equation = get_arg_value(einsum_node, 0) + users = einsum_node.users.keys() + # for now, we only consider the case of two operands + return ( + len(einsum_node.args) == 3 + and is_node_meta_valid(input) + and is_node_meta_valid(weights) + and any( + user.target == "add" or user.target is operator.add for user in users + ) + and match_einsum_strings(equation) + ) + + einsum_node = match.nodes[0] + input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) + if should_replace_einsum(einsum_node): + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl, [input, weights]) + counters[backend]["einsum_to_pointwise_pass"] += 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9668f1b6c6e1d07c9a2744ab3894929826f39429 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py @@ -0,0 +1 @@ +from . import flex, mm, mm_common, mm_plus_mm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ffa1e1780b3b9f9911efc956fd487b9eca8b2d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59e31472aedaf19a9397bc9acec16b1f5db982a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0966ee8faf1043a90ba881dedd100e92a3dc8552 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..a155d35b5d059154e20cb8a1e88e361098e8d4c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,343 @@ +# mypy: allow-untyped-defs +import logging +from typing import TYPE_CHECKING, Union + +import torch +from torch._dynamo.utils import counters +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate + +from .. import config as inductor_config, ir, lowering as L +from ..kernel_inputs import MMKernelInputs +from ..lowering import lowerings, make_pointwise, make_reduction, transform_args +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_cpp_bmm_template, + use_cutlass_template, + use_triton_template, +) +from ..virtualized import ops, V +from .mm_common import ( + _is_static_problem, + is_batch_stride_largest_or_zero, + mm_args, + use_native_matmul, +) + + +if TYPE_CHECKING: + from ..ir import ChoiceCaller + from ..select_algorithm import KernelTemplate + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@SymbolicGridFn +def bmm_grid(b, m, n, meta, *, cdiv): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} +""", + cache_codegen_enabled_for_template=True, +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out", op_overload=aten.bmm.out) +aten_bmm_dtype = ExternKernelChoice( + torch.bmm, + "at::_bmm_out_dtype_cuda", + name="bmm_dtype", + op_overload=aten.bmm.dtype_out, +) +aten_baddbmm = ExternKernelChoice( + torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out +) + + +@L.register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): + """ + Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if all(x.get_device().type == "cpu" for x in [mat1, mat2]): + # decompose to small ops when memory bound + if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1: + mat1 = L.unsqueeze(mat1, -1) + mat2 = L.unsqueeze(mat2, 1) + return L.sum_(L.mul(mat1, mat2), axis=2) + + def is_valid_to_require_contiguous(t): + if not ir.is_storage_and_layout(t): + return True + _, layout = ir.as_storage_and_layout(t, freeze=False) + return isinstance(layout, ir.FlexibleLayout) + + def is_preferred_layout_as_bmm_input(sizes, strides): + # contiguous on one of the last two dims + return ( + strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1]) + ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2])) + + # Make the input of bmm contiguous + # if it is not contiguous on either of the last two dims, + # because bmm cpu implementation would do contiguous() if not. + # This is to avoid additional copies in bmm. + def may_require_contiguous(t, meta_t): + sizes = meta_t.meta["val"].size() + strides = meta_t.meta["val"].stride() + if not is_preferred_layout_as_bmm_input(sizes, strides): + t = ir.ExternKernel.require_contiguous(t) + return t + + if is_valid_to_require_contiguous(mat1): + meta_mat1 = V.graph.current_node.args[0] + mat1 = may_require_contiguous(mat1, meta_mat1) + if is_valid_to_require_contiguous(mat2): + meta_mat2 = V.graph.current_node.args[1] + mat2 = may_require_contiguous(mat2, meta_mat2) + + if use_native_matmul(mat1, mat2): + mat1 = lowerings[aten.unsqueeze](mat1, -1) + mat2 = lowerings[aten.unsqueeze](mat2, 1) + args, kwargs = transform_args( + args=[mat1, mat2], + kwargs={}, + broadcast=True, + type_promotion_kind=None, + convert_input_to_bool=False, + ) # Handles broadcasting the arguments + + if inductor_config.triton.codegen_upcast_to_fp32 and mat1.dtype in [ + torch.float16, + torch.bfloat16, + ]: + + def _to_dtype(x): + return ops.to_dtype(x, mat1.dtype, use_compute_types=False) + + args = [make_pointwise(_to_dtype)(x) for x in args] + + mul_pointwise = make_pointwise(ops.dot)(*args) + dot_reduction = make_reduction("dot")(mul_pointwise, 2) + + return dot_reduction + + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=out_dtype + ) + name = "bmm" + + # Create MMKernelInputs for BMM at the top + kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] # Extract batch dimension + counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + aten_handler: ExternKernelChoice = aten_bmm + aten_extra_kwargs = {} + if out_dtype: + assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA" + aten_handler = aten_bmm_dtype + aten_extra_kwargs = {"out_dtype": out_dtype} + + choices: list[ChoiceCaller] = [] + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + kwarg_overrides = {} + + if use_aten_gemm_kernels(): + templates_to_use.append(aten_handler) + kwarg_overrides[aten_handler.uid] = aten_extra_kwargs + + if use_triton_template(layout, check_max_autotune=False) and ( + out_dtype is None or out_dtype == mat1.get_dtype() + ): + # TODO: add out_dtype support for Triton Template + templates_to_use.append(bmm_template) + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs( + kernel_inputs, + templates_to_use, + name, + kwarg_overrides=kwarg_overrides, + ) + ) + _, is_nonzero = _is_static_problem(layout) + batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout) + if ( + batch_stride_largest_or_zero + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op(name) + ): + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, kernel_inputs.nodes() + ) # type: ignore[arg-type] + + if use_cpp_bmm_template(layout, mat1, mat2): + from ..codegen.cpp_bmm_template import CppBmmTemplate + + CppBmmTemplate.add_choices( + choices, + layout, + kernel_inputs.nodes(), + ) + + if use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) + + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + + +@L.register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + """ + Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if use_native_matmul(mat1, mat2): + if beta == 0: + arg1 = 0 + else: + arg1 = lowerings[aten.mul](beta, inp) + + if alpha == 0: + arg2 = 0 + else: + arg2 = lowerings[aten.mul](alpha, lowerings[aten.bmm](mat1, mat2)) + + return lowerings[aten.add](arg1, arg2) + + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # Create MMKernelInputs for BadDBMM at the top + kernel_inputs = MMKernelInputs( + [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) + ) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] + counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + inp.get_dtype(), + layout, + ) + name = "baddbmm" + # options to tune from + choices: list[ChoiceCaller] = [] + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + if use_aten_gemm_kernels(): + templates_to_use.append(aten_baddbmm) + + if use_triton_template(layout, check_max_autotune=False): + templates_to_use.append(bmm_template) + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs(kernel_inputs, templates_to_use, name) + ) + + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5a2aa09d4ea229ebb56ac589f56fc3900ba6ae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py @@ -0,0 +1,687 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Optional, TYPE_CHECKING, TypedDict + +import torch +from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate + +from .. import config, ir +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + is_ones, + is_zeros, + pad_listlike, + sympy_product, + use_ck_conv_template, + use_triton_template, +) +from ..virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..ir import TensorBox + +log = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +@SymbolicGridFn +def conv2d_grid(n, c, h, w, meta, *, cdiv): + return ( + cdiv(n * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +@SymbolicGridFn +def conv3d_grid(n, c, d, h, w, meta, *, cdiv): + return ( + cdiv(n * d * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +LOOP_BODY_2D = """ + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +""" +This is a relatively simple conv implementation that can likely be +improved. Many alternate conv versions can be found here: +https://github.com/pytorch/torchdynamo/pull/971 +""" +conv2d_template = TritonTemplate( + name="convolution2d", + grid=conv2d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_H = {{size("X", 2)}} + IN_W = {{size("X", 3)}} + OUT_C = {{size(None, 1)}} + OUT_H = {{size(None, 2)}} + OUT_W = {{size(None, 3)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xh = {{stride("X", 2)}} + stride_xw = {{stride("X", 3)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wh = {{stride("W", 2)}} + stride_ww = {{stride("W", 3)}} + + nhw = tl.program_id(0).to(INDEX_DTYPE) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = nhw % OUT_W + nh = nhw // OUT_W + idx_y_h = nh % OUT_H + idx_n = nh // OUT_H + idx_y_c = tl.program_id(1).to(INDEX_DTYPE) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2).to(INDEX_DTYPE) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_2D + + """ +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (ijk % BLOCK_K_COUNT) * BLOCK_K + ij = ijk // BLOCK_K_COUNT + i = ij // KERNEL_W + j = ij % KERNEL_W + """ + + LOOP_BODY_2D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} +""", +) + +LOOP_BODY_3D = """ + idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_d * stride_xd)[:, None] + + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_d >= 0)[:, None] + & (idx_x_d < IN_D)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + + (d * stride_wd) + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +conv3d_template = TritonTemplate( + name="convolution3d", + grid=conv3d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_D = {{size("X", 2)}} + IN_H = {{size("X", 3)}} + IN_W = {{size("X", 4)}} + OUT_C = {{size(None, 1)}} + OUT_D = {{size(None, 2)}} + OUT_H = {{size(None, 3)}} + OUT_W = {{size(None, 4)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xd = {{stride("X", 2)}} + stride_xh = {{stride("X", 3)}} + stride_xw = {{stride("X", 4)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wd = {{stride("W", 2)}} + stride_wh = {{stride("W", 3)}} + stride_ww = {{stride("W", 4)}} + + ndhw = tl.program_id(0).to(INDEX_DTYPE) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = ndhw % OUT_W + ndh = ndhw // OUT_W + idx_y_h = ndh % OUT_H + nd = ndh // OUT_H + idx_y_d = nd % OUT_D + idx_n = nd // OUT_D + idx_y_c = tl.program_id(1).to(INDEX_DTYPE) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2).to(INDEX_DTYPE) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for d in range(KERNEL_D) %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + d = {{d}} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_3D + + """ +{% endfor %} +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for d in range(KERNEL_D): + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (dijk % BLOCK_K_COUNT) * BLOCK_K + dij = dijk // BLOCK_K_COUNT + j = dij % KERNEL_W + di = dij // KERNEL_W + i = di % KERNEL_H + d = di // KERNEL_H + """ + + LOOP_BODY_3D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_d < OUT_D)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_d = idx_y_d[:, None] + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} +""", +) + +aten_convolution = ExternKernelChoice( + torch.convolution, + "at::convolution", + has_out_variant=False, + op_overload=aten.convolution.default, +) + + +def conv1x1_via_mm(x, w, *, out): + w = torch.squeeze(torch.squeeze(w, -1), -1) + return torch.matmul( + x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) + ) + + +aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) + + +class ConvLayoutParams(TypedDict): + stride: tuple[int, ...] + padding: tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + + +def conv_layout( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, +) -> ir.Layout: + """Determine output layout for a convolution""" + with V.graph.fake_mode: + output = torch.ops.aten.convolution( + ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(weight, guard_shape=True), + ir.ir_node_to_tensor(bias, guard_shape=True), + V.graph.sizevars.size_hints(stride), # type: ignore[arg-type] + V.graph.sizevars.size_hints(padding), # type: ignore[arg-type] + V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type] + transposed, + V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type] + groups, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] + + return ir.FixedLayout( + x.get_device_or_error(), + x.get_dtype(), + sizes, + stride, + ) + + +def channels_last_order(rank): + order = list(reversed(range(rank))) + order.insert(1, order.pop(-1)) + return order + + +def convert_1x1_conv_to_mm(x, weight, bias): + # special case for 1x1 convolution, which is actually just a matmul + rank = len(weight.get_size()) + for _ in range(rank - 2): + weight = L[aten.squeeze](weight, dim=-1) + weight = L[aten.permute](weight, [1, 0]) + + x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) + x_permute = list(range(rank)) + x_permute.append(x_permute.pop(1)) + x = L[aten.permute](x, x_permute) + *sizes, in_chan = x.get_size() + x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) + if bias is None: + result = L[aten.mm](x, weight) + else: + result = L[aten.addmm](bias, x, weight) + result = L[aten.reshape](result, [*sizes, -1]) + result_permute = list(range(rank)) + result_permute.insert(1, result_permute.pop(-1)) + return L[aten.permute](result, result_permute) + + +@register_lowering(aten.convolution) +def convolution( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +): + stride = tuple(stride) + padding = tuple(padding) + dilation = tuple(dilation) + output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.guard_int(groups) + assert isinstance(groups, int) + + # Need use hint for triton template since the template does not + # work with a dynamic shape. + # + # No need to guard_int for dilation and output_padding + # since the template is only used when dilation is 1 and output_padding + # is 0. + stride = tuple(V.graph.sizevars.guard_int_seq(stride)) + padding = tuple(V.graph.sizevars.guard_int_seq(padding)) + + kwargs: ConvLayoutParams = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + device_type = ir.get_device_type(x) + + if len(x.get_size()) == len(weight.get_size()) - 1: + # add batch dimension to simplify rest of function + return L[aten.squeeze]( + convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), + dim=0, + ) + + out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size()) + + # Always convert conv1D to 2D for Intel GPU. + # Only conv2D can be converted to channel last layout, + # which have much better performance. + if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu": + kwargs.update( + { + "stride": (1,) + stride, + "padding": (0,) + padding, + "dilation": (1,) + dilation, + "output_padding": (0,) + output_padding, + } + ) + # (N, C, L) -> (N, C, 1, L) + x = L[aten.unsqueeze](x, dim=2) + weight = L[aten.unsqueeze](weight, dim=2) + + return L[aten.squeeze]( + convolution(x, weight, bias, **kwargs), + dim=2, + ) + + ndim = len(kernel_shape) + stride = pad_listlike(stride, ndim) + padding = pad_listlike(padding, ndim) + dilation = pad_listlike(dilation, ndim) + output_padding = pad_listlike(output_padding, ndim) + + def channels_last_conv(): + if V.graph.layout_opt and ndim == 2: + return True + + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + return req_stride_order == ir.NHWC_STRIDE_ORDER + + autotuning_gemm = config.max_autotune or config.max_autotune_gemm + + if ( + (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) + and is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + and groups == 1 + and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0) + ): + return convert_1x1_conv_to_mm(x, weight, bias) + + if bias is not None and device_type != "cpu": + # peel off the bias, cudnn is slower with it + result = convolution(x, weight, None, **kwargs) + return L[aten.add]( + result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) + ) + + x.realize() + weight.realize() + + # ndim can be 1 for convolution in models such as demucs + # TODO: check if it's beneficial to convert Conv1d to Conv2d and then + # apply channels last. + if V.graph.layout_opt and ndim == 2: + V.graph.num_channels_last_conv += 1 + x = ir.ExternKernel.require_channels_last(x) # type: ignore[assignment] + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) # type: ignore[assignment] + layout = conv_layout(x, weight, None, **kwargs) + else: + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment] + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment] + + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] + bias.realize() + bias.freeze_layout() + V.graph.sizevars.guard_int_seq(bias.get_size()) + + choices = [] + if torch._inductor.utils._use_conv_autotune_backend("ATEN"): + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] + + if ( + torch._inductor.utils._use_conv_autotune_backend("TRITON") + and use_triton_template(layout) + # templates only support these: + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) + and V.graph.sizevars.statically_known_equals(in_chan * groups, x.get_size()[1]) # type: ignore[arg-type] + ): + if ( + is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and groups == 1 + ): + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + conv_configs = V.choices.get_conv_configs(device_type) + + dtype_size = x.get_dtype().itemsize + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + dtype_size=dtype_size, + ): + if ndim == 2: + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + elif ndim == 3: + conv3d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_D=kernel_shape[0], + KERNEL_H=kernel_shape[1], + KERNEL_W=kernel_shape[2], + STRIDE_D=stride[0], + STRIDE_H=stride[1], + STRIDE_W=stride[2], + PADDING_D=padding[0], + PADDING_H=padding[1], + PADDING_W=padding[2], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + if use_ck_conv_template(layout): + CKGroupedConvFwdTemplate.add_ck_conv_choices( + choices, + layout, + input_nodes=(x, weight) + ((bias,) if bias is not None else tuple()), + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=ndim, + ) + return autotune_select_algorithm("convolution", choices, args, layout) + + +@register_lowering(aten._convolution) +def _convolution( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, +): + return convolution( + x, weight, bias, stride, padding, dilation, transposed, output_padding, groups + ) + + +def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): + assert fx_node.target is torch.ops.aten.convolution.default + if V.graph.layout_opt: + return args, kwargs + else: + return constrain_to_fx_strides(fx_node, *args, **kwargs) + + +add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a641ce83b17eade82a85cd10962dc377dab7e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/custom_op.py @@ -0,0 +1,537 @@ +# Owner(s): ["module: inductor"] + +import functools +import logging +from collections.abc import Callable +from typing import Any, Optional, Union + +import torch +from torch._inductor.codegen.subgraph import SubgraphTemplate +from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox +from torch._inductor.lowering import lowerings, validate_ir +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, +) +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +def _detect_collective_ops(choices: list) -> bool: + """ + Detect if choices contain collective operations. + """ + from torch._inductor.utils import is_collective_op + + for choice in choices: + if not hasattr(choice, "gm") or choice.gm is None: + continue + + for node in choice.gm.graph.nodes: + if node.op == "call_function" and node.target is not None: + op_name = str(node.target) + + if is_collective_op(op_name) or is_collective_op( + f"torch.ops.{op_name}" + ): + return True + + return False + + +class CustomOpConfig: + """Config for custom op autotuning. + + Specifies optional decomposition function with parameter values. + Each config creates exactly one variant. + + Args: + decomposition: Optional functions to autotune. If not provided, default will be used. + **params: Parameters passed to the function + + Examples: + CustomOpConfig(attention_impl, head_dim=32, method='chunked') + CustomOpConfig(head_dim=32, method='chunked') + """ + + def __init__( + self, + decomposition: Optional[Callable[..., Any]] = None, + **params: Any, + ): + if decomposition is not None and not callable(decomposition): + raise TypeError( + f"decomposition must be callable, got {type(decomposition)}" + ) + + self.decomposition = decomposition + self.params = params + + def get_decomposition( + self, default_impl: Optional[Callable[..., Any]] = None + ) -> Callable[..., Any]: + """Return the decomposition function for this config. + When decomposition is not specified, return the default implementation. + """ + if self.decomposition is not None: + return self.decomposition + + if default_impl is not None and callable(default_impl): + return default_impl + + raise TypeError( + "No decomposition specified in config and no default implementation provided. " + "Please provide a decomposition function in CustomOpConfig." + ) + + def __repr__(self) -> str: + decomp_name = self.decomposition.__name__ if self.decomposition else "default" + if self.params: + params_str = ", ".join(f"{k}={v}" for k, v in self.params.items()) + return f"CustomOpConfig({decomp_name}, {params_str})" + return f"CustomOpConfig({decomp_name})" + + +__all__ = [ + "autotune_custom_op", + "register_custom_op_autotuning", + "CustomOpConfig", +] + + +def _extract_tensor_inputs( + args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + """Extract tensor inputs from mixed args/kwargs. + Separates tensors (for autotuning input_nodes) from non-tensor parameters. + Non-tensor kwargs are later functools.partial'd into decomposition functions. + + Args: + args: Positional arguments (mix of tensors and scalars) + kwargs: Keyword arguments (mix of tensors and scalars) + + Returns: + Tuple of (tensor_inputs_list, non_tensor_kwargs) + """ + tensor_inputs = [] + non_tensor_kwargs = {} + + # Process args and kwargs: separate tensor inputs and non tensor args + for i, arg in enumerate(args): + if isinstance(arg, (TensorBox, Buffer)): + tensor_inputs.append(arg) + else: + # Add non-tensor positional args to kwargs with generated names + non_tensor_kwargs[f"arg_{i}"] = arg + + for key, value in kwargs.items(): + if isinstance(value, (TensorBox, Buffer)): + tensor_inputs.append(value) + else: + non_tensor_kwargs[key] = value + + return tensor_inputs, non_tensor_kwargs + + +def _merge_config_and_runtime_kwargs( + config_params: dict[str, Any], + runtime_kwargs: dict[str, Any], +) -> dict[str, Any]: + """Merge config parameters with runtime kwargs. Runtime kwargs take precedence. + If there are conflicts, log a warning and use runtime value. + + Args: + config_params: Parameters from CustomOpConfig + runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs + + Returns: + Merged kwargs dictionary with runtime values taking precedence + """ + merged_kwargs = config_params.copy() + + # Check for conflicts and let runtime kwargs dominate + conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys()) + + for key in conflicts: + log.warning( + "Parameter '%s' specified both in CustomOpConfig (%s) " + "and at runtime (%s). Using runtime value.", + key, + config_params[key], + runtime_kwargs[key], + ) + + # Runtime kwargs override config params + merged_kwargs.update(runtime_kwargs) + + return merged_kwargs + + +def _adapt_user_input_gen_fns( + inputs: list[Any], + arg_names: list[str], + user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]], +) -> dict[int, Callable[[Any], torch.Tensor]]: + """Convert user input generators from name-based to index-based format. + Inductor autotune's input_gen_fns expects index of arg_names as key. + + Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. + """ + + name_to_index = {name: i for i, name in enumerate(arg_names)} + index_based_fns = {} + + for name, gen_fn in user_input_gen_fns.items(): + if name in name_to_index: + index_based_fns[name_to_index[name]] = gen_fn + else: + log.warning( + "Unknown argument name '%s' in input_gen_fns. " + "Available argument names: %s", + name, + list(name_to_index.keys()), + ) + + def create_internal_input_gen_fn( + user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str + ) -> Callable[[Any], torch.Tensor]: + """Create internal input generator that converts IR buffer to user's fake tensor.""" + + def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: + fake_tensor = ir_node_to_tensor(ir_buffer) + assert fake_tensor is not None, "ir_node_to_tensor returned None" + return user_function(fake_tensor) + + return internal_input_gen_fn + + return { + i: create_internal_input_gen_fn( + user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}" + ) + for i, user_gen_fn in index_based_fns.items() + if i < len(inputs) + } + + +def _create_fallback_choice( + name: str, + default_impl: Callable[..., Any], + fake_output: torch.Tensor, + kwargs: dict[str, Any], +) -> ExternKernelChoice: + """Create fallback choice for default implementation.""" + + def fallback_wrapper(*args: Any) -> Any: + return default_impl(*args, **kwargs) + + return ExternKernelChoice( + kernel=fallback_wrapper, + name=f"{name}_fallback_default", + has_out_variant=False, + op_overload=default_impl, + use_fallback_kernel=True, + ) + + +def autotune_custom_op( + name: str, + decompositions: list[Callable[..., Any]], + inputs: list[Any], + non_tensor_args: list[dict[str, Any]], + op_overload: torch._ops.OpOverload, + user_input_gen_fns: Optional[ + dict[str, Callable[[torch.Tensor], torch.Tensor]] + ] = None, +) -> Union[TensorBox, Any]: + """Autotune custom operations by comparing multiple decomposition implementations. + + Currently supports SINGLE OUTPUT custom ops only. + TODO: Add support for multiple output custom ops (tuple/list returns). + + This function generates multiple implementation choices for a custom operation and + uses Inductor's autotuning system to select the best performing variant at runtime. + After selecting the best choice, applies inline fusion if the winning choice has a graph. + + Args: + name: Unique identifier for the autotuning operation + decompositions: List of alternative implementation functions to benchmark + inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects) + non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg + op_overload: OpOverload of the custom op, used as fallback implementation + user_input_gen_fns: Optional custom input generators for benchmarking. + Maps input indices to functions that take fake tensors + and return real tensors for performance measurement. + + Returns: + IR node representing the optimized operation result + + Raises: + TypeError: If decompositions is not a list/tuple + RuntimeError: If no inputs or no valid choices generated + """ + if not isinstance(decompositions, (list, tuple)): + raise TypeError( + f"decompositions must be a list or tuple of callables, got {type(decompositions)}" + ) + + if not inputs: + raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning") + + if len(decompositions) != len(non_tensor_args): + raise ValueError( + f"decompositions and non_tensor_args must have same length, " + f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs" + ) + + template = SubgraphTemplate(name=name) + choices = template.generate_custom_op_choices( + name=name, + # pyrefly: ignore [bad-argument-type] + decompositions=decompositions, + input_nodes=list(inputs), + non_tensor_args=non_tensor_args, + ) + + # Add default implementation as fallback + if op_overload and hasattr(op_overload, "_op"): + fallback_name = f"{name}_fallback_default" + from torch._inductor.select_algorithm import extern_kernels + + # Skip if extern_kernel already registered to avoid duplicate registration error + if not hasattr(extern_kernels, fallback_name): + with V.fake_mode: + fake_inputs = [ir_node_to_tensor(inp) for inp in inputs] + fallback_kwargs = non_tensor_args[0] if non_tensor_args else {} + fake_output = op_overload(*fake_inputs, **fallback_kwargs) + + fallback_choice = _create_fallback_choice( + name, op_overload, fake_output, fallback_kwargs + ) + fallback_choice.maybe_append_choice( + choices=choices, + input_nodes=list(inputs), + layout=FixedLayout( + device=fake_output.device, + dtype=fake_output.dtype, + size=fake_output.shape, + stride=fake_output.stride(), + ), + ) + + if not choices: + raise RuntimeError(f"No valid choices generated for {name}") + + # Convert user input generation functions to internal format + input_gen_fns = {} + if user_input_gen_fns: + import inspect + + arg_names = ( + list(inspect.signature(decompositions[0]).parameters.keys()) + if decompositions + else [] + ) + input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) + + is_collective = _detect_collective_ops(choices) + + # Run autotuning and get both result and winning choice + selected_result, winning_choice = autotune_select_algorithm( + name=name, + choices=choices, + input_nodes=list(inputs), + layout=choices[0].layout, + input_gen_fns=input_gen_fns, + return_choice=True, + is_collective=is_collective, + ) + + # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) + if winning_choice.gm is not None: + log.debug( + "Inlining winning choice: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes + + return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name) + + log.debug( + "Winning choice does not support inlining: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + return selected_result + + +def _generate_dynamic_configs( + tensor_inputs: list[Buffer], + config_generator: Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]], + default_impl: Callable[..., Any], + operation_name: str, +) -> list[CustomOpConfig]: + """Generate configs dynamically based on input tensors at lowering time.""" + import inspect + + sig = inspect.signature(default_impl) + param_names = list(sig.parameters.keys()) + + with V.fake_mode: + fake_tensors = [ir_node_to_tensor(inp) for inp in tensor_inputs] + + fake_tensors_dict = dict(zip(param_names, fake_tensors)) + + configs = config_generator(fake_tensors_dict) + + if not isinstance(configs, (list, tuple)): + raise TypeError( + f"config_generator must return a list or tuple of CustomOpConfig, " + f"got {type(configs)}" + ) + if not configs: + raise ValueError(f"config_generator returned empty list for {operation_name}. ") + + return list(configs) + + +def register_custom_op_autotuning( + custom_op: torch._library.custom_ops.CustomOpDef, + configs: Optional[Union[list[CustomOpConfig], list[Callable[..., Any]]]] = None, + config_generator: Optional[ + Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]] + ] = None, + name: Optional[str] = None, + input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None, +) -> None: + """Register custom op for autotuning with custom_op configs where each config + specifies a decomposition implementation function with its parameter values. + + Args: + custom_op: Custom operation (decorated function from @torch.library.custom_op) + configs: List of CustomOpConfig objects for static inputs. Mutually exclusive with config_generator. + config_generator: Dynamic config generator function that takes a dict mapping + parameter names to fake tensors, and returns list[CustomOpConfig] + based on input tensor properties. Mutually exclusive with configs. + name: Operation name (default: "{op_name}_autotuned") + input_gen_fns: Custom input generators for benchmarking + + Examples: + # Static configs + @torch.library.custom_op("mylib::attention", mutates_args=()) + def my_attention(query, key, value, head_dim=32): + ... + + register_custom_op_autotuning( + my_attention, + configs=[ + CustomOpConfig(attention_impl, head_dim=32, method='chunked'), + CustomOpConfig(attention_impl, head_dim=64, method='tiled'), + CustomOpConfig(head_dim=128), # No decomposition specified, use default + ], + input_gen_fns={ + "query": lambda fake: torch.randn_like(fake, device='cuda'), + "key": lambda fake: torch.randn_like(fake, device='cuda'), + "value": lambda fake: torch.randn_like(fake, device='cuda'), + }, + ) + + # Dynamic config generation based on input tensor properties + def generate_k_split_configs(fake_tensors: dict[str, torch.Tensor]) -> list[CustomOpConfig]: + # Access tensor shapes, dtypes, devices, etc. + m, k = fake_tensors["mat1"].shape + _, n = fake_tensors["mat2"].shape + k_splits = ... # compute possible k splits based on tensor properties + return [CustomOpConfig(k_splits=k) for k in k_splits] + + register_custom_op_autotuning( + matmul_decomposeK_op, + config_generator=generate_k_split_configs, + input_gen_fns={...}, + ) + """ + from torch._library.custom_ops import CustomOpDef + + if not isinstance(custom_op, CustomOpDef): + raise TypeError( + f"custom_op must be a CustomOpDef (decorated function from @torch.library.custom_op), " + f"got {type(custom_op)}." + ) + + # Validate configs and config_generator are mutually exclusive + if configs is not None and config_generator is not None: + raise ValueError( + "Cannot specify both 'configs' and 'config_generator'. " + "Use 'config_generator' for shape-dependent configs." + ) + + if configs is None and config_generator is None: + raise ValueError("Must specify either 'configs' or 'config_generator'") + + op_overload = custom_op._opoverload + default_impl = custom_op._init_fn + + # Process and validate static configs at registration time + static_configs = None + if configs is not None: + if not isinstance(configs, (list, tuple)): + raise TypeError(f"configs must be a list or tuple, got {type(configs)}") + + static_configs = [] + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + static_configs.append(cfg) + else: + raise TypeError( + f"Each config must be a CustomOpConfig object, got {type(cfg)}" + ) + + if not static_configs: + raise ValueError("At least one config must be provided") + + if name is None: + name = f"{op_overload._name}_autotuned" + + @functools.wraps(op_overload) + def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: + """Inductor lowering function that replaces custom op calls with autotuned versions.""" + # Extract tensor inputs and non-tensor parameters (runtime kwargs) + tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs) + + # Get configs: either generate dynamically or use static configs + if config_generator is not None: + configs_to_use = _generate_dynamic_configs( + tensor_inputs, config_generator, default_impl, name + ) + else: + assert static_configs is not None + configs_to_use = static_configs + + # Prepare decompositions and kwargs for autotuning + decompositions = [] + non_tensor_args = [] + + for cfg in configs_to_use: + decomp = cfg.get_decomposition(default_impl=default_impl) + decompositions.append(decomp) + + # Merge config params with runtime kwargs (runtime takes precedence) + merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs) + non_tensor_args.append(merged_kwargs) + + result = autotune_custom_op( + name=name, + decompositions=decompositions, + inputs=tensor_inputs, + non_tensor_args=non_tensor_args, + op_overload=op_overload, + user_input_gen_fns=input_gen_fns, + ) + + validate_ir(result) + return result + + lowerings[op_overload] = autotuning_lowering diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6471278229a1c6f5107f0d6dfd41756bfe43cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de524f8345c635e18f8748f01372141d935ff1a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..945268637aa508c3dd1f401024e6403a98de94c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa49225ab80267a7ceaa4211781d2033fd4bcd97 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29e25d94b2f2fcab8b3ce8c3112f1b566f08ab75 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..870b1ef3dfde52e669ea5e44d9fa3165a5628315 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..f95beb14612924cfe2877710a4fe99c2e6c15084 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/common.py.jinja @@ -0,0 +1,204 @@ + + +# Common Imports +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( + desc_k, + [kv_base_offset, 0], + ) + {%- else %} + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + {%- endif %} + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_base_offset, 0], + ) + {%- else %} + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..0e83853fa5de8e2ae1a66726bf6e67b1a45fd212 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja @@ -0,0 +1,76 @@ +{% if NEEDS_BLOCK_MASK %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} +{% else %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP")}} +{% endif %} + from flash_attn.cute.interface import _flash_attn_fwd + from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch + + # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D) + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + + @cute.jit + def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=0, + output_name="tSrS_ssa", + score="tSrS_ssa", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + out="tSrS_ssa" + ) | indent_except_first(2) }} + return tSrS_ssa + {{ set_cute_hash("score_mod", "score") }} + + # (B,M,H,D) -> (B,H,M,D) + output = {{get_output()}} + output_transposed = output.transpose(1, 2) + + {% if NEEDS_BLOCK_MASK %} + @cute.jit + def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + ) | indent_except_first(2) }} + return mask_mod_output + {{ set_cute_hash("mask_mod", "mask") }} + block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX) + {% else %} + block_sparse_tensors = None + mask_mod = None + {% endif %} + + # Collect any additional tensor buffers that were added during modifications + {% set tensor_buffers = get_tensor_buffers() -%} + {% if tensor_buffers -%} + buffers = [{% for buffer in tensor_buffers %}{{buffer}}{% if not loop.last %}, {% endif %}{% endfor %}] + buffers = list(buffers) + {% else -%} + buffers = None + {% endif -%} + + # Out and LSE filled inplace + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + softmax_scale={{SM_SCALE}}, + return_lse=True, + score_mod=score_mod, + mask_mod=mask_mod, + out=output_transposed, + lse=LOGSUMEXP, + block_sparse_tensors=block_sparse_tensors, + aux_tensors=buffers + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..2831ba6af5b60ef469d122a4886dbed9b557ede3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja @@ -0,0 +1,28 @@ +{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV")}} + from flash_attn.cute.interface import _flash_attn_bwd + + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + out_transposed = OUT.transpose(1, 2) + d_out_transposed = D_OUT.transpose(1, 2) + + dq_transposed, dk_transposed, dv_transposed = _flash_attn_bwd( + q_transposed, + k_transposed, + v_transposed, + out_transposed, + d_out_transposed, + LSE, + softmax_scale={{SM_SCALE}}, + ) + + dq = dq_transposed.transpose(1, 2) + dk = dk_transposed.transpose(1, 2) + dv = dv_transposed.transpose(1, 2) + + dq_out = {{get_output()}} + {# TODO: add out support to flash #} + dq_out.copy_(dq) + DK.copy_(dk) + DV.copy_(dv) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..b92ea6c14a33fe11bb0c9bd485ca15be60317ded --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja @@ -0,0 +1,215 @@ +{{def_kernel("Q", "K", "V", "LSE", "MAX", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN, QK_HEAD_DIM], + strides=[stride_qm, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask", val_shape=("BLOCK_M", "V_HEAD_DIM_ROUNDED"))}} + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..3467d84475d0ce70fba937df75f7d93caabfb5dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja @@ -0,0 +1,620 @@ +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8, val_shape=("BLOCK_N1", "QK_HEAD_DIM_ROUNDED"))}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + + {# Note: Selective masking DQ + We load elements beyond KV_LEN w/ zero, some score mods may convert this elements to NaN + Example: lambda x, *_: 1 / score, this NaN would propagate regardless of other masking + We only need to do this on the m1 dim since these elements take part in the final reduction + for DQ #} + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + {# See Note Selective masking DQ #} + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + {# Note: Selective masking DK/DV + We load elements beyond Q_LEN w/ zero, some score mods may convert this elements to NaN + Example: lambda x, *_: 1 / score, this NaN would propagate regardless of other masking + We only need to do this on the m1 dim since these elements take part in the final reduction + for DK/DV #} + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + {# See Note: Selective masking DK/DV#} + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..e5f0e118c5631404b0f1fda5086e2447f64e4fbe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja @@ -0,0 +1,242 @@ + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = {{size("KV_IDX", -1)}} + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask", val_shape=("GQA_SHARED_HEADS", "BLOCK_M_PER_HQ", "V_HEAD_DIM"))}} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..0c40b43277f8ae2da748487803758ff46c338ced --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/utilities.py.jinja @@ -0,0 +1,59 @@ + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5b57c458f46e62862ea997c5a0fad0d4729d65d5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,1102 @@ +# mypy: allow-untyped-defs +import functools +import logging +from typing import Any, Optional, Union + +import torch +from torch._dynamo.utils import counters +from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + context_add_strides, + context_add_using_tf32, + mm_operations, +) +from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.remote_gemm_autotune_cache import gen_best_config +from torch._inductor.virtualized import ops, V +from torch.fx.experimental.proxy_tensor import make_fx +from torch.nn.functional import ScalingType # type: ignore[attr-defined] +from torch.torch_version import TorchVersion + +from .. import config as inductor_config, distributed_autotune +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate +from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate +from ..ir import Buffer, ChoiceCaller, is_triton, Layout +from ..kernel_inputs import MMKernelInputs +from ..lowering import ( + lowerings, + make_pointwise, + make_reduction, + register_lowering, + transform_args, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + KernelTemplate, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + ceildiv, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_ck_tile_gemm_template, + use_cpp_gemm_template, + use_cutlass_template, + use_decompose_k_choice, + use_triton_blackwell_tma_template, + use_triton_template, + use_triton_tma_template, +) +from .mm_common import ( + _is_static_problem, + load_kernel_template, + mm_args, + mm_grid, + persistent_mm_grid, + use_native_matmul, +) + + +try: + import triton + + triton_version = TorchVersion(triton.__version__) + has_triton = True +except ImportError: + triton_version = TorchVersion("0.0.0") + has_triton = False + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# We define each template kernel in a separate file which is the name of the input to load_kernel_template +# (e.g. triton_mm for templates/triton_mm.py.jinja). +# If you are adding a new template, please follow that pattern and add a new file with your implementation in the templates folder. +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=load_kernel_template("triton_mm") + if (torch.version.hip is None) or triton_version >= "3.3.0" + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. + # See more details in https://github.com/pytorch/pytorch/pull/146293 + else load_kernel_template("triton_mm_rocm"), + cache_codegen_enabled_for_template=True, + prologue_loads_all_inputs=True, +) + +persistent_tma_mm_template = TritonTemplate( + name="mm_persistent_tma", + grid=persistent_mm_grid, + source=load_kernel_template("triton_persistent_tma_mm"), +) + + +scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( + name="scaled_mm_device_tma_epilogue_scaling", + grid=persistent_mm_grid, + source=load_kernel_template("triton_epilogue_scaled_mm"), +) + + +scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate( + name="scaled_mm_device_tma_main_loop_scaling", + grid=persistent_mm_grid, + source=load_kernel_template("triton_main_loop_scaled_mm"), +) + +blackwell_ws_persistent_device_tma_mm_template = TritonTemplate( + name="blackwell_ws_persistent_device_tma", + grid=persistent_mm_grid, + source=load_kernel_template("triton_blackwell_ws_persistent_device_tma_mm"), +) + + +# prevent duplication registration of extern functions +@functools.cache +def lazy_register_extern_choice(fn): + return ExternKernelChoice(fn) + + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out", op_overload=aten.mm.out) +aten_mm_dtype = ExternKernelChoice( + torch.mm, + "at::_mm_dtype_out_cuda", + name="mm_dtype", + op_overload=aten.mm.dtype_out, +) + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.out +) + +aten__int_mm = ExternKernelChoice( + torch._int_mm, "at::_int_mm_out", op_overload=aten._int_mm.out +) + +aten__sparse_semi_structured_mm = ExternKernelChoice( + torch._sparse_semi_structured_mm, + "at::_sparse_semi_structured_mm", + has_out_variant=False, + op_overload=aten._sparse_semi_structured_mm.default, +) + +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if (inp.stride(0) == 0 and inp.size(0) != 0) or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +def decomposeK(a, b, k_splits): + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + + k_parts = k // k_splits + B = k_splits + a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2)) + b_reshaped = b.reshape(B, k_parts, n) + result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32) + reduced_buf = torch.sum(result, 0) + return reduced_buf.to(a.dtype) + + +class DecomposeKSugraphTemplate(SubgraphTemplate): + def __init__(self): + super().__init__( + name="decompose_k", + ) + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, + k_split: int, + ) -> SubgraphChoiceCaller: + from torch._dispatch.python import enable_python_dispatcher + + from ..decomposition import select_decomp_table + + name = f"decompose_k_mm_{k_split}_split" + description = f"{k_split=}" + + with enable_python_dispatcher(): + decompositions = select_decomp_table() + fn = make_fx( + functools.partial(decomposeK, k_splits=k_split), + decompositions, + ) + + return super().generate( + name=name, + input_nodes=input_nodes, + layout=layout, + make_fx_graph=fn, + description=description, + ) + + +decompose_k_subgraph_template = DecomposeKSugraphTemplate() + + +class ContiguousTemplate(SubgraphTemplate): + def __init__(self, name: str, description: str, fn: Any): + self.name = name + self.description = description + self.fn = fn + super().__init__( + name=name, + ) + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, + ) -> SubgraphChoiceCaller: + from torch._dispatch.python import enable_python_dispatcher + + from ..decomposition import select_decomp_table + + with enable_python_dispatcher(): + decompositions = select_decomp_table() + fn = make_fx( + self.fn, + decompositions, + ) + + return super().generate( + name=self.name, + input_nodes=input_nodes, + layout=layout, + make_fx_graph=fn, + description=self.description, + ) + + +def contiguous_mm(a, b): + return torch.mm(a, b.contiguous()) + + +def contiguous_addmm(inp, a, b): + return torch.addmm(inp, a, b.contiguous()) + + +mm_contiguous_subgraph_template = ContiguousTemplate( + "contiguous_mm", "contiguous mm", contiguous_mm +) +addmm_contiguous_subgraph_template = ContiguousTemplate( + "contiguous_addmm", "contiguous addmm", contiguous_addmm +) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None): + """ + Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if out_dtype is not None: + input_dtype = mat1.get_dtype() + torch._check( + mat2.get_dtype() == input_dtype, + lambda: "input dtypes must be the same", + ) + torch._check( + mat1.get_device().type == "cuda", + lambda: "out_dtype is only supported for CUDA", + ) + torch._check( + out_dtype == input_dtype + or ( + out_dtype == torch.float32 + and input_dtype in (torch.float16, torch.bfloat16) + ), + lambda: "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs", + ) + + # Lower matmul-related operations (e.g., torch.matmul / torch.bmm / torch.addmm) + # into native matmul IR using `ops.dot`. When we see a matmul pattern + # (C[y, x] = A[y, r] * B[r, x]), the core idea is to emulate a broadcasted + # multiply followed by a sum. + # + # For example, given `C = torch.matmul(A, B)`, this can be rewritten as: + # + # Prod = A.unsqueeze(-1) * B.unsqueeze(0) + # C = Prod.sum(dim=1) + # + # Instead of explicitly using `ops.mul` and `ops.reduction("sum")`, we lower + # these into `ops.dot` (pointwise) and `ops.reduction("dot")`. These IR nodes + # are semantically equivalent to the `ops.mul` + `ops.reduction("sum")` + # combination, but are lowered to `tl.dot` during the code generation phase. + if use_native_matmul(mat1, mat2): + mat1 = lowerings[aten.unsqueeze](mat1, -1) + mat2 = lowerings[aten.unsqueeze](mat2, 0) + args, kwargs = transform_args( + args=[mat1, mat2], + kwargs={}, + broadcast=True, + type_promotion_kind=None, + convert_input_to_bool=False, + ) # Handles broadcasting the arguments + + if inductor_config.triton.codegen_upcast_to_fp32 and mat1.dtype in [ + torch.float16, + torch.bfloat16, + ]: + + def _to_dtype(x): + return ops.to_dtype(x, mat1.dtype, use_compute_types=False) + + args = [make_pointwise(_to_dtype)(x) for x in args] + + mul_pointwise = make_pointwise(ops.dot)(*args) + dot_reduction = make_reduction("dot")(mul_pointwise, 1) + + return dot_reduction + + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=out_dtype + ) + static_shape, is_nonzero = _is_static_problem(layout) + name = "mm" + + # Create MMKernelInputs for standard MM at the top + kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + choices: list[ChoiceCaller] = [] + static_shape, is_nonzero = _is_static_problem(layout) + + aten_handler: ExternKernelChoice = aten_mm + aten_extra_kwargs: dict[str, Any] = {} + if out_dtype is not None: + aten_handler = aten_mm_dtype + aten_extra_kwargs = {"out_dtype": out_dtype} + + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + kwarg_overrides: dict[str, dict[str, Any]] = {} + if use_aten_gemm_kernels(): + templates_to_use.append(aten_handler) + if aten_extra_kwargs: + kwarg_overrides[aten_handler.uid] = aten_extra_kwargs + + if ( + out_dtype is None + and is_nonzero + and use_triton_template(layout, check_max_autotune=True) + ): + if use_decompose_k_choice(m, n, k): + templates_to_use.append(decompose_k_subgraph_template) + # Triton Templates typically perform very poorly for large K. + # Its highly unlikely that if we want to use decompose_k, then + # Triton will ever win. + # + # To be conservative we increase this threshold for N/M by 2. + is_exhaustive = inductor_config.max_autotune_gemm_search_space == "exhaustive" + if is_exhaustive or not use_decompose_k_choice(m, n, k, threshold_multiple=2): + templates_to_use.append(mm_template) + + if use_triton_tma_template(mat1, mat2, output_layout=layout): + templates_to_use.append(persistent_tma_mm_template) + + if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): + templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) + + templates_to_use.append(mm_contiguous_subgraph_template) + + choices.extend( + V.choices.get_template_configs( + kernel_inputs, + templates_to_use, + "mm", + kwarg_overrides=kwarg_overrides, + ) + ) + + if ( + out_dtype is None + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, kernel_inputs.nodes() + ) + + if out_dtype is None and is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) + if out_dtype is None and is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): + CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes()) + + if out_dtype is None and use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + kernel_inputs.nodes(), + ) + + input_nodes = [mat1, mat2] + if ( + out_dtype is None + and is_nonzero + and use_triton_template(layout) + and torch._inductor.config.run_autoheuristic(name) + and is_triton(mat1) + ): + always_included = [] + if use_aten_gemm_kernels(): + always_included.append("extern_mm") + num_choices_before_extra_configs = len(choices) + choices.extend( + V.choices.get_template_configs( + # TODO(coconutruben): remove once we deprecate ah + # mm-extra is a hack to keep the ah functionality alive + # while we transition to the unified kwargs retrieval + kernel_inputs, + [mm_template], + "mm-ah", + ) + ) + + # using AutoHeuristic for ranking + ah_choices = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mm_operations(), + None, + top_k=10, + always_included=always_included, + ) + if not torch._inductor.config.collect_autoheuristic(name): + # if we are collecting data, we do not want to modify choices + if ah_choices is not None and len(ah_choices) > 0: + # the order in which autoheuristic returns choices is not the same as + # as the order of choices, which affects things like epilogue fusion. + # once epilogue fusion benchmarks choices in sorted order, I think we can + # just use the order returned by autoheuristic + choices = [choice for choice in choices if choice in ah_choices] + else: + choices = choices[:num_choices_before_extra_configs] + + if out_dtype is None: + for k in inductor_config.external_matmul: + choices.append( + lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout) + ) + + best_config_future = None + if out_dtype is None and torch._inductor.config.remote_gemm_autotune_cache: + # Purposely not awaiting the future here - this kicks off the best config lookup at lowering time + # The future will be awaited at scheduling time in select_algorithm.py + best_config_future = gen_best_config(mat1, mat2) + + if box := distributed_autotune.maybe_autotune_remote( + name, choices, kernel_inputs.nodes(), layout + ): + return box + + return autotune_select_algorithm( + name, + choices, + kernel_inputs.nodes(), + layout, + best_config_future=best_config_future, + ) + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + name = "int_mm" + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + static_shape, is_nonzero = _is_static_problem(layout) + use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) + choices: list[ChoiceCaller] = [] + + # Create MMKernelInputs for Int MM + kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32) + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + if use_aten_gemm_kernels(): + templates_to_use.append(aten__int_mm) + + if is_nonzero and use_triton_template( + layout, enable_int32=True, check_max_autotune=False + ): + templates_to_use.append(mm_template) + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs(kernel_inputs, templates_to_use, name) + ) + + if use_cutlass and _use_cutlass_for_op(name): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True + ) + + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + """ + Lowering for autotuning aten.addmm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if use_native_matmul(mat1, mat2): + if beta == 0: + arg1 = 0 + else: + arg1 = lowerings[aten.mul](beta, inp) + + if alpha == 0: + arg2 = 0 + else: + arg2 = lowerings[aten.mul](alpha, lowerings[aten.mm](mat1, mat2)) + + return lowerings[aten.add](arg1, arg2) + + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + static_shape, is_nonzero = _is_static_problem(layout) + name = "addmm" + + # Create MMKernelInputs for AddMM at the top + kernel_inputs = MMKernelInputs( + [inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) + ) + choices: list[ChoiceCaller] = [] + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + if (not is_nonzero) or ( + not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) + ): + # TODO(coconutruben): combine this with the main flow of addmm through + # a subgraph or something as inp vs inp_expanded causes some slight numeric + # differences + kernel_inputs = MMKernelInputs( + [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) + ) + choices.extend( + V.choices.get_template_configs( + kernel_inputs, + [aten_addmm], + name, + ) + ) + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + if use_aten_gemm_kernels(): + templates_to_use.extend([aten_bias_addmm, aten_addmm]) + + if is_nonzero and use_triton_template(layout, check_max_autotune=False): + templates_to_use.append(mm_template) + + if use_triton_tma_template(mat1, mat2, output_layout=layout): + templates_to_use.append(persistent_tma_mm_template) + + if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): + templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) + + templates_to_use.append(addmm_contiguous_subgraph_template) + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs(kernel_inputs, templates_to_use, name) + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op(name) + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + # reorder here because CUTLASS expects (x, w, bias) but torch + # is bias, x, w + kernel_inputs.nodes(reorder=[1, 2, 0]), + alpha=alpha, + beta=beta, + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices( + choices, + layout, + # reorder here because CK expects (x, w, bias) but torch + # is bias, x, w + kernel_inputs.nodes(reorder=[1, 2, 0]), + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + ) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + kernel_inputs.nodes(), + alpha=alpha, + beta=beta, + has_bias=True, + ) + + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + + +@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) +def tuned_sparse_semi_structured_mm( + mat1, mat1_meta, mat2, *, out_dtype=None, layout=None +): + from torch._inductor.select_algorithm import realize_inputs + + # TODO(coconturuben): support V.choices.get_mm_configs for sparse_semi_structured_mm + mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2) + m1, k1 = mat1.get_size() + m2, _ = mat1_meta.get_size() + k2, n = mat2.get_size() + m = V.graph.sizevars.check_equals_and_simplify(m1, m2) + k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + layout = FixedLayout( + mat2.get_device(), + out_dtype if out_dtype else mat2.get_dtype(), + [m, n], + [n, 1], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + choices = ( + [ + aten__sparse_semi_structured_mm.bind( + (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + m * n != 0 + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("sparse_semi_structured_mm") + ): + CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True + ) + + return autotune_select_algorithm( + "sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout + ) + + +scaling_pairs = [ + (ScalingType.TensorWise, ScalingType.TensorWise), + (ScalingType.RowWise, ScalingType.RowWise), + (ScalingType.BlockWise1x128, ScalingType.BlockWise128x128), + (ScalingType.BlockWise1x128, ScalingType.BlockWise1x128), +] + + +epilogue_scaling_types = [ScalingType.TensorWise, ScalingType.RowWise] +main_loop_scaling_types = [ScalingType.BlockWise1x128, ScalingType.BlockWise128x128] + + +def _is_tensorwise_scaling(sz: Any) -> bool: + return (len(sz) == 0) or all( + V.graph.sizevars.statically_known_equals(d, 1) for d in sz + ) + + +def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool: + idx = 0 if transpose else -1 + return V.graph.sizevars.statically_known_equals(sz[idx], 1) + + +def _is_blockwise1xTILESIZE_scaling( + sz: Any, tensor_sz: Any, tile_size: int, transpose: bool +) -> bool: + lhs = 1 if transpose else 0 + rhs = 0 if transpose else 1 + return V.graph.sizevars.statically_known_equals( + sz[lhs], tensor_sz[lhs] + ) and V.graph.sizevars.statically_known_equals( + sz[rhs], ceildiv(tensor_sz[rhs], tile_size) + ) + + +def _is_blockwise128x128_scaling(sz: Any, tensor_sz: Any) -> bool: + return V.graph.sizevars.statically_known_equals( + sz[0], ceildiv(tensor_sz[0], 128) + ) and V.graph.sizevars.statically_known_equals(sz[1], ceildiv(tensor_sz[1], 128)) + + +def is_desired_scaling( + t: Any, + scale_size: torch.Tensor, + scaling_type: ScalingType, + transpose: bool = False, +) -> bool: + match scaling_type: + case ScalingType.TensorWise: + return _is_tensorwise_scaling(scale_size) + case ScalingType.RowWise: + return _is_rowwise_scaling(scale_size, transpose) + case ScalingType.BlockWise1x128: + return _is_blockwise1xTILESIZE_scaling( + scale_size, t.get_size(), 128, transpose + ) + case ScalingType.BlockWise128x128: + return _is_blockwise128x128_scaling(scale_size, t.get_size()) + case _: + raise AssertionError(f"Unsupported scaling type {scaling_type}") + + +def get_tile_size(scale_option) -> int: + match scale_option: + case ScalingType.BlockWise128x128: + return 128 + case ScalingType.BlockWise1x128: + return 128 + case _: + raise AssertionError( + f"Unsupported scaling type {scale_option} in get_tile_size" + ) + + +def get_scaling_options( + mat_a: Any, + mat_b: Any, + scale_a_size: torch.Tensor, + scale_b_size: torch.Tensor, +) -> tuple[ScalingType, ScalingType]: + for scale_option_a, scale_option_b in scaling_pairs: + if is_desired_scaling( + mat_a, scale_a_size, scale_option_a + ) and is_desired_scaling(mat_b, scale_b_size, scale_option_b, transpose=True): + return scale_option_a, scale_option_b + + raise AssertionError( + f"Inductor Triton does not support scale_a.shape = {scale_a_size}, scale_b.shape = {scale_b_size}" + ) # verify that shapes are supported by at least one existing pairing + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + """ + Performs an optimized matrix multiplication where scaling factors are applied + to the inputs and/or output. + + Args: + mat1 (Tensor): First input matrix + mat2 (Tensor): Second input matrix + scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting) + scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting) + bias (Tensor, optional): Optional bias tensor to add to the result + layout: Layout hint for optimization + + Returns: + Tensor: The result of the scaled matrix multiplication + """ + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + name = "scaled_mm" + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: list[Any] + + if not bias: + input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real] + else: + bias_real = realize_inputs(bias) + input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real] + + # Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1) + kernel_inputs = MMKernelInputs( + input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype + ) + + choices: list[ChoiceCaller] = [] + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + kwarg_overrides = {} + + if use_aten_gemm_kernels(): + templates_to_use.append(aten__fp8_mm) + kwarg_overrides[aten__fp8_mm.uid] = dict( + out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + _, is_nonzero = _is_static_problem(layout) + + if ( + # We dont have triton lowerings for the MX variants yet + scale_a.dtype == torch.float32 + and is_nonzero + and use_triton_template(layout, enable_float8=True, check_max_autotune=False) + ): + overriders = dict(USE_FAST_ACCUM=use_fast_accum) + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exist + if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias: + scale_a_size, scale_b_size = scale_a_real.shape, scale_b_real.shape + + scale_option_a, scale_option_b = get_scaling_options( + mat_a, mat_b, scale_a_size, scale_b_size + ) + overriders["SCALE_RECIPE_A"] = scale_option_a.value + overriders["SCALE_RECIPE_B"] = scale_option_b.value + + if ( + scale_option_a in epilogue_scaling_types + and scale_option_b in epilogue_scaling_types + ): + templates_to_use.append(scaled_mm_device_tma_epilogue_scaling_template) + kwarg_overrides[scaled_mm_device_tma_epilogue_scaling_template.uid] = ( + overriders + ) + elif ( + scale_option_a in main_loop_scaling_types + and scale_option_b in main_loop_scaling_types + ): + overriders["TILE_SIZE_A"] = get_tile_size(scale_option_a) + overriders["TILE_SIZE_B"] = get_tile_size(scale_option_b) + + templates_to_use.append(scaled_mm_device_tma_main_loop_scaling_template) + kwarg_overrides[scaled_mm_device_tma_main_loop_scaling_template.uid] = ( + overriders + ) + else: + raise AssertionError( + "Inductor Triton does not support scaling options that are present " + + "in both epilogue scaling and main loop scaling" + ) + + if ( + use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout) + and not bias + ): + templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) + kwarg_overrides[blackwell_ws_persistent_device_tma_mm_template.uid] = ( + overriders + ) + + templates_to_use.append(mm_template) + kwarg_overrides[mm_template.uid] = overriders + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs( + kernel_inputs, + templates_to_use, + name, + kwarg_overrides=kwarg_overrides, + ) + ) + + # Early return for MX variants + if scale_a.dtype != torch.float32: + return autotune_select_algorithm(name, choices, input_nodes, layout) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op(name) + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + kernel_inputs.nodes(), # type: ignore[arg-type] + use_fast_accum=use_fast_accum, # type: ignore[arg-type] + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) + + return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + + +@functools.cache +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def dims_are_int(dims): + return all(isinstance(dim, int) for dim in dims) + + +def mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + ops, + precondition, + top_k: Optional[int] = None, + always_included=None, +): + m, n, k = get_size_hints(mat1, mat2, m, n, k) + if not dims_are_int([m, n, k]): + return None + mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2) + + def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride): + context = AHContext() + context.add_feature("m", m) + context.add_feature("k", k) + context.add_feature("n", n) + context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True) + context_add_strides(context, "mat1", mat1_stride) + context_add_strides(context, "mat2", mat2_stride) + context.add_feature( + "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True + ) + context.add_feature( + "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True + ) + if name == "mm": + context_add_using_tf32(context, mat1.layout.dtype) + return context + + def fallback(): + return None + + context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride) + autoheuristic = AutoHeuristicSelectAlgorithm( + fallback=fallback, + choices=choices, + input_nodes=input_nodes, + context=context, + name=name, + augment_context=ops, + precondition=precondition, + ) + + if top_k is not None: + # TODO: is there a cleaner way to ensure aten.mm is always included? + return autoheuristic.get_top_k_choices_caller( + top_k, always_included=always_included + ) + + return autoheuristic.get_choice_caller() + + +def get_size_hints(mat1, mat2, m, n, k): + if not isinstance(m, int) or not isinstance(k, int): + (m, k) = V.graph.sizevars.size_hints( + mat1.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + + if not isinstance(n, int) or not isinstance(k, int): + (k, n) = V.graph.sizevars.size_hints( + mat2.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + return m, n, k + + +def get_size_hints_strides(mat1, mat2): + mat1_stride = mat1.layout.stride + mat2_stride = mat2.layout.stride + strides = [mat1_stride, mat2_stride] + strides_hints = [] + for stride in strides: + if not isinstance(stride, int): + stride = V.graph.sizevars.size_hints( + stride, + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + strides_hints.append(stride) + return strides_hints[0], strides_hints[1] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..eb22b95af2afcef65cb4876d9c9685633e5bde70 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,263 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from functools import partial +from pathlib import Path +from typing import Any + +import torch +from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn +from torch._inductor.utils import get_current_backend, sympy_product +from torch._inductor.virtualized import V +from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + +from .. import config +from ..codegen.wrapper import PythonWrapperCodegen +from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template + + +log = logging.getLogger(__name__) + + +@SymbolicGridFn +def mm_grid(m, n, meta, *, cdiv): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +@SymbolicGridFn +def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min): + """Defines the grid for persistent kernels.""" + return ( + min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), + 1, + 1, + ) + + +@SymbolicGridFn +def persistent_grouped_mm_grid(*args): + meta = args[-1] + return (meta["NUM_SMS"], 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_args( + mat1, + mat2, + *others, + layout=None, + out_dtype=None, + use_4x2_dim=False, + mat2_transposed=False, +): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + if mat2_transposed: + *b2, n, k2 = mat2.get_size() + else: + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.check_equals_and_simplify(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue + + +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + +def use_native_matmul(mat1, mat2): + if not config.triton.native_matmul: + return False + + # If tma matmul is on, don't do native matmul + if ( + config.triton.enable_persistent_tma_matmul + and torch.utils._triton.has_triton_tma_device() + ): + raise AssertionError("native matmul doesn't support tma codegen yet") + + # Currently only enable native matmul for default indexing + # TODO : support block ptr + if config.triton.use_block_ptr: + raise AssertionError("native matmul doesn't support block_ptr codegen yet") + + # Currently only enable native matmul for triton on GPU. + device_type = mat1.get_device().type + if not ( + device_type in ("cuda", "xpu") and get_current_backend(device_type) == "triton" + ): + return False + + # Currently, tl.dot only supports following dtypes + triton_supported_dtype = [ + torch.int8, + torch.uint8, + torch.float16, + torch.bfloat16, + torch.float32, + ] + if mat1.dtype not in triton_supported_dtype: + return False + if mat2.dtype not in triton_supported_dtype: + return False + + # (..., M, K) @ (..., K, N) + m, k, n = mat1.get_size()[-2], mat1.get_size()[-1], mat2.get_size()[-1] + + # If the shape has unbacked symbols, don't do native matmul. + # This is related to the behavior of statically_known_multiple_of on unbacked symints. + # Since statically_known_multiple_of just returns False for unbacked symbols + # due to the expensive cost, codegen fails when there is a unbacked symbol. + # In particular, it fails at _split_iteration_ranges in codegen/simd.py. + # See this : https://github.com/pytorch/pytorch/pull/131649 + if any(map(has_free_unbacked_symbols, [m, k, n])): + return False + + # Consider the shape (m,k,n) > 1 + # TODO : support when size = 1 + if ( + V.graph.sizevars.statically_known_leq(m, 1) + or V.graph.sizevars.statically_known_leq(k, 1) + or V.graph.sizevars.statically_known_leq(n, 1) + ): + return False + + return True + + +def _is_static_problem(layout: Layout) -> tuple[bool, bool]: + """ + Check if input tensors and output layout have static shapes and non-zero sizes. + + Args: + layout: Output layout object with a 'size' attribute. + + Returns: + Tuple[bool, bool]: (is_static, is_nonzero) + is_static: True if all shapes are statically known + is_nonzero: True if all dimensions are non-zero + """ + static_shape = True + static_size = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + layout.size + ) + if static_size is None: + nonzero = True + for s in layout.size: + sz = PythonWrapperCodegen.statically_known_int_or_none(s) + if sz is not None and sz == 0: + nonzero = False + break + return False, nonzero + numel = 1 + for dim in static_size: + numel *= dim + nonzero = numel > 0 + return static_shape, nonzero + + +def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: + def is_row_major(stride: Sequence[_IntLike]) -> bool: + return stride[-1] == 1 + + def is_col_major(stride: Sequence[_IntLike]) -> bool: + return stride[-2] == 1 + + def has_zero_dim(size: Sequence[_IntLike]) -> bool: + return bool(size[0] == 0 or size[1] == 0) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: + """ + Checking if the batch stride is the largest in the stride. + """ + sizes = [mat1.get_size(), mat2.get_size(), layout.size] + strides = [mat1.get_stride(), mat2.get_stride(), layout.stride] + for size, stride in zip(sizes, strides): + assert len(size) == len(stride) == 3, "Expect 3D tensors" + if stride[0] != 0 and stride[0] != sympy_product(size[1:]): + return False + + return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..35ee6a541c15079d32f8291ee57e7e3909956cb4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_grouped.py @@ -0,0 +1,891 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import asdict, dataclass +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate +from torch._inductor.runtime.triton_compat import tl +from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs +from torch._inductor.virtualized import V +from torch.utils._triton import has_triton + +from ..ir import ChoiceCaller, Layout, TensorBox +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + get_num_sms, + has_free_symbols, + use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, + use_triton_template, +) +from .mm_common import ( + _is_static_problem, + check_supported_striding, + load_kernel_template, + persistent_grouped_mm_grid, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass +class Config: + kwargs: dict[str, int] + num_stages: int + num_warps: int + + +_NV_CONFIGS = [ + Config( + { + "BLOCK_M": block_size_m, + "BLOCK_N": block_size_n, + "BLOCK_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [16, 32, 64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] +] + + +def grouped_mm_configs(): + return _NV_CONFIGS + + +def early_config_prune(g, m, dtsize, configs, named_args): + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + config.num_warps, + getattr(config, "num_consumer_groups", 0), + ) + + # 1. Prune NV configs depending on g and m. + if not has_free_symbols((g, m)): + a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"] + m_avg = m // g if a_is_2d and not b_is_2d else m + if m_avg <= 16: + if BLOCK_M > 32: + continue + elif m_avg <= 32: + if BLOCK_M > 64: + continue + elif m_avg <= 64: + if BLOCK_M <= 16: + continue + else: + if BLOCK_M <= 32: + continue + + # 2. make sure we have enough smem + max_shared_memory = get_gpu_shared_memory() + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + # 3. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + pruned_configs.append(config) + + return pruned_configs + + +triton_grouped_mm_source = r""" +{% macro assign_maybe_constexpr(name, value_expr) -%} + {%- set value_str = value_expr | string -%} + {%- set sentinel = "__NOT_A_NUMBER__" -%} + {%- set as_int = value_str | int(default=sentinel) -%} + {%- set as_float = value_str | float(default=sentinel) -%} + {%- set is_constexpr = (as_int != sentinel) or (as_float != sentinel) -%} + {{ name }}{{ ": tl.constexpr" if is_constexpr else "" }} = {{ value_expr }} +{%- endmacro %} + +import triton +import triton.language as tl + +@triton.jit +def do_tma_loads( + g, a_desc, b_desc, m_offset, n_offset, k_offset, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): +{%- if A_IS_2D %} +{%- if A_IS_K_MAJOR %} + a = a_desc.load([m_offset, k_offset]) +{%- else %} + a = a_desc.load([k_offset, m_offset]) +{%- endif %} +{%- else %} +{%- if A_IS_K_MAJOR %} + a = a_desc.load([g, m_offset, k_offset]).reshape(BLOCK_M, BLOCK_K) +{%- else %} + a = a_desc.load([g, k_offset, m_offset]).reshape(BLOCK_K, BLOCK_M) +{%- endif %} +{%- endif %} +{%- if B_IS_2D %} +{%- if B_IS_K_MAJOR %} + b = b_desc.load([n_offset, k_offset]) +{%- else %} + b = b_desc.load([k_offset, n_offset]) +{%- endif %} +{%- else %} +{%- if B_IS_K_MAJOR %} + b = b_desc.load([g, n_offset, k_offset]).reshape(BLOCK_N, BLOCK_K) +{%- else %} + b = b_desc.load([g, k_offset, n_offset]).reshape(BLOCK_K, BLOCK_N) +{%- endif %} +{%- endif %} + + return (a, b) + + +@triton.jit +def do_mma(a, b, accumulator): +{%- if USE_FAST_ACCUM %} +{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %} + accumulator = tl.dot(a, b.T, accumulator) +{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %} + accumulator = tl.dot(a, b, accumulator) +{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %} + accumulator = tl.dot(a.T, b.T, accumulator) +{%- else %} + accumulator = tl.dot(a.T, b, accumulator) +{%- endif %} +{%- else %} +{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %} + accumulator += tl.dot(a, b.T) +{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %} + accumulator += tl.dot(a, b) +{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %} + accumulator += tl.dot(a.T, b.T) +{%- else %} + accumulator += tl.dot(a.T, b) +{%- endif %} +{%- endif %} + + return accumulator + + +{%- if SCALED %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}} +{%- endif %} +{%- else %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr")}} +{%- endif %} +{%- endif %} + tidx = tl.program_id(0).to(INDEX_DTYPE) + +{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %} +{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %} +{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %} + +{%- if A_IS_2D %} +{%- if B_IS_2D %} + {{ assign_maybe_constexpr("G", size("offsets_ptr", 0)) }} +{%- else %} + {{ assign_maybe_constexpr("G", size("b_ptr", 0)) }} +{%- endif %} +{%- else %} +{%- if B_IS_2D %} + {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }} +{%- else %} + {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }} +{%- endif %} +{%- endif %} + + # the b_ptr tensor is given with its last two dims transposed, revert here + + {{ assign_maybe_constexpr("M", size("a_ptr", -2)) }} + {{ assign_maybe_constexpr("N", size("b_ptr", -1)) }} + {{ assign_maybe_constexpr("K", size("a_ptr", -1)) }} + + {{ assign_maybe_constexpr("A_STRIDE_M", stride("a_ptr", -2)) }} + {{ assign_maybe_constexpr("A_STRIDE_K", stride("a_ptr", -1)) }} +{%- if not A_IS_2D %} + {{ assign_maybe_constexpr("A_STRIDE_G", stride("a_ptr", 0)) }} +{%- if SCALED %} + {{ assign_maybe_constexpr("SCALE_A_STRIDE_G", stride("scale_a_ptr", 0)) }} +{%- endif %} +{%- endif %} + {{ assign_maybe_constexpr("B_STRIDE_N", stride("b_ptr", -1)) }} + {{ assign_maybe_constexpr("B_STRIDE_K", stride("b_ptr", -2)) }} +{%- if not B_IS_2D %} + {{ assign_maybe_constexpr("B_STRIDE_G", stride("b_ptr", 0)) }} + B_STRIDE_G = {{stride("b_ptr", 0)}} +{%- if SCALED %} + {{ assign_maybe_constexpr("SCALE_B_STRIDE_G", stride("scale_b_ptr", 0)) }} +{%- endif %} +{%- endif %} + +{%- if USE_TMA_LOAD %} +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + a_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + a_desc = tl.make_tensor_descriptor( +{%- endif %} + a_ptr, +{%- if A_IS_2D %} +{%- if A_IS_K_MAJOR %} + shape=[M, K], + strides=[A_STRIDE_M, A_STRIDE_K], + block_shape=[BLOCK_M, BLOCK_K], +{%- else %} + shape=[K, M], + strides=[A_STRIDE_K, A_STRIDE_M], + block_shape=[BLOCK_K, BLOCK_M], +{%- endif %} +{%- else %} +{%- if A_IS_K_MAJOR %} + shape=[G, M, K], + strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K], + block_shape=[1, BLOCK_M, BLOCK_K], +{%- else %} + shape=[G, K, M], + strides=[A_STRIDE_G, A_STRIDE_K, A_STRIDE_M], + block_shape=[1, BLOCK_K, BLOCK_M], +{%- endif %} +{%- endif %} + ) + +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + b_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + b_desc = tl.make_tensor_descriptor( +{%- endif %} + b_ptr, +{%- if B_IS_2D %} +{%- if B_IS_K_MAJOR %} + shape=[N, K], + strides=[B_STRIDE_N, B_STRIDE_K], + block_shape=[BLOCK_N, BLOCK_K], +{%- else %} + shape=[K, N], + strides=[B_STRIDE_K, B_STRIDE_N], + block_shape=[BLOCK_K, BLOCK_N], +{%- endif %} +{%- else %} +{%- if B_IS_K_MAJOR %} + shape=[G, N, K], + strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K], + block_shape=[1, BLOCK_N, BLOCK_K], +{%- else %} + shape=[G, K, N], + strides=[B_STRIDE_G, B_STRIDE_K, B_STRIDE_N], + block_shape=[1, BLOCK_K, BLOCK_N], +{%- endif %} +{%- endif %} + ) +{%- endif %} + +{%- if M_IS_VARYING %} + m_end_offset = 0 +{%- endif %} +{%- if N_IS_VARYING %} + n_end_offset = 0 +{%- endif %} +{%- if K_IS_VARYING %} + k_end_offset = 0 +{%- endif %} + iterated_tiles = 0 + for g in tl.range(G): +{%- if M_IS_VARYING %} + # Move across groups + m_start_offset = m_end_offset + m_end_offset = tl.load(offsets_ptr + g) + m_size = m_end_offset - m_start_offset +{%- if SCALED %} + m_scale_start_offset = m_start_offset +{%- endif %} +{%- else %} + m_start_offset = 0 + m_size = M +{%- if SCALED %} + m_scale_start_offset = g * M +{%- endif %} +{%- endif %} + +{%- if N_IS_VARYING %} + # Move across groups + n_start_offset = n_end_offset + n_end_offset = tl.load(offsets_ptr + g) + n_size = n_end_offset - n_start_offset +{%- if SCALED %} + n_scale_start_offset = n_start_offset +{%- endif %} +{%- else %} + n_start_offset = 0 + n_size = N +{%- if SCALED %} + n_scale_start_offset = g * N +{%- endif %} +{%- endif %} + + if m_size > 0 and n_size > 0: +{%- if K_IS_VARYING %} + # Move across groups + k_start_offset = k_end_offset + k_end_offset = tl.load(offsets_ptr + g) + k_size = k_end_offset - k_start_offset +{%- else %} + k_start_offset = 0 + k_size = K +{%- endif %} + + num_m_tiles = tl.cdiv(m_size, BLOCK_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_N) + num_tiles = num_m_tiles * num_n_tiles + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{%- if USE_TMA_LOAD %} + m_tile_offset = tile_m_idx * BLOCK_M + n_tile_offset = tile_n_idx * BLOCK_N + m_offset = (m_start_offset + m_tile_offset).to(tl.int32) + n_offset = (n_start_offset + n_tile_offset).to(tl.int32) + + k_block_offset = 0 + for k in range(k_size // BLOCK_K): + k_offset = k_start_offset + k_block_offset + a, b = do_tma_loads( + g, a_desc, b_desc, m_offset, n_offset, k_offset, + BLOCK_M, BLOCK_N, BLOCK_K + ) + accumulator = do_mma(a, b, accumulator) + k_block_offset += BLOCK_K + + if k_size % BLOCK_K != 0: + k_offset = k_start_offset + k_block_offset + a, b = do_tma_loads( + g, a_desc, b_desc, m_offset, n_offset, k_offset, + BLOCK_M, BLOCK_N, BLOCK_K + ) +{%- if K_IS_VARYING %} + group_offs = k_block_offset + tl.arange(0, BLOCK_K) + k_mask = group_offs < k_size +{%- if A_IS_K_MAJOR %} + a = tl.where(k_mask[None, :], a, 0) +{%- else %} + a = tl.where(k_mask[:, None], a, 0) +{%- endif %} +{%- if B_IS_K_MAJOR %} + b = tl.where(k_mask[None, :], b, 0) +{%- else %} + b = tl.where(k_mask[:, None], b, 0) +{%- endif %} +{%- endif %} + accumulator = do_mma(a, b, accumulator) +{%- else %} + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + for k_block_offset in range(0, k_size, BLOCK_K): + block_offs_k = k_block_offset + tl.arange(0, BLOCK_K) + offs_k = block_offs_k + k_start_offset + a_ptrs = ( + a_ptr +{%- if not A_IS_2D %} + + g * A_STRIDE_G +{%- endif %} + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr +{%- if not B_IS_2D %} + + g * B_STRIDE_G +{%- endif %} + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) + a_mask = (offs_am[:, None] < m_size) & (block_offs_k[None, :] < k_size) + b_mask = (offs_bn[:, None] < n_size) & (block_offs_k[None, :] < k_size) + a = tl.load(a_ptrs, mask=a_mask, other=tl.zeros((), dtype=a_ptrs.dtype.element_ty)) + b = tl.load(b_ptrs, mask=b_mask, other=tl.zeros((), dtype=b_ptrs.dtype.element_ty)) +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K +{%- endif %} + + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) +{%- if SCALED %} + scale_a = tl.load( + scale_a_ptr +{%- if A_IS_2D %} + + m_scale_start_offset +{%- else %} + + g * SCALE_A_STRIDE_G +{%- endif %} + + offs_am[:, None], + mask=offs_am[:, None] < m_size, + other=tl.zeros((), dtype=scale_a_ptr.dtype.element_ty), + ) + scale_b = tl.load( + scale_b_ptr +{%- if B_IS_2D %} + + n_scale_start_offset +{%- else %} + + g * SCALE_B_STRIDE_G +{%- endif %} + + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + other=tl.zeros((), dtype=scale_b_ptr.dtype.element_ty), + ) + c = accumulator.to(tl.float32) * scale_a * scale_b +{%- else %} + c = accumulator.to(tl.float32) +{%- endif %} + +{%- if M_IS_VARYING %} + idx_m = (m_start_offset + offs_am[:, None]) +{%- else %} + idx_m = offs_am[:, None] +{%- endif %} +{%- if N_IS_VARYING %} + idx_n = (n_start_offset + offs_bn[None, :]) +{%- else %} + idx_n = offs_bn[None, :] +{%- endif %} + mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < n_size) +{%- if M_IS_VARYING or N_IS_VARYING %} + {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}} +{%- else %} + {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}} +{%- endif %} + tidx += NUM_SMS + + iterated_tiles += num_tiles +""" + + +triton_grouped_mm_template = TritonTemplate( + name="grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +triton_scaled_grouped_mm_template = TritonTemplate( + name="scaled_grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + + +def grouped_mm_args( + mat1: TensorBox, + mat2: TensorBox, + offs: Optional[TensorBox], + layout=None, + out_dtype=None, +): + mat1, mat2 = realize_inputs(mat1, mat2) + if offs is not None: + realize_inputs(offs) + mat1_size = mat1.get_size() + mat2_size = mat2.get_size() + + m1dim, m2dim = len(mat1_size), len(mat2_size) + + assert m1dim == 2 or m1dim == 3 + assert m2dim == 2 or m2dim == 3 + + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + alignment = 16 // out_dtype.itemsize + + if m1dim == 2: + if m2dim == 2: + assert offs is not None + out_size = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + else: + out_size = [mat1_size[0], mat2_size[-1]] + else: + if m2dim == 2: + out_size = [mat1_size[1], mat2_size[1]] + else: + out_size = [mat1_size[0], mat1_size[1], mat2_size[-1]] + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if len(out_size) == 2: + out_stride = [size_padded, 1] + else: + out_stride = [out_size[1] * size_padded, size_padded, 1] + + layout = FixedLayout( + mat1.get_device(), + out_dtype, + out_size, + out_stride, + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + return (mat1_size, mat2_size, layout, mat1, mat2, offs) + + +aten__grouped_mm = ExternKernelChoice( + torch._grouped_mm, + "at::_grouped_mm", + op_overload=aten._grouped_mm.default, + has_out_variant=False, +) + + +aten__scaled_grouped_mm = ExternKernelChoice( + torch._scaled_grouped_mm, + "at::_scaled_grouped_mm", + op_overload=aten._scaled_grouped_mm.default, + has_out_variant=False, +) + + +def can_use_triton_kernel( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox], + bias: Optional[TensorBox], + scale_result: Optional[TensorBox], +) -> bool: + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + return False + if not has_triton(): + return False + + # The _grouped_mm()/_scaled_grouped_mm() operator do not support + # bias nor scale_result yet. + if bias is not None: + return False + if scale_result is not None: + return False + + if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2: + return offs is not None + else: + return offs is None + + +def create_offsets(x, m1_size, m2_size, offs_size): + m1_is_2d = len(m1_size) == 2 + m2_is_2d = len(m2_size) == 2 + if m1_is_2d: + if m2_is_2d: + k = V.graph.sizevars.size_hint(m1_size[1]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = k / noffs + return torch.linspace( + step, k, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + + else: + m = V.graph.sizevars.size_hint(m1_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = m / noffs + return torch.linspace( + step, m, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + if m2_is_2d: + n = V.graph.sizevars.size_hint(m2_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = n / noffs + return torch.linspace( + step, n, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + return None + + +def _tuned_grouped_mm_common( + operator_name: str, + algorithm_name: str, + extern_kernel_choice: ExternKernelChoice, + kernel_template: TritonTemplate, + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: Optional[TensorBox] = None, + scale_b: Optional[TensorBox] = None, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: Optional[bool] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + assert (scale_a is None) == (scale_b is None) + assert scale_result is None or scale_a is not None + + m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( + mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype + ) + counters["aten_mm_info"][operator_name] += 1 + log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s" + log.info( + log_message, + m1_size, + m2_size, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + if scale_a is not None and scale_b is not None: + check_supported_striding(mat_a, mat_b) + + # workaround for Inductor not supporting optional tensor input arguments + input_nodes: list[Any] = [mat_a, mat_b] + if scale_a is not None: + input_nodes.append(realize_inputs(scale_a)) + if scale_b is not None: + input_nodes.append(realize_inputs(scale_b)) + if offs is not None: + input_nodes.append(realize_inputs(offs)) + + if use_fast_accum is None: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + ) + else: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + if use_fast_accum is None: + use_fast_accum = False + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + # Checking only for the equality of corresponding dims of + # multiplicands here, relying on meta function checks for + # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + + if ( + is_nonzero + and use_triton_template(layout) + and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) + ): + scaled = scale_a is not None + + a_is_k_major = mat_a.get_stride()[-1] == 1 + b_is_k_major = mat_b.get_stride()[-2] == 1 + + triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor") + triton_has_experimental_make_tensor_descriptor = hasattr( + tl, "_experimental_make_tensor_descriptor" + ) + use_tma_load = ( + triton_has_make_tensor_descriptor + or triton_has_experimental_make_tensor_descriptor + ) + kwargs = { + "SCALED": scaled, + "A_IS_2D": a_is_2d, + "B_IS_2D": b_is_2d, + "A_IS_K_MAJOR": a_is_k_major, + "B_IS_K_MAJOR": b_is_k_major, + "USE_FAST_ACCUM": use_fast_accum, + "NUM_SMS": get_num_sms(), + "USE_TMA_LOAD": use_tma_load, + "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor, + } + + for config in early_config_prune( + g, m, mat_a.dtype.itemsize, grouped_mm_configs(), kwargs + ): + kernel_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + num_stages=config.num_stages, + num_warps=config.num_warps, + **kwargs, + **config.kwargs, + ) + + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + + input_gen_fns = { + 4: lambda x: create_offsets( + x, m1_size, m2_size, offs.get_size() if offs is not None else None + ), + } + return autotune_select_algorithm( + algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns + ) + + +@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +def tuned_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._grouped_mm.default", + "grouped_mm", + aten__grouped_mm, + triton_grouped_mm_template, + mat_a, + mat_b, + None, + None, + offs, + bias, + None, + out_dtype, + None, + layout, + ) + + +@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) +def tuned_scaled_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _scaled_grouped_mm() operator.""" + + # matching _scaled_grouped_mm_cuda Blas.cpp implementation + out_dtype = out_dtype or torch.bfloat16 + + return _tuned_grouped_mm_common( + "aten._scaled_grouped_mm.default", + "scaled_grouped_mm", + aten__scaled_grouped_mm, + triton_scaled_grouped_mm_template, + mat_a, + mat_b, + scale_a, + scale_b, + offs, + bias, + scale_result, + out_dtype, + use_fast_accum, + layout, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..aef8dfb2168f4e9f410310f898ff3ae08bae02ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs + +import logging +from typing import TYPE_CHECKING, Union + +import torch + +from .. import config as inductor_config +from ..kernel_inputs import MMKernelInputs +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid + + +if TYPE_CHECKING: + from torch._inductor.ir import ChoiceCaller + from torch._inductor.select_algorithm import KernelTemplate + +log = logging.getLogger(__name__) + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1)) + and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + + if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1)) + and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} +""", + cache_codegen_enabled_for_template=True, +) + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + or inductor_config.triton.native_matmul + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/triton-lang/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + # Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair) + # Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction + kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1) + + assert layout1 == layout2 + # options to tune from + choices: list[ChoiceCaller] = [] + + # Collect all templates for unified call + templates_to_use: list[Union[ExternKernelChoice, KernelTemplate]] = [] + if use_aten_gemm_kernels(): + templates_to_use.append(aten_mm_plus_mm) + + if use_triton_template(layout1, check_max_autotune=False): + templates_to_use.append(mm_plus_mm_template) + + # Single unified call for all templates + choices.extend( + V.choices.get_template_configs(kernel_inputs, templates_to_use, "mm_plus_mm") + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, kernel_inputs.nodes(), layout1 + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73f847180a944d7368474f4d984c6511453634be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c4416bea45d8702d27d3dc5ee565e78833fc2c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebb1d5618bfa5eb44cfa9c4bc3921e887f54374 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__init__.py @@ -0,0 +1,32 @@ +""" +Template lookup table system for PyTorch Inductor. + +This package provides functionality for: +- Loading pre-configured template choices from lookup tables +- Managing template configurations and choices + +All functionality is contained within the LookupTableChoices class. +You can customize any aspect by subclassing LookupTableChoices and overriding methods. + +Usage: + # Basic usage + choices = LookupTableChoices() + V.set_choices_handler(choices) + + # Custom usage + class MyCustomChoices(LookupTableChoices): + def _get_lookup_table(self): + return my_custom_table + + def make_lookup_key(self, kernel_inputs, op_name, include_device=False): + return f"custom_{op_name}_{hash(str(kernel_inputs))}" + + V.set_choices_handler(MyCustomChoices()) +""" + +from .choices import LookupTableChoices + + +__all__ = [ + "LookupTableChoices", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..46e54180114aba4b59f0832b6cd64a408df521c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/choices.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import copy +import logging +from functools import lru_cache +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.choices import InductorChoices +from torch._inductor.kernel_template_choice import KernelTemplateChoice +from torch._inductor.template_heuristics.params import DictKernelTemplateParams + + +log = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from collections.abc import Generator + + from torch._inductor.codegen.common import KernelTemplate + from torch._inductor.kernel_inputs import KernelInputs + from torch._inductor.select_algorithm import ExternKernelChoice + + +class LookupTableChoices(InductorChoices): + """ + InductorChoices subclass that uses lookup table when available, otherwise falls back to parent. + All lookup functionality is contained within this class and can be customized by overriding methods. + """ + + def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]: + """ + Get the template lookup table from config. + Override this method to use custom lookup table sources (database, API, etc.). + """ + if not torch.cuda.is_available() or config.lookup_table.table is None: + return {} + return config.lookup_table.table + + @staticmethod + @lru_cache + def _get_device_key(device: torch.device) -> Optional[str]: + """ + Generate a device key for lookup table indexing. + For CPU devices, returns None. + For CUDA devices, returns the props.gcnArchName string. + """ + if device.type != "cuda": + # only cuda devices are supported, this indicates that the system is not in use + # for this device + return None + + # Get CUDA device properties + props = torch.cuda.get_device_properties(device.index) + return props.gcnArchName + + @staticmethod + def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str: + """ + Generate a key based on input node properties and scalars. + The key includes dtype, size, and stride information for each input node, + plus scalar values as key=value pairs separated by & signs. + """ + # Get node information using existing methods + dtypes = kernel_inputs.dtypes() + shapes = kernel_inputs.shapes_hinted() + strides = kernel_inputs.strides_hinted() + + # Create tuple of (dtype, shape_list, stride_list) for each node + node_info = tuple( + (dtype, list(shape), list(stride)) + for dtype, shape, stride in zip(dtypes, shapes, strides) + ) + + # Create base key from node information + fmt_key = str(node_info) + # Add scalar information if present + if kernel_inputs._scalars: + # Sort scalars for consistent key generation and join with & + scalar_parts = [ + f"{key}={value}" + for key, value in sorted(kernel_inputs._scalars.items()) + ] + scalars_key = "&".join(scalar_parts) + fmt_key = f"{fmt_key}+{scalars_key}" + + return f"{fmt_key}" + + def make_lookup_key( + self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False + ) -> Optional[str]: + """ + Create a flattened lookup key from kernel inputs and operation name. + Override this method to customize key generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + include_device: Whether to include device key in the generated key + + Returns: + A string key combining device (optional), operation, and input information + """ + device = kernel_inputs.device() + dev_key = self._get_device_key(device) + if dev_key is None: + # The system does not run when dev_key is None, regardless of + # whether include_device is True or False + return None + if not include_device: + dev_key = None + + # Generate input key using our staticmethod + input_key = self._generate_kernel_inputs_key(kernel_inputs) + + # Create the flattened lookup key + if dev_key is not None: + key_parts = [dev_key, input_key, op_name] + else: + key_parts = [input_key, op_name] + + return "+".join(key_parts) + + def make_lookup_key_variants( + self, kernel_inputs: KernelInputs, op_name: str + ) -> tuple[Optional[str], Optional[str]]: + """ + Generate both device-specific and device-agnostic lookup keys. + Override this method to customize key variant generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + + Returns: + Tuple of (device_key, device_agnostic_key). Either may be None if generation fails. + """ + device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True) + device_agnostic_key = self.make_lookup_key( + kernel_inputs, op_name, include_device=False + ) + + return device_key, device_agnostic_key + + @staticmethod + def _entry_is_valid( + cfg: dict[str, Any], + template_id: str, + template_hash_map: Optional[dict[str, Optional[str]]], + ) -> bool: + """ + Check if a config entry is valid based on template hash validation. + + Args: + cfg: Configuration dictionary that may contain a template_hash field + template_id: The template identifier + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + True if the config is valid and should be kept, False if it should be filtered out + """ + # If hash checking is disabled or no hash map provided, keep the config + if not config.lookup_table.check_src_hash or not template_hash_map: + return True + + template_hash = template_hash_map.get(template_id) + config_hash = cfg.get("template_hash") + + # Both hashes present - validate they match + if template_hash is not None and config_hash is not None: + if config_hash != template_hash: + log.warning( + "Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. " + "Template code may have changed. Filtering out config: %s", + template_id, + config_hash, + template_hash, + {k: v for k, v in cfg.items() if k != "template_hash"}, + ) + return False + else: + log.debug( + "Hash validation passed for template '%s': hash='%s'", + template_id, + template_hash, + ) + return True + # Config has no hash - keep it + elif config_hash is None: + log.debug( + "Config for template '%s' has no hash - keeping it (template_hash='%s')", + template_id, + template_hash, + ) + return True + # Template has no hash - keep config + else: + log.debug( + "Template '%s' has no src_hash - keeping config with hash '%s'", + template_id, + config_hash, + ) + return True + + def lookup_template_configs( + self, + kernel_inputs: KernelInputs, + op_name: str, + template_uids: list[str], + template_hash_map: Optional[dict[str, Optional[str]]] = None, + ) -> dict[str, list[dict[str, Any]]]: + """ + Unified function to look up template configurations for multiple templates. + Override this method to customize lookup logic. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"]) + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + {}: No lookup table in use, or no matches found for any template + {"template_uid1": [config1, config2], ...}: Matches found, filtered configurations + """ + lookup_table = self._get_lookup_table() + if not lookup_table: + log.debug("Lookup table: no table configured or CUDA unavailable") + return {} + + # Try both key variants: device-specific first, then device-agnostic + # If both exist, device-specific takes priority + device_key, device_agnostic_key = self.make_lookup_key_variants( + kernel_inputs, op_name + ) + + config_list = [] + + for key_type, key in [ + ("device-specific", device_key), + ("device-agnostic", device_agnostic_key), + ]: + if key is not None: + config_list = lookup_table.get(key, []) + if config_list: + log.debug( + "Lookup table: found %d configs using %s key '%s' for %s", + len(config_list), + key_type, + key, + op_name, + ) + break + else: + log.debug( + "Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)", + op_name, + device_key, + device_agnostic_key, + len(lookup_table), + ) + return {} + + log.debug( + "Lookup table: found %d configs for %s templates %s", + len(config_list), + op_name, + template_uids, + ) + # Group configs by template_id + configs_by_template: dict[str, list[dict[str, Any]]] = {} + for cfg in config_list: + if not isinstance(cfg, dict): + raise ValueError( + f"Config for {op_name} operation is not a dictionary: {cfg}" + ) + if "template_id" not in cfg: + raise ValueError( + f"Config for {op_name} operation missing required 'template_id' field: {cfg}" + ) + + template_id = cfg["template_id"] + if template_id in template_uids: + if template_id not in configs_by_template: + configs_by_template[template_id] = [] + configs_by_template[template_id].append(cfg) + + # Check template hashes and clean up template_id field + result = {} + for template_id, matching_configs in configs_by_template.items(): + filtered_configs = [] + for cfg in matching_configs: + # Check template hash using helper function + if not self._entry_is_valid(cfg, template_id, template_hash_map): + continue + + # Return a copy of the config, as we don't want to modify the original + cconfig = copy.deepcopy(cfg) + # Lastly, we have to throw out the template_id, as it's not a valid kwarg + # and just used to identify which template the entry belongs to + del cconfig["template_id"] + # Similarly, the template_hash is not a valid kwarg + cconfig.pop("template_hash", None) + filtered_configs.append(cconfig) + + if filtered_configs: + result[template_id] = filtered_configs + + return result + + def _finalize_template_configs( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Check lookup table for hits, use those if found, otherwise fall back to parent.""" + # 1. Collect template src_hashes for validation + template_uids = [template.uid for template in templates] + template_hash_map = {} + for template in templates: + src_hash = getattr(template, "src_hash", None) + template_hash_map[template.uid] = src_hash + + log.debug( + "Choices: attempting lookup for %s with %d templates", + op_name, + len(template_uids), + ) + + # 2. Single batch lookup for all templates + lookup_results = self.lookup_template_configs( + kernel_inputs, op_name, template_uids, template_hash_map + ) + + # 3. Early exit if no lookup table or no matches + if not lookup_results: # Empty dict + log.info("LookupChoices: lookup miss for %s, using fallback", op_name) + return self._fallback( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + log.info( + "LookupChoices: lookup hit for %s - found %d/%d templates: %s", + op_name, + len(lookup_results), + len(template_uids), + list(lookup_results.keys()), + ) + + # 4. Create KTCs only for templates with lookup entries + return self._create_lookup_choices( + lookup_results, templates, kernel_inputs, op_name + ) + + def _fallback( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Fallback to parent if no lookup table or no matches.""" + # NOTE: this is broken out, so that subclasses are able to override this + # to handle explicitly the situations where the lookup take had a miss vs + # overriding the entire logic + return super()._finalize_template_configs( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + def _create_lookup_choices( + self, + lookup_results: dict[str, list[dict[str, Any]]], + templates: list[Union[KernelTemplate, ExternKernelChoice]], + kernel_inputs: KernelInputs, + op_name: str, + ) -> list[KernelTemplateChoice]: + """Create KernelTemplateChoice objects from lookup results using parent's get_ktc method.""" + templates_by_uid = {template.uid: template for template in templates} + lookup_choices: list[KernelTemplateChoice] = [] + + for template_uid, configs in lookup_results.items(): + template = templates_by_uid[template_uid] + + # Use parent's get_ktc method to get a generator, then get the first base KTC + ktc_generator = self.get_ktc(kernel_inputs, template, op_name) + + try: + base_ktc = next(ktc_generator) + except StopIteration: + # No configs from heuristic, skip this template + continue + + # For each lookup config, create a KTC with the override kwargs + for c in configs: + lookup_ktc = KernelTemplateChoice( + template=base_ktc.template, + # use the ones from the lookup table + params=DictKernelTemplateParams(c), + extra_kwargs=base_ktc.extra_kwargs, + layout=base_ktc.layout, + inputs=base_ktc.inputs, + ) + lookup_choices.append(lookup_ktc) + + return lookup_choices diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15587401b723581b57f94fdcddbcbc8255f73bfe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__init__.py @@ -0,0 +1 @@ +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/build_package.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/build_package.py new file mode 100644 index 0000000000000000000000000000000000000000..9205b9ced254275018472108485173eba9479f11 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/build_package.py @@ -0,0 +1,15 @@ +build_package_contents = """ +import os +from pathlib import Path + +from torch._inductor.package.package import compile_so + +curr_dir = Path(__file__).parent +aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(curr_dir) + for file in files +] + +output_so = compile_so(curr_dir, aoti_files, curr_dir) +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/package.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/package.py new file mode 100644 index 0000000000000000000000000000000000000000..bd11d033cadb3fc3cfdba8165fb42dd996284931 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/package.py @@ -0,0 +1,138 @@ +import io +import json +import logging +import os +import tempfile +from typing import IO + +import torch +from torch._inductor import config +from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder +from torch.export.pt2_archive._package import ( + AOTI_FILES, + AOTICompiledModel, + load_pt2, + package_pt2, +) +from torch.types import FileLike + + +log = logging.getLogger(__name__) + + +def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: + def get_aoti_file_with_suffix(suffix: str) -> str: + for file in aoti_files: + if file.endswith(suffix): + return file + raise RuntimeError(f"Unable to find file with suffix {suffix}") + + # Compile all the files into a .so + cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp")) + consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o")) + + file_name = os.path.splitext(cpp_file)[0] + + # Parse compile flags and build the .o file + with open(file_name + "_compile_flags.json") as f: + compile_flags = json.load(f) + + compile_options = BuildOptionsBase( + **compile_flags, use_relative_path=config.is_fbcode() + ) + object_builder = CppBuilder( + name=file_name, + sources=cpp_file, + BuildOption=compile_options, + ) + output_o = object_builder.get_target_file_path() + object_builder.build() + + # Parse linker flags and build the .so file + with open(file_name + "_linker_flags.json") as f: + linker_flags = json.load(f) + + linker_options = BuildOptionsBase( + **linker_flags, use_relative_path=config.is_fbcode() + ) + so_builder = CppBuilder( + name=os.path.split(so_path)[-1], + sources=[output_o, consts_o], + BuildOption=linker_options, + output_dir=so_path, + ) + output_so = so_builder.get_target_file_path() + so_builder.build() + + # mmapped weights + serialized_weights_filename = file_name + "_serialized_weights.bin" + if serialized_weights_filename in aoti_files: + with open(serialized_weights_filename, "rb") as f_weights: + serialized_weights = f_weights.read() + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(serialized_weights) + + return output_so + + +def package_aoti( + archive_file: FileLike, + aoti_files: AOTI_FILES, +) -> FileLike: + """ + Saves the AOTInductor generated files to the PT2Archive format. + + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + + return package_pt2( + archive_file, + aoti_files=aoti_files, + ) + + +def load_package( + path: FileLike, + model_name: str = "model", + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, +) -> AOTICompiledModel: + try: + pt2_contents = load_pt2( + path, + run_single_threaded=run_single_threaded, + num_runners=num_runners, + device_index=device_index, + ) + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in package") + return pt2_contents.aoti_runners[model_name] + except RuntimeError: + log.warning("Loading outdated pt2 file. Please regenerate your package.") + + if isinstance(path, (io.IOBase, IO)): + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + # TODO(angelayi): We shouldn't need to do this -- miniz should + # handle reading the buffer. This is just a temporary workaround + path.seek(0) + f.write(path.read()) + log.debug("Writing buffer to tmp file located at %s.", f.name) + loader = torch._C._aoti.AOTIModelPackageLoader( + f.name, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) + + path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) + loader = torch._C._aoti.AOTIModelPackageLoader( + path, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..0034a6a8feb3de9d6c3052a4bd5b5cc18ac112e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,649 @@ +""" +PyTorch Inductor Autotuning Cache System + +This module implements a caching system for autotuning configurations in PyTorch's Inductor compiler. +It provides mechanisms to store and retrieve optimal kernel configurations both locally and remotely, +which significantly speeds up compilation by reusing previously discovered optimal parameters. + +The caching system includes: +- Local filesystem caching for individual machine reuse +- Remote caching for sharing optimizations across machines +- Bundled caching to efficiently store multiple related configurations +- Cache invalidation based on PyTorch versions and backend changes +- Serialization/deserialization support for worker processes + +Key components: +- AutotuneCache: Main class for managing cache access and storage +- AutotuneCacheBundler: Bundles multiple cache entries for efficient storage +- LocalAutotuneCache: Handles filesystem-based caching +- _LocalAutotuneCacheBackend: Low-level file operations for cache storage +- AutotuneCacheArtifact: Integration with PyTorch's artifact system + +This caching system is critical for performance as it eliminates the need to re-run +expensive autotuning operations when the same kernels are compiled multiple times. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +import re +from typing import Any, TYPE_CHECKING +from typing_extensions import override + +import torch +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.utils._triton import has_triton + +from ..remote_cache import ( + create_cache, + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) +from .triton_compat import Config, HAS_WARP_SPEC + + +if TYPE_CHECKING: + from ..remote_cache import Sample + +log = logging.getLogger(__name__) + + +_InductorMetaTy = dict[str, object] + + +def inductor_meta_from_config() -> _InductorMetaTy: + from torch._inductor import config + + backend_hash = None + if has_triton(): + try: + backend_hash = torch.utils._triton.triton_hash_with_backend() + except RuntimeError: + # This can get the error: + # RuntimeError: 0 active drivers ([]). There should only be one. + pass + + is_hip = None + if torch.version.hip is not None: + is_hip = True + + return { + "autotune_local_cache": config.autotune_local_cache, + "autotune_remote_cache": config.autotune_remote_cache, + "backend_hash": backend_hash, + "bundled_autotune_remote_cache": config.bundled_autotune_remote_cache, + "coordinate_descent_tuning": config.coordinate_descent_tuning, + "is_fbcode": config.is_fbcode(), + "is_hip": is_hip, + } + + +@CacheArtifactFactory.register +class AutotuneCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + autotune_cache = _LocalAutotuneCacheBackend() + key = os.path.join(cache_dir(), self.key) + autotune_cache._put(key, self.content) + + @override + @staticmethod + def type() -> str: + return "autotune" + + @override + @staticmethod + def encode(content: JsonDataTy) -> bytes: + assert not isinstance(content, bytes) + serde = RemoteCacheJsonSerde() + content_bytes = serde.encode(content) + assert isinstance(content_bytes, bytes) + return content_bytes + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + local_cache: tuple[RemoteCache[JsonDataTy], str] | None = None + remote_cache: tuple[RemoteCache[JsonDataTy], str] | None = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> AutotuneCache | None: + cache = AutotuneCache(configs_hash) + key = AutotuneCache._prepare_key(filename) + + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) + cache._setup_remote_autotune_cache(inductor_meta, key) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + @staticmethod + def _prepare_key(filename: str) -> str: + from torch.compiler import config as cconfig + + # base of filename is already sha256 hash the source contents + key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # Read the best config options from the most local cache and return it. + def _read(self) -> dict[str, JsonDataTy] | None: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: list[Config] + ) -> Config | None: + if best := self._read(): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache( + self, inductor_meta: _InductorMetaTy, dirname: str, cache_key: str + ) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + from ..codecache import torch_key + + """ + [Note: torch_key in autotune cache key] + Include torch_key() in the cache key so that different versions + of torch result in cache invalidation. This is important in case + of changes to the best_config format or other code changes that + are not backward compatible w.r.t. the cache. + """ + hasher = hashlib.sha256() + hasher.update(cache_key.encode("utf-8")) + hasher.update(torch_key()) + updated_cache_key = hasher.hexdigest() + + cache_filename = f"{dirname}/{updated_cache_key}.best_config" + local_cache = LocalAutotuneCache() + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, cache_key: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + if (backend_hash := inductor_meta.get("backend_hash", None)) is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return + assert isinstance(backend_hash, str) + + from ..codecache import torch_key + + is_fbcode = bool(inductor_meta.get("is_fbcode", False)) + + salt = "autotune-best-config-v2" + # re: torch_key - see [Note: torch_key in autotune cache key] + key = torch_key().hex() + backend_hash + self.configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + remote_cache = create_cache( + key, + is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if not remote_cache: + return + + # Save the args passed to create_cache + # in case AutotuneCache needs to be pickled + self.remote_cache_full_key = key + self.is_fbcode = is_fbcode + self.remote_cache = (remote_cache, cache_key) + + # The AutotuneCache may be serialized/deserialized if we're using + # AsyncCompile worker processes to run triton compilation. + # This is because AutotuneCache instances are created on the worker + # process, but we need to run AutotuneCache.save on the parent process + # when actually doing autotuning. + def __getstate__(self) -> dict[str, Any]: + # The remote cache handles themselves may not be serializable + # So clear it and reconstruct it on setstate + remote_cache = getattr(self, "remote_cache", None) + return { + **self.__dict__, + # Save the cache_key portion + "remote_cache": remote_cache and remote_cache[1], + } + + def __setstate__(self, state: dict[str, Any]) -> None: + # Reconstruct the remote cache on the parent class + self.__dict__.update(state) + if self.remote_cache is not None: + assert isinstance(self.remote_cache, str) + assert hasattr(self, "remote_cache_full_key") + assert hasattr(self, "is_fbcode") + cache_key = self.remote_cache + remote_cache = create_cache( + self.remote_cache_full_key, + self.is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if remote_cache is not None: + self.remote_cache = (remote_cache, cache_key) + else: + log.warning("Warning, failed to recreate remote cache after pickling") + self.remote_cache = None + + # Save the config in the caches + def save( + self, + config: Config, + time_taken_ns: int, + found_by_coordesc: bool = False, + triton_cache_hash: str | None = None, + ) -> None: + data = { + # pyrefly: ignore [missing-attribute] + **config.kwargs, + # pyrefly: ignore [missing-attribute] + "num_warps": config.num_warps, + # pyrefly: ignore [missing-attribute] + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + "triton_cache_hash": triton_cache_hash, + } + if HAS_WARP_SPEC: + data.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr( + config, "num_buffers_warp_spec", 0 + ), + } + ) + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + AutotuneCacheBundler.put(key, data) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +class _AutotuneCacheBundlerImpl: + """ + Caches a set of LocalAutotuneCacheBackend entries together in a single + cache. + """ + + _key: str + _cache: RemoteCache[JsonDataTy] + + # All known entries from LocalAutotuneCache.put() + _entries: dict[str, JsonDataTy] + + def end_compile(self) -> None: + # TODO: Do we need to compute time_taken_ms and encode that somehow? + if self._entries: + self._cache.put(self._key, self._entries) + + def put(self, basename: str, data: JsonDataTy) -> None: + # Do we need to worry about duplicates? We only have a single local fs + # entry - so probably not. + self._entries[basename] = data + + def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None: + self._key = key + self._cache = cache + self._entries = {} + + def sync(self) -> None: + # We don't currently use this - but we could async load starting at + # `begin_compile` and wait for the load to be finished here. + pass + + @classmethod + def _should_use_bundled_autotune_remote_cache( + cls, inductor_meta: _InductorMetaTy + ) -> bool: + # The bundled autotune cache is only available if you've also got local + # caching enabled (because we feed the bundled data to the local cache). + if not inductor_meta.get("autotune_local_cache", True): + return False + + # Check if the we're enabled via config + if ( + bundled_autotune_remote_cache := inductor_meta.get( + "bundled_autotune_remote_cache" + ) + ) is not None: + return bool(bundled_autotune_remote_cache) + + if not cls._get_is_fbcode(inductor_meta): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:bundled_autotune_remote_cache_version" + ) + return REMOTE_CACHE_VERSION >= jk + + def _load_cache(self) -> bool: + from torch._inductor import codecache + + # The single key is defined on construction of the cache. + entries = self._cache.get(self._key) + if entries is None or not isinstance(entries, dict): + # We couldn't load the cache - so mark _entries as non-None so we + # store local cache values. + return False + + # Go through the entries we got from the cache and save them locally. + time_saved_ns = 0 + for basename, data in entries.items(): + # Reconstruct the final filename (see put()) + root, ext = _splitext_nodot(basename) + _, _, filename = codecache.get_path(root, ext) + if isinstance(data, dict) and (tsns := data.get("time_saved_ns")): + time_saved_ns += int(tsns) # type: ignore[arg-type] + local_cache = LocalAutotuneCache() + local_cache.put(filename, data) + + codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + + return True + + @staticmethod + def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool: + return bool(inductor_meta.get("is_fbcode", False)) + + @staticmethod + def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str: + backend_hash = inductor_meta["backend_hash"] + assert isinstance(backend_hash, str) + return backend_hash + + +class AutotuneCacheBundler: + _bundler: _AutotuneCacheBundlerImpl | None = None + + def __init__(self) -> None: + pass + + # Call this before we start any autotune computation for an inductor python + # file. On a cache hit it copies the individual results into the local + # autotune caches. + @classmethod + def begin_compile( + cls, + inductor_meta: _InductorMetaTy, + *, + code: str | None = None, + code_hash: str | None = None, + ) -> None: + assert cls._bundler is None + + if code is not None: + assert code_hash is None, "Cannot specify both code and code_hash" + code_hash = _comment_stripped_hash(code) + assert code_hash is not None + + if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache( + inductor_meta + ): + return + + cache = create_cache( + "bundled-autotune-v1", + _AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta), + "FbRemoteBundledAutotuneCache", + "RemoteBundledAutotuneCache", + ) + if not cache: + return + + # We're starting a compilation phase. We have a cache key for the code + # we're compiling. We'll get the individual autotune bundles later (via + # self.put()). For now create the AutotuneCacheBundler and try to load + # from the cache. + + salt = "bundled-autotune-best-configs-v1" + backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta) + # TODO: The autotune cache includes configs_hash in the key. The problem + # is that the configs_hash includes info from the individual pointwise() + # calls (size_hints, for example) which we can't know yet. I *think* + # that info is basically present in the `code_hash` (since it's a + # parameter to the pointwise decorator) - but is there other info we + # need to include from inductor_meta? + key = code_hash + backend_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + bundler = _AutotuneCacheBundlerImpl(key, cache) + if not bundler._load_cache(): + # We couldn't load from the cache - so save the data so we can store + # the saved autotunes. + cls._bundler = bundler + + # If we get a cache hit don't bother saving any of the individual + # autotune results. + + # Call this after all individual autotune results are finished for a + # inductor python file. If we gathered any individual results then we bundle + # those and put it into the cache. + @classmethod + def end_compile(cls) -> None: + if bundler := cls._bundler: + cls._bundler = None + bundler.end_compile() + + @classmethod + def sync(cls) -> None: + if bundler := cls._bundler: + bundler.sync() + + @classmethod + def put(cls, filename: str, data: JsonDataTy) -> None: + if bundler := cls._bundler: + # The filename comes in as something like + # "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is + # basename[1:3]). Strip it down and make sure that it looks like a path + # we could reconstruct (because it's possible for the caller to + # customize the path). + basename = os.path.basename(filename) + + # TODO: check cache_dir() vs filename, then strip dirname + bundler.put(basename, data) + + +# Remove the comments from the code (which include things like run ids and file +# paths) and then hash the result. +def _comment_stripped_hash(code: str) -> str: + code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE) + return torch._inductor.codecache.code_hash(code) + + +def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: dict[str, JsonDataTy], + configs_hash: str, + configs: list[Config], + inductor_meta: _InductorMetaTy, +) -> Config | None: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + best_config.pop("triton_cache_hash", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + + # Extract common arguments + config_args = { + "num_warps": num_warps, + "num_stages": num_stages, + } + + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": best_config.pop("num_consumer_groups", 0), + "num_buffers_warp_spec": best_config.pop( + "num_buffers_warp_spec", 0 + ), + } + ) + + # Create the triton_config with the appropriate arguments + # pyrefly: ignore [bad-argument-count] + triton_config = Config(best_config, **config_args) + # pyrefly: ignore [missing-attribute] + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + # pyrefly: ignore [missing-attribute] + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + # pyrefly: ignore [missing-attribute] + and cfg.num_warps == best_config.get("num_warps") + # pyrefly: ignore [missing-attribute] + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def _get(self, key: str) -> bytes | None: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def _put(self, key: str, data: bytes) -> None: + os.makedirs(os.path.dirname(key), exist_ok=True) + from torch._inductor import codecache + + codecache.write_atomic(key, data) + + +class LocalAutotuneCache(RemoteCache[JsonDataTy]): + def __init__(self) -> None: + backend = _LocalAutotuneCacheBackend() + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + @override + def _get(self, key: str, sample: Sample | None) -> JsonDataTy | None: + AutotuneCacheBundler.sync() + result = super()._get(key, sample) + if result is not None: + assert isinstance(result, dict) + # What? Why are we doing a put() here? Imagine we have a new model + # that reuses some existing kernels that have already been + # compiled. If we didn't do a `put` here (on cache hit) then the new + # model would only bundle *newly* compiled kernels, not existing + # kernels that were already compiled and cached. + AutotuneCacheBundler.put(key, result) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) + return result + + @override + def _put(self, key: str, value: JsonDataTy, sample: Sample | None) -> None: + AutotuneCacheBundler.put(key, value) + super()._put(key, value, sample) + + +def _splitext_nodot(basename: str) -> tuple[str, str]: + root, ext = os.path.splitext(basename) + if ext: + ext = ext[1:] + return root, ext diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa33f66ef3a4441613eedabe25731ca2edc25fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py @@ -0,0 +1,441 @@ +import functools +import inspect +import time +from collections.abc import Callable +from functools import cached_property, wraps +from itertools import chain +from statistics import median +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.config import use_experimental_benchmarker +from torch.utils._debug_mode import DebugMode + + +logger = torch._logging.getArtifactLogger(__name__, "benchmarking") +use_experimental_benchmarker = ( + use_experimental_benchmarker and torch.cuda.is_available() +) + + +MILLISECONDS_PER_SECOND = 1000 + +P = ParamSpec("P") +T = TypeVar("T") + + +def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any]: + from torch._inductor import config + + if config.test_configs.distort_benchmarking_result == "": + return fn + + def distort( + ms: list[float] | tuple[float, ...] | float, + ) -> list[float] | tuple[float, ...] | float: + if isinstance(ms, (list, tuple)): + return type(ms)(distort(val) for val in ms) # type: ignore[misc] + + distort_method = config.test_configs.distort_benchmarking_result + assert isinstance(ms, float) + if distort_method == "inverse": + return 1.0 / ms if ms else 0.0 + elif distort_method == "random": + import random + + return random.random() + else: + raise RuntimeError(f"Unrecognized distort method {distort_method}") + + @functools.wraps(fn) + def wrapper( + *args: list[Any], **kwargs: dict[str, Any] + ) -> list[float] | tuple[float, ...] | float: + ms = fn(*args, **kwargs) + + return distort(ms) + + return wrapper + + +def may_ban_benchmarking() -> None: + if torch._inductor.config.deterministic: + raise RuntimeError("""In the deterministic mode of Inductor, we will avoid those + benchmarkings that would cause non deterministic results. Only benchmarkings in the vetted + scenarios are allowed. Example include autotuning for triton configs of pointwise kernels. + + When you see this exception, you can do one of the following two things: + 1. if the benchmarking you are doing does not introduce any non-determinism, you can just + add is_vetted_benchmarking=True to you benchmark_gpu call. That would solve the issue. + + 2. if the benchmarking you are doing indeed introduces non-determinism, you'll need to disable + such feature in deterministic mode or find an alternative implementation that is deterministic. + """) + + +def time_and_count( + fn: Callable[Concatenate[Any, P], T], +) -> Callable[Concatenate[Any, P], T]: + """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo + counters. It is expected that `fn` is a method of `Benchmarker` or one of its + subclasses; typing limitations prevent us from declaring this directly. + """ + + @wraps(fn) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: + fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}" + counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1 + with dynamo_timed(fn_qual_name, log_pt2_compile_event=False): + return fn(self, *args, **kwargs) + + return wrapper + + +class Benchmarker: + """ + A device-agnostic benchmarking utility for measuring the runtime of + inductor generated callables. + """ + + def __init__(self: Self) -> None: + pass + + def infer_device(self, *fn_args: Any, **fn_kwargs: Any) -> torch.device: + inferred_device: Optional[torch.device] = None + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + # Some callables take nested structures as arguments so use the + # flattened form to find any tensors + for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg): + if not isinstance(arg_or_kwarg_leaf, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg_leaf.device + elif arg_or_kwarg_leaf.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + + if inferred_device is None: + raise ValueError( + "Can't safely infer the device type of `fn` with no device types" + " in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. " + "`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`." + ) + + return inferred_device + + @time_and_count + def benchmark( + self: Self, + fn: Callable[..., Any], + fn_args: Optional[tuple[Any, ...]] = None, + fn_kwargs: Optional[dict[str, Any]] = None, + device: Optional[Union[str, torch.device]] = None, + **kwargs: Any, + ) -> float: + """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the + actual runtime calculation is dictated by the benchmarking implementation, but may be + one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around + device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises + `ValueError(...)` if we can't safely infer the device type of `fn`; for example, + if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device + types are found. To bypass device inference, provide the device to the `device` + parameter. + + WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly. + For example, if `fn` clears a mutable object, subsequent invocations of `fn` during + benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally. + If device inference is required, `Benchmarker.infer_device` can be used prior to calling + this method without any arguments for `fn_args` and `fn_kwargs`. + + Arguments: + - fn: The function to benchmark. + - fn_args: The function's arguments. + - fn_kwargs: The function's kwargs. + + Keyword Arguments: + - device: Which device to use for benchmarking. If not provided the device will be attempted + to be inferred from `fn_args` and `fn_kwargs`. + - **kwargs: The benchmarking implementation's kwargs. + + Returns: + - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. + """ + inferred_device: Optional[torch.device] = None + if device is not None: + inferred_device = ( + torch.device(device) if isinstance(device, str) else device + ) + else: + if fn_args is None and fn_kwargs is None: + raise ValueError( + "`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided." + ) + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + inferred_device = self.infer_device(*fn_args, **fn_kwargs) + + assert isinstance(inferred_device, torch.device) + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + + # No need to wrap if the callable takes no arguments + if len(fn_args) == 0 and len(fn_kwargs) == 0: + _callable = fn + else: + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + + # Surfacing all kernels during autotuning is super noisy; filtering these out. + with DebugMode._benchmarking_inductor(): + if inferred_device == torch.device("cpu"): + return self.benchmark_cpu(_callable, **kwargs) + # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking + # implementation which was written specifically with CUDA devices in mind, we may want to + # explore alternate implementations for other device types. + return self.benchmark_gpu(_callable, **kwargs) + + @time_and_count + def benchmark_cpu( + self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 + ) -> float: + """Benchmark the CPU callable, `_callable`, and return the median runtime, + in milliseconds. + + Arguments: + - _callable: The CPU callable to benchmark. + + Keyword Arguments: + - warmup: Optionally, the duration, in milliseconds, to run `_callable` + before benchmarking starts. + - rep: Optionally, the duration, in milliseconds, to run `_callable` + during benchmarking. + + Returns: + - The median runtime of `_callable`, in milliseconds. + """ + + def run_for(ms: int) -> list[float]: + timings = [] + run_start_t = time.perf_counter() + while True: + start_t = time.perf_counter() + _callable() + end_t = time.perf_counter() + timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND) + if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms: + break + return timings + + run_for(warmup) + return median(run_for(rep)) + + @time_and_count + def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: + raise NotImplementedError + + +class TritonBenchmarker(Benchmarker): + @cached_property + def triton_do_bench(self: Self) -> Callable[..., Any]: + """Lazily import Triton's `do_bench`.""" + try: + from triton.testing import do_bench + except ImportError as e: + raise NotImplementedError("requires Triton") from e + return do_bench + + @may_distort_benchmarking_result + @time_and_count + # pyrefly: ignore [bad-override] + def benchmark_gpu( + self: Self, + _callable: Callable[[], Any], + is_vetted_benchmarking: bool = False, + **kwargs: Any, + ) -> float: + """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. + + Arguments: + - _callable: The GPU callable to benchmark. + + Keyword Arguments: + - quantiles: Optionally, a tuple of floats denoting the requested quantiles. + - return_mode: Optionally, the requested return mode. Currently, Triton's + `do_bench` supports min, max, mean, and median return modes. + - **kwargs: Additional kwargs passed to Triton's `do_bench`. + + Returns: + - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified, + this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified, + this is the requested return mode. Otherwise, this is the median. + """ + if not is_vetted_benchmarking: + may_ban_benchmarking() + + do_bench_params = inspect.signature(self.triton_do_bench).parameters + for kwarg in list(kwargs.keys()): + if kwarg not in do_bench_params: + del kwargs[kwarg] + if "quantiles" in kwargs: + return self.triton_do_bench(_callable, **kwargs)[0] + elif "return_mode" in kwargs: + return self.triton_do_bench(_callable, **kwargs) + return self.triton_do_bench(_callable, **kwargs, return_mode="median") + + +class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter + @cached_property + def L2_cache_size(self: Self) -> int: + """Get the L2 cache size, in bytes, of the current device.""" + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return props.L2_cache_size + + def get_event_pairs( + self: Self, iters: int + ) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]: + """Get `iters` pairs of CUDA events.""" + return [ + ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) + for _ in range(iters) + ] + + def get_event_pairs_min_timing( + self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]] + ) -> float: + """Get the minimum timing, in milliseconds, for a group of CUDA event pairs.""" + return min( + [ + start_event.elapsed_time(end_event) + for start_event, end_event in event_pairs + ] + ) + + @may_distort_benchmarking_result + @time_and_count + def benchmark_gpu( # type: ignore[override] + self: Self, + _callable: Callable[[], Any], + estimation_iters: int = 5, + memory_warmup_iters: int = 100, + benchmark_iters: int = 100, + max_benchmark_duration: int = 25, + return_mode: str = "min", + grad_to_none: list[torch.Tensor] | None = None, + is_vetted_benchmarking: bool = False, + **kwargs: Any, + ) -> float | list[float]: + """Benchmark a GPU callable using a custom benchmarking implementation. + + Arguments: + - _callable: The callable to benchmark. + + Keyword Arguments: + - estimation_iters: Optionally, the number of iterations to run `_callable` + during runtime estimation. + - memory_warmup_iters: Optionally, the number of iterations to flush the L2 + cache before starting benchmarking. + - benchmark_iters: Optionally, the number of iterations to run `_callable` + during the benchmarking. + - max_benchmark_duration: Optionally, the maximum duration of the benchmarking, + in milliseconds. An estimated duration is calculated based on the values + of `memory_warmup_iters` and `benchmark_iters`, along with the estimated + runtime of `_callable` and various other factors, and we then shrink + `benchmark_iters` to fit in the allotted maximum duration. + - return_mode: Return mode for benchmark results. Options are "min" (default), + "all" (returns all measurements). + - grad_to_none: Optionally, a list of tensors whose gradients should be cleared + before each benchmark iteration. + - is_vetted_benchmarking: in deterministic mode, we only allow + benchmarking in vetted cases. + - **kwargs: Additional kwargs that may be passed to the fallback. + + Returns: + - If return_mode="min": The minimum runtime of `_callable`, in milliseconds. + - If return_mode="all": List of all runtime measurements, in milliseconds. + """ + + if not is_vetted_benchmarking: + may_ban_benchmarking() + + # we don't want any outside errors propagating into benchmarking + torch.cuda.synchronize() + + # warmup `_callable` (and catches any failures in the process) + _callable() + torch.cuda.synchronize() + + # see https://github.com/triton-lang/triton/pull/840 for why `dtype=torch.int` + buffer = torch.empty(self.L2_cache_size // 4, dtype=torch.int, device="cuda") + buffer.zero_() + + # estimate the runtime of `_callable` + event_pairs = self.get_event_pairs(estimation_iters) + for start_event, end_event in event_pairs: + # Clear gradients before timing (matches triton.testing.do_bench) + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + estimated_timing = self.get_event_pairs_min_timing(event_pairs) + + # adjust `benchmark_iters` to fit in the maximum benchmarking duration + benchmark_iters = max( + min(benchmark_iters, int(max_benchmark_duration // estimated_timing)), 1 + ) + + # do the memory warmup + for _ in range(memory_warmup_iters): + buffer.zero_() + + # benchmark `_callable` + event_pairs = self.get_event_pairs(benchmark_iters) + for start_event, end_event in event_pairs: + # Clear gradients before timing (matches triton.testing.do_bench) + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + + # explicitly delete the buffer, sometimes helps memory + # footprint metrics in OSS Inductor performance benchmarks + del buffer + + # Return based on the requested mode + if return_mode == "all": + # Get all timings from event pairs + all_timings = [ + start_event.elapsed_time(end_event) + for start_event, end_event in event_pairs + ] + return all_timings + elif return_mode == "min": + benchmarked_timing = self.get_event_pairs_min_timing(event_pairs) + # return the minimum of `estimated_timing` and `benchmarked_timing`, + # we just want the minimum timing overall so we might as well check both + return min(estimated_timing, benchmarked_timing) + else: + raise ValueError( + f"Unsupported return_mode: {return_mode}. Use 'min' or 'all'." + ) + + +benchmarker = ( + InductorBenchmarker() if use_experimental_benchmarker else TritonBenchmarker() +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34b84a68f6300c1709593e303ff2a07e1f50bc46 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py @@ -0,0 +1,54 @@ +import getpass +import os +import re +import tempfile +from collections.abc import Generator +from contextlib import contextmanager + +from torch._environment import is_fbcode + + +# Factoring out to file without torch dependencies + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir() -> str: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir() if not is_fbcode() else "/var/tmp", + "torchinductor_" + sanitized_username, + ) + + +def triton_cache_dir(device: int) -> str: + if (directory := os.getenv("TRITON_CACHE_DIR")) is not None: + return directory + return os.path.join( + cache_dir(), + "triton", + str(device), + ) + + +@contextmanager +def temporary_cache_dir(directory: str) -> Generator[None, None, None]: + from torch._inductor.utils import clear_caches + + original = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory + try: + clear_caches() + yield + finally: + clear_caches() + if original is None: + del os.environ["TORCHINDUCTOR_CACHE_DIR"] + else: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = original diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..11801eac925848eb1e0969d587e8fcc98484a1cd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import functools +import linecache +import os +import sys +import time +import warnings +from pathlib import Path +from types import ModuleType +from typing import Any, TYPE_CHECKING + +from torch._utils_internal import log_triton_builds + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + +def _reload_python_module( + key: str, path: str, set_sys_modules: bool = True +) -> ModuleType: + with open(path) as f: + try: + code = compile(f.read(), path, "exec", dont_inherit=True) + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + if set_sys_modules: + sys.modules[mod.__name__] = mod + return mod + + +@functools.cache +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas = Path(__file__).absolute().parents[2] / "bin" / "ptxas" + if not ptxas.exists(): + return + if ptxas.is_file() and os.access(ptxas, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = str(ptxas) + else: + warnings.warn(f"{ptxas} exists but is not an executable") + + +def _worker_compile_triton( + load_kernel: Callable[[], CachingAutotuner], + extra_env: dict[str, str], + extra_config: dict[str, Any], +) -> tuple[CachingAutotuner, int]: + _set_triton_ptxas_path() + os.environ.update(extra_env) + from torch._inductor import config + + with config.patch(extra_config): + fail = None + try: + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..91736febd29f61106fa5bb26d896941952582091 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -0,0 +1,412 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING + +from torch.utils._ordered_set import OrderedSet + +from ..utils import get_max_numwarps +from .hints import TRITON_MAX_BLOCK +from .runtime_utils import red_text, triton_config_to_hashable + + +if TYPE_CHECKING: + from .triton_compat import triton + + +log = logging.getLogger(__name__) + + +def get_field(config, name): + if name == "num_warps": + return config.num_warps + elif name == "num_stages": + return config.num_stages + elif name == "waves_per_eu": + return config.kwargs.get(name, int(8 // config.num_warps)) + else: + return config.kwargs.get(name, None) + + +def set_field(config, name, value): + if name == "num_warps": + config.num_warps = value + elif name == "num_stages": + config.num_stages = value + else: + config.kwargs[name] = value + + +class CoordescTuner: + """ + The coordinate descent tuner. Tune one field/coordinate at a time. + + TODO will it be necessary to tune multiple fields simultaneously. + + + TODO: what if both increasing and decreasing a field can improve perf. + i.e., there are multiple local optima.. + """ + + def __init__( + self, + is_mm=False, + is_native_matmul=False, + is_mix_order_reduction=False, + name="unknown", + size_hints=None, + inductor_meta=None, + frozen_fields=None, + ): + self.is_mm = is_mm # we will tune num_stages for mm + + # Native matmul codegen assumes ZBLOCK=1 always. + # This is because 3d tl.dot is slow and so we want to tile y and x only. + # tl.dot also does not support size smaller than 16; we put this restriction. + self.is_native_matmul = is_native_matmul + assert not (self.is_mm and self.is_native_matmul) + self.is_mix_order_reduction = is_mix_order_reduction + self.cached_benchmark_results = {} + self.name = name + self.size_hints = size_hints + self.inductor_meta = inductor_meta or {} + self.frozen_fields: OrderedSet[str] = ( + OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet() + ) + + def get_config_max(self, prefix: str) -> int: + max_block = TRITON_MAX_BLOCK[prefix.upper()] + size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None + return min(max_block, size_hint) if size_hint is not None else max_block + + def get_warpsmax(self): + # Avoid querying device directly if device properties are populated in inductor_meta + warp_size = self.inductor_meta.get("warp_size") + max_threads_per_block = self.inductor_meta.get("max_threads_per_block") + if warp_size and max_threads_per_block: + return max_threads_per_block // warp_size + else: + return get_max_numwarps() + + def cache_benchmark_result(self, config, timing): + self.cached_benchmark_results[triton_config_to_hashable(config)] = timing + + def lookup_in_cache(self, config): + return self.cached_benchmark_results.get(triton_config_to_hashable(config)) + + def call_func(self, func, config): + found = self.lookup_in_cache(config) + if found is not None: + log.debug(" CACHED") + return found + timing = func(config) + self.cache_benchmark_result(config, timing) + return timing + + @property + def tunable_fields(self): + out = [ + "XBLOCK", + "YBLOCK", + "ZBLOCK", + # NOTE: we should not tune R0_BLOCK for persistent reduction. + # We rely on the fact that persistent reduction's triton.Config + # does not have the R0_BLOCK field to guarantee that. + "R0_BLOCK", + "R1_BLOCK", + # the following 3 are for mm + "BLOCK_M", + "BLOCK_N", + "BLOCK_K", + "num_warps", + ] + if self.is_mm: + out.append("num_stages") + if self.inductor_meta.get("is_hip") is True: + out.append("waves_per_eu") + if self.is_native_matmul: + out.append("num_stages") + out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul + + if self.is_mix_order_reduction: + # unlike TritonConfig.num_stages, this one is + # put in TritonConfig.kwargs["NUM_STAGES"] and is used to + # control the stage of pipelining of tl.range. + out.append("NUM_STAGES") + + return [f for f in out if f not in self.frozen_fields] + + def value_too_large(self, name: str, val: int) -> bool: + block_suffix = "BLOCK" + if name.endswith(block_suffix): + prefix = name.strip(block_suffix).lower() + return val > self.get_config_max(prefix) + if name == "num_warps": + return val > self.get_warpsmax() + if name == "waves_per_eu": + return val > 8 + + return False + + def value_too_small(self, name: str, val: int) -> bool: + # In native matmul, block size should be >= 16 for tl.dot + if self.is_native_matmul: + if name in ["YBLOCK", "XBLOCK", "R0_BLOCK"]: + return val < 16 + + # Break if value becomes 0/neg + return val <= 0 + + def get_neighbour_values(self, name, orig_val, radius=None, include_self=False): + """ + Get neighbour values in 'radius' steps. The original value is not + returned as it's own neighbour. + """ + if radius is None: + radius = 1 + if name == "NUM_STAGES": + # we see cases that + # NUM_STAGES=1 is better than NUM_STAGES=2 + # while NUM_STAGES=1 is worse than NUM_STAGES=3 + radius = max(radius, 2) + + assert radius >= 1 + + def update(cur_val, inc=True): + if name in ["num_stages", "NUM_STAGES"]: + if inc: + return cur_val + 1 + else: + return cur_val - 1 + else: + if inc: + return cur_val * 2 + else: + return cur_val // 2 + + out = [] + # increment loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, True) + if self.value_too_large(name, cur_val): + break + out.append(cur_val) + + # decrement loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, False) + if self.value_too_small(name, cur_val): + break + out.append(cur_val) + + if include_self: + out.append(orig_val) + return out + + @staticmethod + def has_improvement(baseline, test): + threshold = 0.001 # 0.1% + return test is not None and test < baseline * (1 - threshold) + + def is_valid_config(self, config) -> bool: + if self.is_mix_order_reduction: + # Mix order reduction has an extra constraint that + # we should not tune XBLOCK beyond RSPLIT_SIZE + xblock = config.kwargs["XBLOCK"] + split_size = config.kwargs["RSPLIT_SIZE"] + return xblock <= split_size + return True + + def check_all_tuning_directions( + self, + # pyrefly: ignore [missing-attribute] + func: Callable[["triton.Config"], float], + best_config, + best_timing, + ): + """ + Check all directions. We only do this once the regular coordinate + descent tuning find no better choices any more. + We only have a few tunable fields, so this should be fine. + """ + candidate_values_list = [] + effective_fields = [] + for field in self.tunable_fields: + old_value = get_field(best_config, field) + if old_value is None: + continue + radius = self.inductor_meta.get("coordinate_descent_search_radius", 1) + candidate_values = self.get_neighbour_values( + field, + old_value, + radius=radius, + include_self=True, + ) + candidate_values_list.append(candidate_values) + effective_fields.append(field) + + choices = itertools.product(*candidate_values_list) + improved = False + for choice in choices: + assert len(choice) == len(effective_fields) + candidate_config = copy.deepcopy(best_config) + for new_val, field in zip(choice, effective_fields): + set_field(candidate_config, field, new_val) + if not self.is_valid_config(candidate_config): + continue + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config = candidate_config + best_timing = candidate_timing + + return improved, best_config, best_timing + + def compare_config(self, func, candidate_config, best_config, best_timing): + """ + Check if candidate_config is better than best_config. + + Return a tuple of (compare_result, candidate_timing). + compare_result is true iff candidate_config is better. + """ + log.debug("Try config %s", candidate_config) + try: + candidate_timing = self.call_func(func, candidate_config) + except Exception as e: + log.debug("Got exception %s", e) # noqa: G200 + return False, float("inf") + + if self.has_improvement(best_timing, candidate_timing): + log.debug( + "Tune from %s %f -> %s %f", + best_config, + best_timing, + candidate_config, + candidate_timing, + ) + + return True, candidate_timing + return False, candidate_timing + + def autotune( + self, + # pyrefly: ignore [missing-attribute] + func: Callable[["triton.Config"], float], + # pyrefly: ignore [missing-attribute] + baseline_config: "triton.Config", + baseline_timing: float | None = None, + ) -> "triton.Config": # pyrefly: ignore # missing-attribute + if baseline_timing is None: + baseline_timing = self.call_func(func, baseline_config) + + log.debug("= Do coordinate descent tuning for %s =", self.name) + log.debug( + "%s: Baseline Config %s, baseline timing %f", + self.name, + baseline_config, + baseline_timing, + ) + improved = True + best_config = baseline_config + best_timing = baseline_timing + tunable_fields = self.tunable_fields + + while improved: + improved = False + + for name in tunable_fields: + cur_val = get_field(best_config, name) + # some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None + if cur_val is None: + continue + + # It's possible that candidate_values is empty. + # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. + # We would not try either larger or smaller XBLOCK in this case. + candidate_values = self.get_neighbour_values(name, cur_val) + + for next_val in candidate_values: + candidate_config = copy.deepcopy(best_config) + set_field(candidate_config, name, next_val) + + if not self.is_valid_config(candidate_config): + continue + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config, best_timing = candidate_config, candidate_timing + + if not improved and self.inductor_meta.get( + "coordinate_descent_check_all_directions" + ): + old_best_timing = best_timing + improved, best_config, best_timing = self.check_all_tuning_directions( + func, best_config, best_timing + ) + + if improved: + msg = red_text( + "%s: Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, + self.name, + old_best_timing / best_timing, + ) + + log.debug( + "%s: Improve from %s %f -> %s %f, %.3fx", + self.name, + baseline_config, + baseline_timing, + best_config, + best_timing, + baseline_timing / best_timing, + ) + + return best_config + + @staticmethod + def autotune_single_field(fn, init_val, min_val=None, max_val=None): + """ + fn is a function that takes the field value and returns the benchmarking result + init_val is the starting point of autotuning. + + Should work well for parabola like curve. Here is a real example + for split-size of mix-order-reduction: https://github.com/pytorch/pytorch/pull/166461 + """ + cache = {} + + def _bench(val): + if val not in cache: + cache[val] = fn(val) + # print(f"split size {val} -> {cache[val]:.3f} ms") + return cache[val] + + if min_val is None: + min_val = 1 + if max_val is None: + max_val = 2**30 # some arbitrary large value + + best_val = init_val + improved = True + while improved: + improved = False + candlist = [best_val // 2, best_val * 2] + for cand in candlist: + cand = max(cand, min_val) + cand = min(cand, max_val) + + if _bench(cand) < _bench(best_val): + best_val = cand + improved = True + + return best_val diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c15ff890dda6bc2cf9b541c1d5c8b76939c07ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/debug_utils.py @@ -0,0 +1,138 @@ +import functools +import logging +import threading +import weakref + +import torch +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + +local = threading.local() +local.memory_tracker = None + + +class BufferMemoryTracker: + """ + Tracks inductor runtime allocations and deallocations to compare against + expected behavior. + """ + + def __init__(self) -> None: + self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = ( + weakref.WeakValueDictionary() # type: ignore[assignment] + ) + self.died_since_last_step: OrderedSet[str] = OrderedSet() + self.added_since_last_step: OrderedSet[str] = OrderedSet() + self.error = ( + torch._inductor.config.test_configs.track_memory_lifecycle == "assert" + ) + + def set_tensor(self, name: str, tensor: torch.Tensor) -> None: + storage = tensor.untyped_storage() + + self.added_since_last_step.add(name) + self.tensor_tracker[name] = storage + + def on_tensor_death() -> None: + self.died_since_last_step.add(name) + + weakref.finalize(storage, on_tensor_death) + + def advance_step(self) -> None: + self.died_since_last_step.clear() + self.added_since_last_step.clear() + + def log_or_raise(self, msg: str) -> None: + if self.error: + raise RuntimeError(msg) + else: + log.info(msg) + + def check_step_delta( + self, + expected_allocated: list[str], + expected_freed: list[str], + is_final_step: bool, + ) -> None: + """Check only the delta changes since last step""" + + # Check expected deaths - we dont currently distinguish between nodes which die in last step + # and are returned as outputs, so skip if final_step. + if not is_final_step: + missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step + if missing_deaths: + self.log_or_raise( + f"Expected tensors to die but still alive: {missing_deaths}" + ) + + # Check for unexpected deaths + unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed) + if unexpected_deaths: + self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}") + + # Check newly alive tensors - separate messages like deaths + actual_allocated = self.added_since_last_step + expected_allocated_set = OrderedSet(expected_allocated) + + extra_alive = actual_allocated - expected_allocated_set + if extra_alive: + self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}") + + missing_alive = expected_allocated_set - actual_allocated + if missing_alive: + self.log_or_raise( + f"Expected allocated tensors but missing: {missing_alive}" + ) + + # Reset for next step + self.advance_step() + + if is_final_step: + local.memory_tracker = None + + +def get_mem_tracker() -> BufferMemoryTracker: + if local.memory_tracker is None: + local.memory_tracker = BufferMemoryTracker() + return local.memory_tracker + + +def track_tensor(tensor: torch.Tensor, name: str) -> None: + get_mem_tracker().set_tensor(name, tensor) + + +def tracked_empty_strided( + size: list[int], + stride: list[int], + *, + dtype: torch.dtype, + device: torch.device, + name: str, +) -> torch.Tensor: + o = torch.empty_strided(size, stride, dtype=dtype, device=device) + track_tensor(o, name) + return o + + +def check_memory_step( + allocated: list[str], freed: list[str], is_final_step: bool = False +) -> None: + tracker = get_mem_tracker() + tracker.check_step_delta(allocated, freed, is_final_step) + + +@functools.lru_cache(None) +def register_check_mem_op() -> None: + lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901 + lib.define( + "check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()" + ) + lib.impl("check_memory_step", check_memory_step, "BackendSelect") + from torch._higher_order_ops.effects import _EffectType, _register_effectful_op + + _register_effectful_op( + torch.ops._inductor_debug.check_memory_step.default, + _EffectType.ORDERED, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bf70fe9d8db1cb66379df11e025ad84cc0069b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +try: + import halide as hl # type: ignore[import-untyped, import-not-found] +except ImportError: + hl = None + +PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +if hl is not None: + PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) + PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) + PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) + PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) +else: + PHILOX_KEY_A_U32 = None + PHILOX_KEY_B_U32 = None + PHILOX_ROUND_A_U32 = None + PHILOX_ROUND_B_U32 = None + + +def _pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = hl.max(hl.f32(1.0e-7), u1) + th = hl.f32(6.283185307179586) * u2 + r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) + return r * hl.cos(th), r * hl.sin(th) + + +def _uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + + # TODO: + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. + assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) + x = hl.cast(hl.Int(32), x) + scale = hl.f64(4.6566127342e-10) + x = hl.select(x < 0, -x - 1, x) + return x * scale + + +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): + def umulhi(a, b): + a = hl.cast(hl.UInt(64), a) + b = hl.cast(hl.UInt(64), b) + return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) + + for _ in range(n_rounds): + _c0, _c2 = c0, c2 + + c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 + c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 + c1 = PHILOX_ROUND_B_U32 * _c2 + c3 = PHILOX_ROUND_A_U32 * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A_U32 + k1 = k1 + PHILOX_KEY_B_U32 + + return c0, c1, c2, c3 + + +def halide_philox(seed, c0, c1, c2, c3, n_rounds): + seed = hl.cast(hl.UInt(64), seed) + + assert c0.type().bits() == 32 + + seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) + seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) + + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +def randint4x(seed, offset, n_rounds): + offset = hl.cast(hl.UInt(32), offset) + _0 = hl.u32(0) + return halide_philox(seed, offset, _0, _0, _0, n_rounds) + + +def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + u3 = _uint_to_uniform_float(i3) + u4 = _uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + source = randint(seed, offset, n_rounds) + return _uint_to_uniform_float(source) + + +def randn(seed, offset): + i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + n1, _ = _pair_uniform_to_normal(u1, u2) + return n1 + + +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + r0 = hl.cast(hl.UInt(64), r0) + r1 = hl.cast(hl.UInt(64), r1) + + result = r0 | (r1 << 32) + size = high - low + result = result % hl.cast(hl.UInt(64), size) + result = hl.cast(hl.Int(64), result) + low + return result diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ddf91e9a59cf805907e7ee4accecdd4c214a37 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py @@ -0,0 +1,224 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import functools +import typing +from enum import auto, Enum + +import torch +from torch.utils._triton import has_triton_package + + +# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values +# NOTE: if these fail asserts submit a PR to increase them +TRITON_MAX_BLOCK = { + "X": 8192 if torch.version.hip else 4096, + "Y": 1024, + "Z": 1024, + "R0_": 4096 * 16, # * 16 is multi-kernel only + "R1_": 2048 * 16, # * 16 is multi-kernel only +} +TRITON_MAX_RSPLIT = 64 + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +# Define `AttrsDescriptorWrapper` function with clear conditional handling +if has_triton_package(): + import triton + import triton.backends.compiler + import triton.compiler.compiler + + if hasattr(triton.backends.compiler, "AttrsDescriptor"): + # Triton 3.2.0 - the second implementation + from triton.backends.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "tt.divisibility": divisible_by_16, + "tt.equal_to": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + res = AttrsDescriptor.from_dict( + {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__} + ) + assert res.property_values["tt.divisibility"] == 16 + assert res.property_values["tt.equal_to"] == 1 + return res + + elif hasattr(triton.compiler.compiler, "AttrsDescriptor"): + # Triton 3.0.0 - the original implementation + from triton.compiler.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + + else: + # Triton in 2025: + # note: there's also a range of triton commits not currently supported + # from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still + # used, but the contents are different. + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # pyrefly: ignore [not-iterable] + return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] + # pyrefly: ignore [invalid-argument] + "AttrsDescriptor", + ["divisible_by_16", "equal_to_1"], + defaults=[(), ()], + ) + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + FIXED = auto() + + +class AutotuneHint(Enum): + ONE_ELEMENT_PER_THREAD = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +class DeviceProperties(typing.NamedTuple): + """Copy device properties into a data structure not requiring torch to be imported""" + + type: str # type: ignore[assignment] + index: int # type: ignore[assignment] + multi_processor_count: int + cc: int + major: int | None = None + regs_per_multiprocessor: int | None = None + max_threads_per_multi_processor: int | None = None + max_threads_per_block: int | None = None + warp_size: int | None = None + + @classmethod + @functools.cache + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + try: + multi_processor_count = props.multi_processor_count + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + elif device_type == "mtia": + multi_processor_count = 64 + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + max_threads_per_block=getattr(props, "max_threads_per_block", 1024), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) + + +class HalideInputSpec(typing.NamedTuple): + ctype: str + name: str + shape: list[str] | None = None + stride: list[str] | None = None + offset: str | None = None + alias_of: str | None = None + + def bindings_type(self) -> str: + if self.ctype in ("at::Half*", "at::BFloat16*"): + return "uint16_t*" # half not defined + return self.ctype + + def halide_type(self) -> str: + if self.ctype == "at::Half*": + return "halide_type_t(halide_type_float, 16)" # half not defined + if self.ctype == "at::BFloat16*": + return "halide_type_t(halide_type_bfloat, 16)" # half not defined + return f"halide_type_of<{self.ctype.replace('*', '')}>()" + + def is_scalar(self) -> bool: + return self.shape is None + + def is_buffer(self) -> bool: + return self.shape is not None + + +class HalideMeta(typing.NamedTuple): + argtypes: list[HalideInputSpec] + target: str + scheduler: str | None = None + scheduler_flags: dict[str, int | str] | None = None + cuda_device: int | None = None + + def args(self) -> list[str]: + """Command line args to pass to halide generator""" + args = [f"target={self.target}"] + if self.scheduler: + args.append(f"autoscheduler={self.scheduler}") + if self.scheduler_flags: + assert self.scheduler + for k, v in self.scheduler_flags.items(): + args.append(f"autoscheduler.{k}={v}") + return args + + def is_cuda(self) -> bool: + return self.cuda_device is not None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e66378e85aec07c60e68e48d441178db423dc2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import functools +import operator +from typing import Any, TYPE_CHECKING + +import torch + +# NOTE: other files rely on the imports below +from torch._dynamo import callback as compilation_callback # noqa: F401 +from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 + cache_dir, + default_cache_dir, + triton_cache_dir, +) + + +if TYPE_CHECKING: + from collections.abc import Hashable + + from .triton_compat import Config + + +def conditional_product(*args: int) -> int: + return functools.reduce(operator.mul, [x for x in args if x]) + + +def ceildiv(number: int, denom: int) -> int: + return -(number // -denom) + + +def is_power_of_2(n: int) -> bool: + """Returns whether n = 2 ** m for some integer m.""" + return n > 0 and n & n - 1 == 0 + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def last_power_of_2(n: int) -> int: + """Return the largest power of 2 less than or equal to n""" + next_pow2 = next_power_of_2(n) + return next_pow2 // 2 if next_pow2 > n else next_pow2 + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def triton_config_to_hashable(cfg: Config) -> Hashable: + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + # pyrefly: ignore [missing-attribute] + items = sorted(cfg.kwargs.items()) + # pyrefly: ignore [missing-attribute] + items.append(("num_warps", cfg.num_warps)) + # pyrefly: ignore [missing-attribute] + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def validate_triton_config(cfg: Config) -> None: + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert getattr(cfg, "pre_hook", None) is None, ( + "triton configs with pre_hooks not supported" + ) + + +def create_bandwidth_info_str( + ms: float, + num_gb: float, + gb_per_s: float, + prefix: str = "", + suffix: str = "", + color: bool = True, +) -> str: + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_max_y_grid() -> int: + return 65535 + + +try: + # pyrefly: ignore [import-error] + import colorama + + HAS_COLORAMA = True +except ModuleNotFoundError: + HAS_COLORAMA = False + colorama = None # type: ignore[assignment] + + +if HAS_COLORAMA: + + def _color_text(msg: str, color: str) -> str: + # pyrefly: ignore [missing-attribute] + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + +else: + + def _color_text(msg: str, color: str) -> str: + return msg + + +def green_text(msg: str) -> str: + return _color_text(msg, "green") + + +def yellow_text(msg: str) -> str: + return _color_text(msg, "yellow") + + +def red_text(msg: str) -> str: + return _color_text(msg, "red") + + +def blue_text(msg: str) -> str: + return _color_text(msg, "blue") + + +def get_first_attr(obj: Any, *attrs: str) -> Any: + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] + + +def triton_hash_to_path_key(key: str) -> str: + # In early versions of Triton, the hash is directly used in the path name. + # Later, the hash is converted to base64 before being used in the path name. + # Later, the base64 conversion was replaced to the base32 + # + # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable. + # + # To handle this, try to import the to-base64-conversion function. + # If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly. + try: + from triton.runtime.cache import _base64 + + return _base64(key) + except Exception: + try: + from triton.runtime.cache import _base32 + + return _base32(key) + except Exception: + return key + + +def compile_mps_shader(source: str) -> Any: + """ + Compiles shader source but raise more actionable error message when needed + """ + try: + return torch.mps.compile_shader(source) + except SyntaxError as err: + raise SyntaxError(f"failed to compile {source} with {err.msg}") from err + + +def torch_dtype_to_jax_runtime(dtype: torch.dtype) -> Any: + """ + Map PyTorch dtype to actual JAX dtype object at runtime. + + This helper is used in generated Pallas kernels at runtime to convert + PyTorch dtypes to JAX dtype objects (not string representations). + + Args: + dtype: PyTorch dtype to convert + + Returns: + JAX dtype object (e.g., jnp.float32 object itself) + """ + import jax.numpy as jnp # pyrefly: ignore [import-error] + + dtype_map = { + torch.float32: jnp.float32, + torch.float64: jnp.float64, + torch.float16: jnp.float16, + torch.bfloat16: jnp.bfloat16, + torch.int32: jnp.int32, + torch.int64: jnp.int64, + torch.int16: jnp.int16, + torch.int8: jnp.int8, + torch.uint8: jnp.uint8, + torch.bool: jnp.bool_, + torch.complex64: jnp.complex64, + torch.complex128: jnp.complex128, + } + if dtype not in dtype_map: + raise ValueError(f"Unsupported dtype for JAX conversion: {dtype}") + return dtype_map[dtype] + + +def torch_dtype_to_jax(dtype: torch.dtype) -> str: + """ + Map PyTorch dtype to JAX dtype expression string. + + This helper is used at compile time in codegen to generate + JAX dtype expressions for Pallas kernels. + + Args: + dtype: PyTorch dtype to convert + + Returns: + JAX dtype expression as string (e.g., "jnp.float32") + """ + jax_dtype = torch_dtype_to_jax_runtime(dtype) + dtype_name = jax_dtype.__name__ + if dtype_name == "bool": + dtype_name = "bool_" + return f"jnp.{dtype_name}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..f48f351ce823a325a2f15092ad964aeba09aaf82 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py @@ -0,0 +1,270 @@ +import functools +import os +from typing import Any +from typing_extensions import Unpack + +from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs +from .triton_helpers import get_constexprs + + +class StaticallyLaunchedCudaKernel: + """ + Parses the metadata of a CompiledKernel from Triton into a structure that can + launch the cuda kernel directly. Only works for triton kernels compiled to cubin. + + Doing this avoids C++ codegen and compilation during compile, since we can use a + statically compiled library to launch the kernel. To avoid mallocing for the arguments, + we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher + only supports # of arguments up until 10 for now. + + Workflow: + Compile time: + 1. Compile a kernel with triton and get a CompiledKernel + 2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel) + 3. Write to a cubin file: kernel.write_cubin_to_file(filepath) + 4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin + Runtime: + 5. Call kernel.run(grid, stream, args) to launch the kernel + + Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable. + This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker + to the parent process in inductor. + + There are two main versions of triton that we wish to support: 3.3 and 3.2. Triton makes considerable changes + to how it handles constants in 3.3, so there's some special logic necessary to handle both versions. + """ + + def __init__(self, kernel: CompiledKernel) -> None: + # pyrefly: ignore [missing-attribute] + self.name = kernel.src.fn.__name__ + # pyrefly: ignore [missing-attribute] + self.cubin_raw = kernel.asm.get("cubin", None) + # pyrefly: ignore [missing-attribute] + self.cubin_path = kernel._cubin_path + + # Used by torch.compile to filter constants in older triton versions + # pyrefly: ignore [missing-attribute] + self.arg_names = kernel.src.fn.arg_names + + # Const exprs that are declared by the triton kernel directly + # Used to generate the kernel launcher's def args + # pyrefly: ignore [missing-attribute] + self.declared_constexprs = get_constexprs(kernel.src.fn) + + # pyrefly: ignore [missing-attribute] + self.hash = kernel.hash + + if triton_knobs is None: + # pyrefly: ignore [missing-attribute] + launch_enter = kernel.__class__.launch_enter_hook + # pyrefly: ignore [missing-attribute] + launch_exit = kernel.__class__.launch_exit_hook + else: + launch_enter = triton_knobs.runtime.launch_enter_hook + launch_exit = triton_knobs.runtime.launch_exit_hook + + def hook_is_empty(hook: Any) -> bool: + if hook is None: + return True + if ( + triton_knobs + and (HookChain := getattr(triton_knobs, "HookChain", None)) is not None + and isinstance(hook, HookChain) + ): + # Support hooks after https://github.com/triton-lang/triton/pull/7866 + return len(hook.calls) == 0 + return False + + if not hook_is_empty(launch_enter) or not hook_is_empty(launch_exit): + raise NotImplementedError( + "We don't support launch enter or launch exit hooks" + ) + # pyrefly: ignore [missing-attribute] + self.num_warps = kernel.metadata.num_warps + self.shared = ( + # pyrefly: ignore [missing-attribute] + kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared + ) + + def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: + # pyrefly: ignore [missing-attribute] + if hasattr(kernel.metadata, param_name): + if getattr(kernel.metadata, param_name) > 0: + raise NotImplementedError( + f"{scratch_name} scratch not yet supported" + ) + return True + return False + + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. + # Inductor never uses this field or enables it, but we still have to pass + # an extra None into the set of params if its enabled + self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size") + # same situation for profile scratch - triton-lang/triton#7258 + self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") + + # pyrefly: ignore [missing-attribute] + self.arg_tys = self.arg_ty_from_signature(kernel.src) + self.function: int | None = None # Loaded by load_kernel(on the parent process) + num_ctas = 1 + if hasattr(kernel, "num_ctas"): + num_ctas = kernel.num_ctas + elif hasattr(kernel, "metadata"): + num_ctas = kernel.metadata.num_ctas + + if num_ctas != 1: + raise NotImplementedError( + "Static cuda launcher only supports num_ctas == 1" + ) + + def reload_cubin_from_raw(self, filepath: str) -> str: + """ + If the cubin file triton generated gets deleted under us, we can + reload it from the raw cubin file. + """ + if self.cubin_path is None: + assert self.cubin_raw is not None + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "wb") as f: + f.write(self.cubin_raw) + self.cubin_path = filepath + return self.cubin_path + + def load_kernel(self, device: int) -> None: + from torch._C import _StaticCudaLauncher + + if self.function is not None: + return + + assert hasattr(self, "cubin_path") + assert self.cubin_path is not None + (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel( + self.cubin_path, self.name, self.shared, device + ) + # Don't need the cubin path anymore now that we've loaded + self.cubin_path = None + self.cubin_raw = None + + @staticmethod + @functools.lru_cache + def type_mappings() -> dict[str, str]: + return { + "i1": "i", + "i8": "b", + "i16": "h", + "i32": "i", + "i64": "l", + "u1": "I", + "u8": "B", + "u16": "H", + "u32": "I", + "u64": "K", + "fp16": "f", + "bf16": "f", + "fp32": "f", + "f32": "f", + "fp64": "d", + # TODO handle nvTmaDesc/CUtensormap + } + + def extract_type(self, ty: str) -> str: + """ + Takes a triton type from CompiledKernel.signature and + converts it into a single char encoding. _StaticCudaLauncher + will switch on this char to figure out what type the underlying + value should be passed to the triton kernel as. + """ + if ty[0] == "*": + return "O" + elif ty == "nvTmaDesc": + raise NotImplementedError("nvTmaDesc kernels are not yet supported") + return StaticallyLaunchedCudaKernel.type_mappings()[ty] + + def arg_ty_from_signature(self, src: ASTSource) -> str: + def index_key(i: Any) -> int: + if isinstance(i, str): + # pyrefly: ignore [missing-attribute] + return src.fn.arg_names.index(i) + elif isinstance(i, tuple): + # In triton 3.3, src.fn.constants has tuples as a key + return i[0] + else: + return i + + # pyrefly: ignore [missing-attribute] + signature = {index_key(key): value for key, value in src.signature.items()} + # Triton uses these as the main way to filter out constants passed to their cubin + constants = [index_key(key) for key in getattr(src, "constants", dict())] + # This value is always a superset of kernel.fn.constexprs: kernel.fn.constexprs are + # constants declared by the triton kernel directly, whereas this list can have + # constants that are unused by the triton kernel that triton figured out during + # compilation. + self.full_constexprs = constants + # Despite requiring them to be passed in, the triton CUDA launcher + # completely ignores the constexprs passed into it when generating code. + # So we can ignore them here too + params = [] + + for i in sorted(signature.keys()): + ty = signature[i] + # In newer triton versions, constants are passed in to signature with type `constexpr` + # In older triton versions, there can be constants in src.constants that are not `constexpr` in signature + # so we check both here + if ty == "constexpr" or i in constants: + pass + else: + # pyrefly: ignore [bad-argument-type] + params.append(self.extract_type(ty)) + return "".join(params) + + def __getstate__(self) -> dict[str, Any]: + # Remove objects that are no longer valid for pickling + state = self.__dict__.copy() + state["function"] = None + # Cubin paths aren't consistent across processes, so we clear + # and reload them. + state["cubin_path"] = None + return state + + def run( + self, + grid_x: int, + grid_y: int, + grid_z: int, + stream: int, + *args: Unpack[tuple[object, ...]], + ) -> None: + """Actually run the kernel at runtime. This function is the hot codepath.""" + from torch._C import _StaticCudaLauncher + + # Assert load_kernel() has been called and args match + assert self.function is not None + + # TODO: actually, if the args *don't* match, we probably should + # throw an exception. But if inductor is the only one calling this + # thing, it should always match. + # Get rid of constants before passing to cubin launcher + + # Add a None if triton wants extra parameters for scratch spaces + arg_tys = self.arg_tys + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) + # pyrefly: ignore [bad-argument-type] + assert len(args) == len(arg_tys) + + # TODO: can handle grid functions here or in C++, so + # that we don't need the grid handler above. + _StaticCudaLauncher._launch_kernel( + self.function, + grid_x, + grid_y, + grid_z, + self.num_warps, + self.shared, + arg_tys, + # pyrefly: ignore [bad-argument-type] + args, + stream, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..49ceacb50bc3d9f4b6c5c9451d6b810d7898bf20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import inspect +from typing import Any + +import torch + + +try: + import triton +except ImportError: + triton = None + + +if triton is not None: + import triton.language as tl + from triton import Config + from triton.compiler import CompiledKernel + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import JITFunction, KernelInterface + + try: + from triton.runtime.autotuner import PTXASError + except ImportError: + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + + def GPUTarget( + backend: str, + arch: int | str, + warp_size: int, + ) -> Any: + if torch.version.hip: + return [backend, arch, warp_size] + return (backend, arch) + + # In the latest triton, math functions were shuffled around into different modules: + # https://github.com/triton-lang/triton/pull/3172 + try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math + except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + + try: + from triton.language.standard import _log2 + except ImportError: + + def _log2(x: Any) -> Any: + raise NotImplementedError + + def _triton_config_has(param_name: str) -> bool: + if not hasattr(triton, "Config"): + return False + if not hasattr(triton.Config, "__init__"): + return False + return param_name in inspect.signature(triton.Config.__init__).parameters + + # Drop the legacy support of autoWS + HAS_WARP_SPEC = False + + try: + from triton import knobs + except ImportError: + knobs = None + + try: + from triton.runtime.cache import triton_key # type: ignore[attr-defined] + except ImportError: + from triton.compiler.compiler import ( + triton_key, # type: ignore[attr-defined,no-redef] + ) + + builtins_use_semantic_kwarg = ( + "_semantic" in inspect.signature(triton.language.core.view).parameters + ) + HAS_TRITON = True +else: + + def _raise_error(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError("triton package is not installed") + + class OutOfResources(Exception): # type: ignore[no-redef] + pass + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + Config = object + CompiledKernel = object + KernelInterface = object + ASTSource = None + GPUTarget = None + _log2 = _raise_error + libdevice = None + math = None + knobs = None + builtins_use_semantic_kwarg = False + + class triton: # type: ignore[no-redef] + @staticmethod + def jit(*args: Any, **kwargs: Any) -> Any: + return _raise_error + + class tl: # type: ignore[no-redef] + @staticmethod + def constexpr(val: Any) -> Any: + return val + + tensor = Any + dtype = Any + + class JITFunction: # type: ignore[no-redef] + pass + + HAS_WARP_SPEC = False + triton_key = _raise_error + HAS_TRITON = False + + +def cc_warp_size(cc: str | int) -> int: + if torch.version.hip: + cc_str = str(cc) + if "gfx10" in cc_str or "gfx11" in cc_str: + return 32 + else: + return 64 + else: + return 32 + + +try: + autograd_profiler = torch.autograd.profiler +except AttributeError: # Compile workers only have a mock version of torch + + class autograd_profiler: # type: ignore[no-redef] + _is_profiler_enabled = False + + +__all__ = [ + "Config", + "CompiledKernel", + "OutOfResources", + "KernelInterface", + "PTXASError", + "ASTSource", + "GPUTarget", + "tl", + "_log2", + "libdevice", + "math", + "triton", + "cc_warp_size", + "knobs", + "triton_key", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7e89868e216a5156c08a6c4922f16e1897eedeeb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py @@ -0,0 +1,761 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math as pymath +import warnings +from collections.abc import Callable +from typing import Any, TypeVar + +from .triton_compat import ( # noqa: F401 + _log2, + builtins_use_semantic_kwarg, + JITFunction, + libdevice, + math, + tl, + triton, +) + + +_T = TypeVar("_T") +_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e)) + + +def set_driver_to_cpu(): + driver = triton.runtime.driver + if backend := triton.backends.backends.get("cpu", None): + if isinstance(driver.active, backend.driver): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + # This can be a hard error once triton-cpu is merged into fbcode + warnings.warn( + "Could not find an active CPU backend. Generated kernels will not be executable!" + ) + + +def set_driver_to_gpu(): + driver = triton.runtime.driver + for name, backend in triton.backends.backends.items(): + if backend.driver.is_active() and name != "cpu": + # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4, + # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute. + if ( + isinstance(driver.active, backend.driver) + or hasattr(driver.active, "_obj") + and isinstance(driver.active._obj, backend.driver) + ): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + raise RuntimeError("Could not find an active GPU backend") + + +def get_backend_options(): + from triton.runtime import driver + + target = driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + return options.__dict__ + + +def get_constexprs(kernel: JITFunction) -> list[int]: + return [p.num for p in kernel.params if p.is_constexpr] + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def div_floor_integer(a, b): + # NOTE: a // b is C division, but we want floor division + # Based on c10::div_floor_integer + quot = a // b + remainder = a % b + fixed = tl.where(remainder != 0, quot - 1, quot) + return tl.where((a < 0) != (b < 0), fixed, quot) + + +@triton.jit +def remainder_integer(a, b): + # NOTE: a % b matches C division, not floor division + remainder = a % b + return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def _prod_accumulate(a, b): + return a * b + + +@triton.jit +def prod(input, axis): + return tl.reduce(input, axis, _prod_accumulate) + + +@triton.jit +def minimum(a, b): + mask = a < b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def min2(a, dim): + return tl.reduce(a, dim, minimum) + + +@triton.jit +def max2(a, dim): + return tl.reduce(a, dim, maximum) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan & (not b_isnan) + # Consider NaNs as equal + equal |= a_isnan & b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan & (not b_isnan) + # Consider NaNs as equal + equal |= a_isnan & b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def min_with_index(value, index, dim): + return tl.reduce((value, index), dim, minimum_with_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def exp(x, use_fast_math: tl.constexpr): + if use_fast_math: + return math.exp(x) + else: + return libdevice.exp(x) + + +@triton.jit +def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): + out_max = max2(lhs_max, dim) + out_max_keepdim = tl.expand_dims(out_max, dim) + delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) + out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) + return out_max, out_sum + + +@triton.jit +def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr): + """ + When we do combine, we assume lhs is the accumulator and rhs is the next + block of data. + Then rhs_sum is always 1. With that assumption, we can save some registers + and computation. + """ + out_max = maximum(lhs_max, rhs_max) + + lhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math) + ) + rhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math) + ) + + # Should be + # out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale + # but since rhs_sum is all 1, we can simplify it. + out_sum = lhs_sum * lhs_scale + rhs_scale + return out_max, out_sum + + +@triton.jit +def welford_reduce(value, mean, m2, weight, first_iteration): + if first_iteration: + new_weight = tl.full(weight.shape, 1, weight.dtype) + new_mean = value + new_m2 = tl.zeros_like(m2) + else: + delta = value - mean + new_weight = weight + 1 + new_mean = mean + delta / new_weight + new_m2 = m2 + delta * (value - new_mean) + return new_mean, new_m2, new_weight + + +@triton.jit +def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def welford(mean, m2, weight, dim): + return tl.reduce((mean, m2, weight), dim, welford_combine) + + +@triton.jit +def device_assert_then(cond, msg, r): + tl.device_assert(cond, msg) + return r + + +@triton.jit +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = tl.randint4x(seed, offset) + r0 = r0.to(tl.uint64) + r1 = r1.to(tl.uint64) + result = r0 | (r1 << 32) + size = high - low + result = result % size.to(tl.uint64) + result = result.to(tl.int64) + low + return result + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def any(a, dim): + return tl.reduce(a, dim, _any_combine) + + +@triton.jit +def bucketize_binary_search( + values: tl.tensor, + boundaries_ptr: tl.tensor, + BOUNDARIES_SIZE: int, + BOUNDARIES_UNDERLYING_NUMEL: int, + BOUNDARIES_STRIDE: int, + boundary_indices: tl.tensor, + indexing_dtype: tl.dtype, + right: "bool", # triton can't handle the unquoted bool annotation + sorter_ptr: tl.tensor, + SORTER_STRIDE: int, + sorter_indices: tl.tensor, +): + """ + See [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to bucketize. + boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D. + BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one + individual set of boundaries). + BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring + any striding. + BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor + boundary_indices: a tensor of the same size as "values"; each element is an index + into a 1-D, un-strided boundaries tensor, pointing to the first element in the set + of boundaries used for that value. + indexing_dtype: the dtype used for indexing into the boundaries tensor, and the + return dtype. + right: if true, use boundary intervals closed on the left; otherwise use intervals + closed on the right. + sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries, + but potentially different striding. If present, this allows us to treat boundaries + as sorted even if the elements of boundaries are unsorted. + SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last + dimension of the sorter tensor. + sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices". + BLOCK_SHAPE: the shape of the data block being processed. + """ + + low = tl.zeros(values.shape, dtype=indexing_dtype) + high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype) + + full_range = BOUNDARIES_SIZE + 1 + while full_range > 1: + mid = (high + low) // 2 + mask = ( + (mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL + ).logical_and(mid < BOUNDARIES_SIZE) + mid_indices = ( + mid + if sorter_ptr is None or SORTER_STRIDE is None + else tl.load( + sorter_ptr + sorter_indices + SORTER_STRIDE * mid, + mask=mask, + other=0, + ) + ) + + bucket_upper_bound = tl.load( + boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices, + mask=mask, + other=0, + ) + if right: + is_above = values >= bucket_upper_bound + else: + is_above = values > bucket_upper_bound + + low = tl.where(is_above & mask, mid + 1, low) + high = tl.where(is_above, high, mid) + + full_range = (full_range + 1) // 2 + + return low + + +@triton.jit +def pack_value_flag( + value, + flag, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) + return flag.to(DTYPE_PACK) | (uv << bitwidth) + + +@triton.jit +def unpack_value( + pack, + DTYPE_VALUE, + DTYPE_VALUE_AS_UINT, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) + return value_uint.to(DTYPE_VALUE, bitcast=True) + + +@triton.jit +def unpack_flag(pack, DTYPE_FLAG): + return pack.to(DTYPE_FLAG) + + +@triton.jit +def exclusive_scan_decoupled_lookback( + scratch_base, + block_value, + index, + combine_fn, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` + DTYPE_PACK: Unsigned type twice the width of block_value + + NOTE: This function is limited to values which are 32-bits or less because + we need to pack (value, flag) into a single unsigned int. + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + DTYPE_VALUE = block_value.dtype + pack = pack_value_flag( + block_value, + tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + if index > 0: + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], DTYPE_VALUE) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + # tl.atomic_load + flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) + while flag == 0: + pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") + flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) + + value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + pack = pack_value_flag( + inclusive_prefix, + tl.full([], 2, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + return exclusive_prefix + + +@triton.jit +def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block, must be 64-bits wide + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identity of combine_fn + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + if index > 0: + block_value_u64 = block_value.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 1, block_value_u64) + tl.debug_barrier() + flag_one = tl.full([], 1, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], block_value.dtype) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + flag = tl.full([], 0, tl.uint64) + while flag == 0: + flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") + + value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) + value = value_u64.to(block_value.dtype, bitcast=True) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) + tl.debug_barrier() + flag_two = tl.full([], 2, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") + + return exclusive_prefix + + +@triton.jit +def frexp(x): + # TODO(isuruf): use inline_asm_elementwise here + y = libdevice.ilogb(x) + 1 + exponent = tl.where(x == 0, 0, y) + mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) + return mantissa, exponent + + +@triton.jit +def _compare_and_swap_with_index( + x, + idxs, + rnumel, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + y = tl.reshape(x, shape) + iy = y.to(idtype, bitcast=True) + # slice left/right with 'stride' 2**(n_dims - i - 1) + right_mask = tl.arange(0, 2)[None, :, None].to(idtype) + left_mask = (1 - right_mask).to(idtype) + ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape) + iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape) + ileft = tl.reshape(ileft, x.shape) + iright = tl.reshape(iright, x.shape) + left = ileft.to(x.dtype, bitcast=True) + right = iright.to(x.dtype, bitcast=True) + + # idx + y_idx = tl.reshape(idxs, shape) + left_idx = tl.broadcast_to( + tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + right_idx = tl.broadcast_to( + tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + left_idx = tl.reshape(left_idx, x.shape) + right_idx = tl.reshape(right_idx, x.shape) + + # valid + if rnumel is None: + left_valid_mask = tl.full(x.shape, True, tl.int1) + right_valid_mask = tl.full(x.shape, True, tl.int1) + else: + left_valid_mask = left_idx < rnumel + right_valid_mask = right_idx < rnumel + + # actual compare-and-swap + ix = x.to(idtype, bitcast=True) + + # sort treats nan as having the higher value. comparisons with nan always return False. + # to align with sort semantics, we need to update descending to check if right_isnan, + # and ascending to check if left_isnan. + left_isnan = left != left + right_isnan = right != right + + if descending: + cond = left < right + if is_floating(left): + if not stable: + cond = cond | right_isnan + else: + cond = cond | (right_isnan & (~left_isnan)) + + else: + cond = left > right + if is_floating(left): + if not stable: + cond = cond | left_isnan + else: + cond = cond | (left_isnan & (~right_isnan)) + + if stable: + # When stable sorting, tie break by index + eq = left == right + if is_floating(left): + eq = eq | (left_isnan & right_isnan) + cond = cond | (eq & (left_idx > right_idx)) + + cond = (right_valid_mask > left_valid_mask) | ( + (right_valid_mask == left_valid_mask) & cond + ) + cond = (cond ^ flip).to(tl.int1) + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) + + return ret.to(x.dtype, bitcast=True), new_idxs + + +@triton.jit +def _bitonic_merge_with_index( + x, + idxs, + rnumel, + stage: tl.constexpr, + alternating: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if alternating: + shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape( + tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = False + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, idxs = _compare_and_swap_with_index( + x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending + ) + return x, idxs + + +@triton.jit +def sort_with_index( + x, # value + idxs, # index + rnumel, # number of elements + dim: tl.constexpr = None, + stable: tl.constexpr = tl.constexpr(False), + descending: tl.constexpr = tl.constexpr(False), +): + x, idxs = tl.broadcast(x, idxs) + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert( + _dim == len(x.shape) - 1, "only minor dimension is currently supported" + ) + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = _log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, idxs = _bitonic_merge_with_index( + x, + idxs, + rnumel, + i, + alternating=i < n_dims, + n_dims=n_dims, + stable=stable, + descending=descending, + ) + return x, idxs + + +@triton.jit +def select_one(x, mask, dim, keep_dims=False): + idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) + ix = x.to(idtype, bitcast=True) + iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) + return iy.to(x.dtype, bitcast=True) + + +@triton.jit +def x_grid_barrier(sem): + """ + Wait for all other thread blocks in grid sharing same y/z program_id + to reach this barrier before returning. + + Args: + sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID. + """ + # ensure stores before this are visible + tl.debug_barrier() + + one_i32 = 1 + one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined] + expected = tl.num_programs(0).to(tl.uint32) + if tl.program_id(0) == 0: + nb = 0x80000000 - (expected - one_u32) + else: + nb = one_u32 + + old_arrive = tl.atomic_add(sem, nb, sem="release") + + bar_flipped = False + while not bar_flipped: + # want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it + current_arrive = tl.atomic_add(sem, 0, sem="acquire") + # current_arrive = tl.load(sem, volatile=True) + bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0 + + # TODO(jansel): is this needed? + tl.debug_barrier() + + +def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: + """ + Decorator to mark a function as a Triton built-in function. These functions + are evaluated at compile time. + + Args: + f (function): The function to be marked as a Triton built-in. + + Returns: + function: The same function, marked as a Triton built-in. + """ + if builtins_use_semantic_kwarg: + # support Triton before and after https://github.com/triton-lang/triton/pull/7054 + # and after https://github.com/triton-lang/triton/pull/7239 + def wrapper(*args, _semantic, **kwargs): + kwargs["_builder"] = _semantic + return f(*args, **kwargs) + else: + wrapper = f # type: ignore[assignment] + + wrapper.__triton_builtin__ = True # type: ignore[attr-defined] + return wrapper + + +@triton_builtin +def constexpr_next_power_of_2( + n: tl.constexpr, *, _builder: object = None +) -> tl.constexpr: + """ + A version triton.next_power_of_two that can be used within a kernel on constants. + """ + assert isinstance(n, tl.constexpr) + return tl.constexpr(triton.next_power_of_2(n.value)) + + +@triton_builtin +def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr: + """ + Work around triton compile error: `ValueError: `other` cannot be provided without `mask`` + A compile-time to check to return either `val` or `None` depending on the value of mask. + """ + if isinstance(mask, tl.constexpr) and mask.value is None: + return tl.constexpr(None) + return val diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..2aefc498efb3e1731c433cb9924d6520e34cc16c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py @@ -0,0 +1,3874 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import builtins +import copy +import dataclasses +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import os.path +import re +import sys +import threading +import time +from collections import namedtuple +from typing import Any, Generic, Literal, TYPE_CHECKING, TypeVar, Union + +import torch +from torch._dynamo.utils import counters, set_feature_use +from torch._inductor import metrics +from torch._prims_common import compute_required_storage_length +from torch.utils._debug_mode import get_active_debug_mode +from torch.utils._ordered_set import OrderedSet + +from ..triton_bundler import TritonBundler +from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict +from . import triton_helpers +from .autotune_cache import AutotuneCache +from .benchmarking import benchmarker +from .coordinate_descent_tuner import CoordescTuner +from .hints import ( + _NUM_THREADS_PER_WARP, + AutotuneHint, + DeviceProperties, + HeuristicType, + ReductionHint, + TileHint, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from .runtime_utils import ( + ceildiv, + conditional_product, + create_bandwidth_info_str, + dynamo_timed, + get_first_attr, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_cache_dir, + triton_config_to_hashable, + triton_hash_to_path_key, + validate_triton_config, +) +from .static_cuda_launcher import StaticallyLaunchedCudaKernel +from .triton_compat import ( + ASTSource, + autograd_profiler, + cc_warp_size, + CompiledKernel, + Config, + GPUTarget, + HAS_WARP_SPEC, + KernelInterface, + knobs, + OutOfResources, + PTXASError, + triton, +) +from .triton_helpers import get_constexprs + + +class InductorConfig(Config): + """Inductor-specific Triton config with additional control flags""" + + def __init__(self, *args, dynamic_scale_rblock=True, **kwargs): + super().__init__(*args, **kwargs) + self.dynamic_scale_rblock = dynamic_scale_rblock + + +class NoTritonConfigsError(RuntimeError): + pass + + +if TYPE_CHECKING: + from collections.abc import Callable, Container, Hashable + + from torch._guards import CompileId + + LauncherType = Any + +_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel] +_T = TypeVar("_T", bound=_KernelType) + +log = logging.getLogger(__name__) + +triton_name_sub = re.compile(r"^def [^(]+\(") + + +def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str) -> str: + # Name agnostic + strip white space + fn_strip_name = re.sub(triton_name_sub, "(", source_code.strip(), count=1) + hash_str = size_hints_str + fn_strip_name + fn_hash = hashlib.sha256(hash_str.encode("utf-8")).hexdigest() + + return fn_hash + + +def lookup_autotune_config(size_hints, fn) -> Config | None: + lookup_table = torch._inductor.config.autotune_lookup_table + cached_config = None + if len(lookup_table) > 0 and "_fused_" in fn.src: + fn_hash = generate_lookup_hash_from_source_code(str(size_hints), fn.src) + if fn_hash in lookup_table: + config_dict = lookup_table[fn_hash] + block_configs = {k: v for k, v in config_dict.items() if "BLOCK" in k} + cached_config = Config( + block_configs, + num_warps=config_dict["num_warps"], + num_stages=config_dict["num_stages"], + ) + + return cached_config + + +def get_total_reduction_numel(numels: dict[str, int]) -> int: + return conditional_product( + *[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)] + ) + + +def autotune_hints_to_configs( + hints: OrderedSet[AutotuneHint], + size_hints, + block_size: int, + device_props: DeviceProperties, +) -> list[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: tuple[tuple[int, int | None, int | None], ...] + configs: list[Config] = [] + for hint in hints: + if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + configs.extend( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=( + device_props.warp_size if device_props.warp_size else 32 + ), + ) + for xyz in xyz_options + ) + + return configs + + +def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): + call_args = [] + call_kwargs = {} + for arg in args: + if isinstance(arg, (int, bool)): + call_args.append(str(arg)) + else: + call_args.append("T") + for k, v in kwargs.items(): + if isinstance(arg, (int, bool)): + call_kwargs[k] = v + else: + call_kwargs[k] = v + call_kwargs.update(launcher.config.kwargs) + call_kwargs["num_warps"] = launcher.config.num_warps + call_kwargs["num_stages"] = launcher.config.num_stages + if HAS_WARP_SPEC: + call_kwargs["num_consumer_groups"] = getattr( + launcher.config, "num_consumer_groups", 0 + ) + call_kwargs["num_buffers_warp_spec"] = getattr( + launcher.config, "num_buffers_warp_spec", 0 + ) + args_str = [*call_args] + args_str.extend(f"{k}={v}" for k, v in call_kwargs.items()) + args_str = ", ".join(args_str) + abs_path = os.path.abspath(sys.argv[0]) + with open(f"{abs_path}.launch_params", "a") as f: + f.write(f"{kernel_name} | {args_str} | {grid!r}\n") + + +def check_autotune_cache( + configs: list[Config], filename: str | None, inductor_meta: dict[str, Any] +) -> tuple[list[Config], AutotuneCache | None, dict[str, Any]]: + """ + Given a list of configs, checks autotune cache and return metadata + """ + autotune_cache = None + autotune_cache_info = {} + disabled = inductor_meta.get("force_disable_caches", False) + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and os.environ.get("TRITON_INTERPRET", "0") != "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + autotune_cache_info["best_config"] = triton_config_to_hashable( + best_config + ) + autotune_cache_info["autotune_cache_state"] = "hit" + + else: + autotune_cache_info["autotune_cache_state"] = "miss" + autotune_cache_info["num_configs"] = len(configs) + if inductor_meta.get("coordinate_descent_tuning"): + autotune_cache_info["coordesc_tuning"] = True + if len(configs) == 1: + # This is the config that coordinate descent tuning started at, which + # is not the same as the final config chosen (i.e. only_config, best_config) + autotune_cache_info["coordesc_tuning_start_config"] = ( + triton_config_to_hashable(configs[0]) + ) + else: + if len(configs) == 1: + autotune_cache_info["autotune_cache_state"] = "only 1 config" + autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) + + if disabled: + autotune_cache_info["autotune_cache_state"] = "force_disabled" + log.debug("autotune caching is disabled by config.force_disable_caches") + + return configs, autotune_cache, autotune_cache_info + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: list[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: str | None = None, + reset_to_zero_arg_names: list[str] | None = None, + autotune_cache_info: dict[str, Any] | None = None, + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + + self.fn = fn + self.device_props: DeviceProperties = triton_meta["device"] + self.triton_meta = { + **triton_meta, + "device": self.device_props.index, + "device_type": self.device_props.type, + } + self.inductor_meta = {} if inductor_meta is None else inductor_meta + # Add device properties to inductor_meta for use by coordinate descent tuner + self.inductor_meta["warp_size"] = self.device_props.warp_size + self.inductor_meta["max_threads_per_block"] = ( + self.device_props.max_threads_per_block + ) + self.deterministic_mode = self.inductor_meta.get("deterministic", False) + + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.reset_to_zero_arg_names = ( + [] if reset_to_zero_arg_names is None else reset_to_zero_arg_names + ) + self.optimize_mem = optimize_mem + cached_config = lookup_autotune_config(size_hints, fn) + self.configs = [cached_config] if cached_config else configs + + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + self.autotune_cache_info = autotune_cache_info + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.compile_results: list[CompileResult[_KernelType]] = [] + self.launchers: list[LauncherType] = [] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = triton_cache_dir( + self.triton_meta.get("device", 0) + ) + log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) + + self.size_hints = size_hints + self.is_mix_order_reduction = self.inductor_meta.get("RSPLIT_SIZE") is not None + self.coordesc_tuner = CoordescTuner( + is_mm=False, + is_native_matmul=triton_meta.get("native_matmul", False), + is_mix_order_reduction=self.is_mix_order_reduction, + name=self.fn.__name__, + size_hints=size_hints, + inductor_meta=self.inductor_meta, + ) + self.filename = filename + + # used for profiling + self.kernel_hash: str = "" + + # Kernels are stored in the codecache with the filename as a hash of the code. + # We rely on this to obtain the kernel hash + if self.filename is not None: + base_name = os.path.basename(self.filename) + if ".py" in base_name: + self.kernel_hash = os.path.splitext(base_name)[0] + + self.precompile_time_taken_ns = 0 + self.autotune_time_taken_ns = 0 + # Dumps the launch configs after autotuning. + self.dump_launch_params = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" + ) + + self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" + + # Compile-time info included in runtime logginging + self.compile_id: CompileId | None = None + self.is_backward = False + + # Mode for launch grid calculation + self.grid_mode: Literal["python", "cpp"] = "python" + + def is_statically_launchable(self): + """ + Checks if every compiled kernel is statically launchable, which + allows us to efficiently cache it in FXGraphCache + """ + if not self.compile_results: + return False + return all( + isinstance(x, StaticTritonCompileResult) for x in self.compile_results + ) + + def recheck_autotune_cache( + self, reload_kernel_from_src: Callable[[], CachingAutotuner] + ) -> None: + """ + On cache load on static autotuner, we need to recheck the autotune cache, since + a best config could have been found from a previous run + """ + assert self.is_statically_launchable() + + configs = [result.config for result in self.compile_results] + + (cached_configs, _, autotune_cache_info) = check_autotune_cache( + configs, self.filename, self.inductor_meta + ) + self.autotune_cache_info = autotune_cache_info + # I.e. there was an autotune cache hit + if len(cached_configs) == 1 and len(configs) > 1: + best_config = cached_configs[0] + # Grab the best compiled config, if it's in the list of available ones + best_config_hash = triton_config_to_hashable(best_config) + + for compile_result in self.compile_results: + if triton_config_to_hashable(compile_result.config) == best_config_hash: + self.compile_results = [compile_result] + return + + # If the best config isn't in our list of compile results, + # it's likely because it was found by coordesc after the cache + # already saved + if best_config.found_by_coordesc: + with dynamo_timed("CachingAutotuner.slow_precompile_config"): + if self.fn.fn is None: + self.fn = reload_kernel_from_src().fn + self.compile_results = [self._precompile_config(best_config)] + + def set_compile_info(self, compile_id: CompileId | None, is_backward: bool) -> None: + self.compile_id = compile_id + self.is_backward = is_backward + + def precompile( + self, + warm_cache_only=False, + reload_kernel: Callable[[], CachingAutotuner] | None = None, + static_triton_bundle_key: str | None = None, + ): + if warm_cache_only: + self._precompile_worker() + return + with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel + self._precompile_worker() + if static_triton_bundle_key is not None and self.is_statically_launchable(): + TritonBundler.put_static_autotuner(static_triton_bundle_key, self) + self._make_launchers() + self._dynamic_scale_rblock() + + def _precompile_worker(self): + if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined] + self.triton_meta.get("device", 0), + ) + return + assert not self.launchers + if not self.configs: + raise NoTritonConfigsError("No triton configs are available") + + compile_results = [] + exc = None + for c in self.configs: + try: + compile_results.append(self._precompile_config(c)) + except (OutOfResources, PTXASError) as e: + exc = e + if len(compile_results) == 0: + raise NoTritonConfigsError( + f"No valid triton configs. {type(exc).__name__}: {exc}" + ) + self.compile_results = compile_results + self.configs = None + + def _dynamic_scale_rblock(self): + # TODO(jansel): we should find a way to move this extra compile into the worker process + # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg. + device_prop = self.device_props + if ( + not self.deterministic_mode + and self.inductor_meta.get("dynamic_scale_rblock", True) + and not self.inductor_meta.get("persistent_reduction") + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] + and device_prop.major + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None + ): + assert device_prop.regs_per_multiprocessor + assert device_prop.max_threads_per_multi_processor + assert device_prop.multi_processor_count + seen_config_hashes: OrderedSet[Hashable] | None = None + warp_size = device_prop.warp_size or 32 + for result in self.compile_results: + triton_config = result.config + compiled_binary = result.kernel + assert len(self.size_hints) >= 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + reduction_kwargs = [ + kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R") + ] + rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs] + total_block = (self.size_hints["x"] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblocks are not too small + if conditional_product(*rblocks) <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce R0_BLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * warp_size + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tighter upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if total_block <= max_blocks_per_sm * device_prop.multi_processor_count: + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + + # Reduce the largest Rn_BLOCK by a factor of 2. + largest_rkwarg: str = max( + reduction_kwargs, key=triton_config.kwargs.__getitem__ + ) + new_config.kwargs[largest_rkwarg] //= 2 + + if seen_config_hashes is None: + seen_config_hashes = OrderedSet( + [ + triton_config_to_hashable(x.config) + for x in self.compile_results + ] + ) + new_config_hash = triton_config_to_hashable(new_config) + if new_config_hash in seen_config_hashes: + continue + seen_config_hashes.add(new_config_hash) + log.debug( + "Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)", + largest_rkwarg, + triton_config, + new_config, + ) + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + self.compile_results.append(self._precompile_config(new_config)) # noqa: B909 + + self._make_launchers() + + def _make_launchers(self): + if len(self.launchers) == len(self.compile_results): + return + + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, self.triton_meta["device"]): + # need to initialize context + with dynamo_timed( + "CachingAutotuner.synchronize", + # Deliberately avoid overloading pt2_compile_events: + log_pt2_compile_event=False, + ): + device_interface.synchronize(device_interface.current_device()) + + launchers = [] + exc = None + for result in self.compile_results: + try: + launchers.append(result.make_launcher()) + + except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e: + exc = e + if len(launchers) == 0: + raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") + self.launchers = launchers + + def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]: + """Drop stuff from triton.JITFunction that does not pickle. + This must be called after precompile so that these things are no longer needed. + Returns a tuple of old values + """ + old_values = ( + self.fn.fn, + self.fn.__globals__, + self.fn.used_global_vals, + self.fn.repr, + self.launchers, + getattr(self.fn, "_hash_lock", None), + ) + self.fn.fn = None + self.fn.__globals__ = None + self.fn.used_global_vals = None + self.fn.repr = _ConstRepr(self.fn.repr(self.fn)) + self.launchers = [] + self.fn._hash_lock = None + return old_values + + def restore_after_unpickle( + self, old_values: tuple[Any, Any, Any, Any, Any, Any] | None + ) -> None: + if old_values: + ( + self.fn.fn, + self.fn.__globals__, + self.fn.used_global_vals, + self.fn.repr, + self.launchers, + self.fn._hash_lock, + ) = old_values + else: + # even if we don't need/have specific values, we do need the + # _hash_lock to be a valid RLock + self.fn._hash_lock = threading.RLock() + + def prepare_for_caching(self) -> None: + """ + Statically Launched CUDA Kernels have a raw cubin on them + that we don't need to store in the cache(since TritonBundler handles the collection for us) + """ + for result in self.compile_results: + if isinstance(result, StaticTritonCompileResult): + # Don't save this in the inductor cache, as it is very large + result.kernel.cubin_raw = None + + def __getstate__(self) -> dict[str, Any]: + assert not self.launchers, ( + "pickle should not be called with after make_launchers()" + ) + return { + **self.__dict__, + "lock": None, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + self.lock = threading.Lock() + + def get_device_interface(self): + # this code cannot run in compile workers, because it imports from torch + from torch._dynamo.device_interface import get_interface_for_device + + return get_interface_for_device(self.device_props.type.replace("hip", "cuda")) + + def _create_compile_meta(self, cfg: Config) -> dict[str, Any]: + """ + Create compilation metadata for a given autotuner config. This involves + processing the Config kwargs so that the kwargs that are not part + of the triton signature are passed in as options to triton.compile + instead + """ + compile_meta = copy.deepcopy(self.triton_meta) + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + + cfg_kwargs = cfg.kwargs + if self.device_props.type == "hip": + cfg_kwargs = {**cfg_kwargs} + for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"): + if k in cfg_kwargs: + compile_meta[k] = cfg_kwargs.pop(k) + compile_meta["constants"].update(cfg_kwargs) + + for i in get_constexprs(self.fn): + arg_name = self.fn.arg_names[i] + if arg_name not in compile_meta["constants"] and ( + arg_name == "num_warps" or arg_name == "num_stages" + ): + compile_meta["constants"][arg_name] = getattr(cfg, arg_name) + if HAS_WARP_SPEC: + compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0) + compile_meta["num_buffers_warp_spec"] = getattr( + cfg, "num_buffers_warp_spec", 0 + ) + compile_meta["debug"] = self.inductor_meta.get( + "assert_indirect_indexing", True + ) and not self.inductor_meta.get("is_hip", False) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + return compile_meta + + def _create_compile_options( + self, cfg: Config, compile_meta: dict[str, Any] + ) -> dict[str, Any]: + """ + Create options to pass to triton.compile based on the compile metadata + and the given config. + """ + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + "sanitize_overflow": False, # turn off additional asserts added for overflow checks + } + if "enable_fp_fusion" in compile_meta: + options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"] + if HAS_WARP_SPEC: + options.update( + { + "num_consumer_groups": compile_meta.get("num_consumer_groups", 0), + "num_buffers_warp_spec": compile_meta.get( + "num_buffers_warp_spec", 0 + ), + } + ) + if self.device_props.type == "cuda": + options.update( + { + "launch_cooperative_grid": compile_meta.get( + "launch_cooperative_grid", False + ), + "launch_pdl": compile_meta.get("launch_pdl", False), # True + } + ) + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"] + + return options + + def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: + """Ahead of time compile a given autotuner config.""" + compile_meta = self._create_compile_meta(cfg) + + if self.device_props.type == "cpu": + triton_helpers.set_driver_to_cpu() + else: + triton_helpers.set_driver_to_gpu() + + if not ASTSource: + raise RuntimeError("Installed triton version too old, please upgrade") + + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + if self.device_props.type == "mtia": + from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found] + build_codename, + ) + + arch = build_codename() + else: + arch = compile_meta["cc"] + + target = GPUTarget( + compile_meta["device_type"], + arch, + cc_warp_size(compile_meta["cc"]), + ) + + options = self._create_compile_options(cfg, compile_meta) + + compile_kwargs = { + "target": target, + "options": options, + } + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + + # Simulate JIT Hook call + if ( + torch._inductor.config.run_jit_post_compile_hook + and knobs + and getattr(knobs.runtime, "jit_post_compile_hook", None) + ): + try: + hook = knobs.runtime.jit_post_compile_hook + + # base args everyone should get + call_kwargs = dict( + key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)), + repr=getattr(self.fn, "src", None), + fn=self.fn, + compile=binary, + is_manual_warmup=False, + already_compiled=True, + ) + + # only add inductor_args if the hook takes it + sig = inspect.signature(hook) + params = sig.parameters + if "inductor_args" in params and "config_args" in self.inductor_meta: + call_kwargs["inductor_args"] = self.inductor_meta["config_args"] + + hook(**call_kwargs) + except Exception: + log.exception("jit_post_compile_hook failed") + + TritonBundler.put( + triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) + ) + # If the binary has a cubin file to directly launch, save it on the binary + static_launcher = StaticTritonCompileResult.can_statically_launch( + binary, self.inductor_meta, self.triton_meta, self.heuristic_type + ) + + if static_launcher is not None: + result = StaticTritonCompileResult( + static_launcher, cfg, compile_meta, self.inductor_meta + ) + return result + + return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) + + def bench(self, launcher, *args, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs with spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs) + + def kernel_call(): + cloned_args, cloned_kwargs = self.maybe_clone_args( + cpu_copies, *args, **kwargs + ) + # reset to zero before evaluating any config + self.reset_to_zero_args(*args, **kwargs) + kernel_name = self.inductor_meta.get("kernel_name", "triton kernel") + if autograd_profiler._is_profiler_enabled: + profiler_kwargs = self.get_profiler_kwargs(stream, launcher) + with torch._C._profiler._RecordFunctionFast( + kernel_name, + cloned_args, + profiler_kwargs, + ): + try: + launcher( + *cloned_args, + **cloned_kwargs, + stream=stream, + ) + except Exception: + log.error("Failed during launch %s: ", kernel_name) + raise + + else: + try: + launcher( + *cloned_args, + **cloned_kwargs, + stream=stream, + ) + except Exception: + log.error("Failed during launch %s: ", kernel_name) + raise + self.restore_args_from_cpu(cpu_copies) + + # only use profiler when not already in a profiler instance + if with_profiler and not autograd_profiler._is_profiler_enabled: + from torch._inductor.utils import do_bench_using_profiling + + return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + + benchmark_kwargs = ( + {} + if self.device_props.type == "cpu" + else {"rep": 40, "is_vetted_benchmarking": True} + ) + return benchmarker.benchmark( + fn=kernel_call, + device=self.device_props.type, + **benchmark_kwargs, # type: ignore[arg-type] + ) + + def copy_args_to_cpu_if_needed(self, *args, **kwargs): + """ + To support benchmarking in the presence of mutated args, we need to avoid + autotuning contanminating them. We try to pass cloned args to the kernel. + If those clones would increase the peak memory usage, however, we instead + copy to cpu and restore them after each iteration. Figure out the args + to be copied and do the copying. + """ + if not self.optimize_mem: + return {} + + copies = {} + try: + budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated() + except RuntimeError: + # Possibly a custom CUDA allocator, see https://github.com/pytorch/pytorch/issues/163257 + return {} + + def maybe_copy(name, arg): + if name in self.mutated_arg_names and arg.is_cuda: + nonlocal budget + assert isinstance(arg, torch.Tensor) + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + size = required_storage_length * arg.element_size() + if size > budget: + cpu_arg = torch.empty_strided( + (required_storage_length,), + (1,), + dtype=arg.dtype, + device="cpu", + pin_memory=True, + ) + cpu_arg.copy_( + arg.as_strided((required_storage_length,), (1,)), + non_blocking=True, + ) + copies[name] = (arg, cpu_arg) + else: + budget -= size + + for name, arg in zip(self.fn.arg_names, args): + maybe_copy(name, arg) + + for name, arg in kwargs.items(): + maybe_copy(name, arg) + + return copies + + def restore_args_from_cpu(self, cpu_copies): + for pair in cpu_copies.values(): + arg, cpu_arg = pair + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + arg.as_strided((required_storage_length,), (1,)).copy_( + cpu_arg, non_blocking=True + ) + + def reset_to_zero_args(self, *args, **kwargs): + if not self.reset_to_zero_arg_names: + return + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + for name, arg in kwargs.items(): + if name in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + def maybe_clone_args( + self, exclude: Container[str], *args, **kwargs + ) -> tuple[list[Any], dict[str, Any]]: + """ + Prepare new args and kwargs by cloning any in-place buffers + (that are not in the provided exclusion list), to avoid autotune + contaminating them. Avoid cloning the other buffers because it + leads to increased memory usage. + """ + from ..compile_fx import clone_preserve_strides + + def prepare_arg(name, arg): + if name in self.mutated_arg_names and name not in exclude: + assert isinstance(arg, torch.Tensor) + return clone_preserve_strides(arg) + else: + return arg + + cloned_args = [ + prepare_arg(name, arg) + for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args) + ] + cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()} + return cloned_args, cloned_kwargs + + def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: + return self.maybe_clone_args(OrderedSet(), *args, **kwargs) + + def benchmark_all_configs(self, *args, **kwargs): + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + # Temporarily disable due to spam + # compilation_callback.callback_handler.install_callbacks( + # compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, + # str(self.compile_id), + # ), + ): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + if metrics.is_metric_table_enabled("kernel_autotune"): + if self.fn.fn is None: + self.fn = self._reload_kernel().fn + + kernel_path = self.fn.fn.__code__.co_filename + kernel_name = self.fn.__name__ + + for k, v in timings.items(): + metrics.log_kernel_autotune_result( + kernel_path, kernel_name, k.config, v + ) + + self.reset_to_zero_args(*args, **kwargs) + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + + # log the best config + launcher = self.launchers[0] + log.debug( + "Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s", + self.fn.__name__, + launcher.config, + timings[launcher], + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + + if self.save_cache_hook: + self.save_cache_hook( + launcher.config, + self.autotune_time_taken_ns, + triton_cache_hash=launcher.cache_hash, + ) + + def save_gpu_kernel(self, stream, launcher): + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": ( + launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"] + ), + "num_warps": ( + launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps + ), + "shared_mem": ( + launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared + ), + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "config": config_to_dict(launcher.config), + "inductor_meta": self.inductor_meta, + "triton_meta": self.triton_meta, + "def_args": launcher.def_args, + "call_args": launcher.call_args, + "global_scratch": launcher.global_scratch, + "profile_scratch": launcher.profile_scratch, + } + if self.device_props.type == "xpu": + # On the XPU backend, threads_per_warp is not always 32. + # For Intel GEMM Triton kernels, it can be 16. + # This information must be preserved so that the Cpp wrapper + # can launch the kernel with the correct configuration. + params["threads_per_warp"] = getattr( + launcher.bin.metadata, "threads_per_warp", 32 + ) + + from torch._inductor import config + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") + binary = launcher.bin.asm[bin_type] + + # ROCm multi-arch: capture LLVM IR + if torch.version.hip and config.aot_inductor.emit_multi_arch_kernel: + # Multi-arch ROCm: Capture LLVM IR for cross-architecture compilation + asm_type = "ll" + + # llir is the key to obtain LLVM IR from triton + asm = launcher.bin.asm.get("llir", None) + + # CRITICAL: Multi-arch compilation cannot proceed without LLVM IR + # Fail fast with clear error message pointing to the issue + if not asm: + available_keys = list(launcher.bin.asm.keys()) + raise RuntimeError( + f"ROCm multi-arch requires LLVM IR, but none found. " + f"Available keys: {available_keys}. " + f"Triton may need to be patched to emit LLVM IR." + ) + + # Everything else: capture architecture-specific assembly + else: + asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( + self.device_props.type, None + ) + asm = launcher.bin.asm.get(asm_type, None) + + CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if self.heuristic_type in ( + HeuristicType.TEMPLATE, + HeuristicType.USER_AUTOTUNE, + HeuristicType.FIXED, + ): + # skip triton template + return launcher + + if self.deterministic_mode and self.heuristic_type in ( + HeuristicType.REDUCTION, + HeuristicType.PERSISTENT_REDUCTION, + HeuristicType.SPLIT_SCAN, + ): + # Not only RBLOCK size matters for numericals of reduction. + # num_warps also matters since that affect how much data + # is handled by each thread, how many warp-reduction we do + # in parallel and how much data is there for block + # reduction. + return launcher + + with dynamo_timed( + "CachingAutotuner.coordinate_descent_tuning", + # These generate too many pt2_compile_event logs: + log_pt2_compile_event=False, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ): + return self._coordinate_descent_tuning(launcher, *args, **kwargs) + + def _coordinate_descent_tuning(self, launcher, *args, **kwargs): + config2launcher = {launcher.config: launcher} + + # TODO: should we just load the kernels ahead of time if we know we're going to call this? + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + + def benchmark_one_config(config): + with self.lock: + launcher = self._precompile_config(config).make_launcher() + config2launcher[config] = launcher + + out = self.bench(launcher, *args, **kwargs) + counters["inductor"]["coordesc_tuning_bench"] += 1 + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "R0_BLOCK" in launcher.config.kwargs + ), ( + "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" + ) + start_time = time.time_ns() + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + coordesc_time_taken_ns = time.time_ns() - start_time + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook( + best_config, + self.autotune_time_taken_ns + coordesc_time_taken_ns, + found_by_coordesc=True, + ) + + if best_config not in config2launcher: + # On a Coordesc cache hit, we might not have loaded the launcher + # This can happen because PyCodeCache saves CachingAutotuners in memory, + # even for separate compile IDs (which can have different inputs without changing output code) + config2launcher[best_config] = self._precompile_config( + best_config + ).make_launcher() + + fn_hash = generate_lookup_hash_from_source_code( + str(self.size_hints), self.fn.src + ) + log.debug("Function hash %s has best config %s", fn_hash, best_config) + return config2launcher[best_config] + + def get_profiler_kwargs(self, stream, launcher): + kernel_kwargs_str = ",".join( + f"{k}={v}" for (k, v) in launcher.config.kwargs.items() + ) + + ret = { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + "num_warps": launcher.config.num_warps, + "num_stages": launcher.config.num_stages, + "kernel_kwargs": kernel_kwargs_str, + } + if "kernel_name" in self.inductor_meta: + ret["kernel_name"] = self.inductor_meta["kernel_name"] + if "kernel_flop" in self.inductor_meta: + ret["kernel_flop"] = self.inductor_meta["kernel_flop"] + if "kernel_num_gb" in self.inductor_meta: + ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"] + return ret + + def run( + self, + *args, + stream, + benchmark_run=False, + **kwargs, + ): # type:ignore[override] + """Launch triton kernel call and return result.""" + debug_mode = get_active_debug_mode() + debug_call = None + if debug_mode: + arg_names = list(self.triton_meta.get("signature", {}).keys()) + kernel_kwargs = dict(zip(arg_names, args)) + kernel_kwargs.update(kwargs) + debug_call = debug_mode.record_triton_kernel( + kernel_name=self.fn.__name__, kwargs=kernel_kwargs + ) + + if hasattr(triton, "set_allocator"): + + def alloc_fn(size: int, align: int, stream: int | None): + return torch.empty( + size, dtype=torch.int8, device=self.device_props.type + ) + + triton.set_allocator(alloc_fn) + + if self.triton_interpret: + args, grid = self._interpret_args_grid(args, self.configs[0]) + return self.fn[grid]( + *args, + **kwargs, + **self.configs[0].kwargs, + ) + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs) + ] + + (launcher,) = self.launchers + if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): + self.save_gpu_kernel(stream, launcher) + + # PyTorch execution trace replay calls CachingAutotuner::run() instead of calls launcher + # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() + # make a copy here to avoid mutating the original args + args_without_constexprs = tuple(args) + + if self.dump_launch_params: + new_args, grid = self._interpret_args_grid(args, launcher.config) + _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + profiler_kwargs = self.get_profiler_kwargs(stream, launcher) + + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args_without_constexprs, + profiler_kwargs, + ): + result = launcher( + *args, + **kwargs, + stream=stream, + ) + else: + result = launcher( + *args, + **kwargs, + stream=stream, + ) + + if debug_call: + debug_call.finalize(self.get_device_interface()) + return result + + def _interpret_args_grid( + self, args: tuple[Any, ...], cfg: Config + ) -> tuple[tuple[Any, ...], tuple[int, int, int]]: + if triton_version_uses_attrs_dict(): + + def filtered_signature() -> list[str]: + # constexprs are not passed in as args + new_signature: list[str] = [] + from triton.runtime.interpreter import InterpretedFunction + + for i, x in enumerate(self.triton_meta["signature"].keys()): + if isinstance(self.fn, InterpretedFunction): + # These are torch compiled triton kernels that definitely + # have block size configs. Dynamo does not currently + # trace user defined triton kernels when TRITON_INTERPRET=1 + if x not in cfg.kwargs: + new_signature.append(x) + elif i not in get_constexprs(self.fn): + # use constexprs rather than just configs since user + # defined triton kernels may not have any configs + new_signature.append(x) + + return new_signature + + else: + + def filtered_signature() -> list[str]: + return list(self.triton_meta["signature"].keys()) + + grid = GridExpr.from_meta( + self.inductor_meta, cfg, mode=self.grid_mode + ).eval_slow( + dict( + zip( + [ + *filtered_signature(), + *self.inductor_meta.get("extra_launcher_args", ()), + ], + args, + ) + ) + ) + if self.inductor_meta.get("extra_launcher_args"): + args = args[: -len(self.inductor_meta["extra_launcher_args"])] + return args, grid + + +class _ConstRepr: + def __init__(self, value: str): + self.value = value + + def __call__(self, _=None) -> str: + return self.value + + +class CompileResult(Generic[_T]): + """ + Base class representing compiled result. + """ + + def __init__( + self, + kernel: _T, + config: Config, + compile_meta: dict[str, Any], + inductor_meta: dict[str, Any], + ): + self.kernel = kernel + self.config = config + self.compile_meta = compile_meta + self.inductor_meta = inductor_meta + + def make_launcher(self) -> LauncherType: ... + + def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType: + grid = GridExpr.from_meta(self.inductor_meta, self.config) + # grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)` + lines = [ + f"def launcher({', '.join(def_args)}, stream):", + *[f" {line}" for line in grid.prefix], + f" grid_0 = {grid.x_grid}", + f" grid_1 = {grid.y_grid}", + f" grid_2 = {grid.z_grid}", + f" runner({', '.join(runner_args)})", + ] + launcher_code = "\n".join(lines) + exec(launcher_code, scope) + return scope["launcher"] + + def _get_arg_lists( + self, arg_names, constexprs + ) -> tuple[list[str], list[str], OrderedSet[str]]: + """ + Return a bunch of intermediate lists of args needed for generating + launcher code. + """ + compile_meta = self.compile_meta + cfg = self.config + known_constants = OrderedSet( + arg for i, arg in enumerate(arg_names) if i in constexprs + ) + + """ + https://github.com/pytorch/pytorch/issues/115344 + + self.fn.constexprs doesn't properly deal with None args, so when we filter out + an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well. + We also don't want to modify self.fn. + + We know that we removed something from the signature if: + 1. It's in compile_meta["constants"] + 2. It isn't a constant we already know about + Note: The value of interest has already been added to compile_meta['constants'], + so we use self.fn.constexprs instead. + 3. It isn't in the compile_meta signature + """ + none_args = OrderedSet( + k + for k, v in compile_meta["constants"].items() + if v is None and k not in known_constants + ) + none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) + + def _convert_constant(constant): + if isinstance(constant, str): + return "r'" + constant + "'" + else: + return repr(constant) + + if triton_version_uses_attrs_dict(): + call_args = arg_names + def_args = arg_names + implicit_constants = OrderedSet( + ( + "num_warps", + "num_stages", + ) + ).union(OrderedSet(k for k in known_constants)) + if implicit_constants := implicit_constants & OrderedSet( + compile_meta["constants"].keys() + ): + # num_warps/num_stages are special implicit args that are not in the signature + # see test_triton_kernel_special_params + def_args = [arg for arg in def_args if arg not in implicit_constants] + repl = { + k: _convert_constant(compile_meta["constants"].get(k)) + for k in implicit_constants + } + call_args = [repl.get(arg, arg) for arg in call_args] + else: + call_args = [ + arg + for i, arg in enumerate(arg_names) + if i not in constexprs and arg not in none_args + ] + cfg_dict = config_to_dict(cfg) + def_args = [ + name + for name in arg_names + if name not in cfg_dict and name not in none_args + ] + + if "extra_launcher_args" in self.inductor_meta: + def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]] + + return call_args, def_args, none_args + + +class CannotStaticallyLaunchKernel(Exception): + pass + + +class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]): + """ + TritonCompileResult that uses StaticCudaLauncher, + which vastly simplifies the setup and metadata needed to be kept. + """ + + @staticmethod + def can_statically_launch( + kernel: CompiledKernel, + inductor_meta: dict[str, Any], + triton_meta: dict[str, Any], + heuristic_type: HeuristicType, + ) -> StaticallyLaunchedCudaKernel | None: + if not torch._inductor.config.use_static_cuda_launcher: + return None + + def check_can_launch() -> StaticallyLaunchedCudaKernel: + if triton_meta.get("device_type") != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") + + if torch._inductor.config.cpp_wrapper: + # If we're running with cpp wrapper, it doesn't + # make sense to statically compile since everything + # is codegenned anyway + raise CannotStaticallyLaunchKernel("Cpp wrapper enabled") + + if ( + heuristic_type == HeuristicType.USER_AUTOTUNE + and not torch._inductor.config.static_launch_user_defined_triton_kernels + ): + # Don't support user defined triton kernels yet + raise CannotStaticallyLaunchKernel("User defined triton kernel") + + if inductor_meta.get("store_cubin"): + # Requires storing the entire binary + raise CannotStaticallyLaunchKernel("store_cubin is enabled") + + if getattr(kernel.metadata, "launch_pdl", False) or getattr( + kernel.metadata, "launch_cooperative_grid", False + ): + raise CannotStaticallyLaunchKernel( + "static launch does not support launch attributes" + ) + + cubin_location = os.path.join( + triton_cache_dir(triton_meta.get("device", 0)), + triton_hash_to_path_key(kernel.hash), + f"{kernel.src.fn.__name__}.cubin", + ) + + if not os.path.exists(cubin_location): + raise CannotStaticallyLaunchKernel( + f"Cubin path not found: {cubin_location}" + ) + + else: + kernel._cubin_path = cubin_location + + try: + static_kernel = StaticallyLaunchedCudaKernel(kernel) + except NotImplementedError as e: + raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e + + return static_kernel + + try: + result = check_can_launch() + return result + except CannotStaticallyLaunchKernel as e: + log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e)) # noqa: G200 + if torch._inductor.config.strict_static_cuda_launcher: + raise e + return None + + def reload_cubin_path(self): + """ + When loading from cache on disk, we want to reload cubin + files from their appropriate location on disc. + """ + cubin_location = os.path.join( + triton_cache_dir(self.compile_meta.get("device", 0)), + triton_hash_to_path_key(self.kernel.hash), + f"{self.kernel.name}.cubin", + ) + if not os.path.exists(cubin_location): + if self.kernel.cubin_raw is not None: + # We saved the raw cubin, so write it to he appropriate location + self.kernel.reload_cubin_from_raw(cubin_location) + else: + raise RuntimeError( + "Cubin file saved by TritonBundler not found at %s", cubin_location + ) + self.kernel.cubin_path = cubin_location + + def make_launcher(self) -> LauncherType: + # If at least one static make_launcher call occurs, + # we're sure static cuda launcher was used for this compile + set_feature_use("static_cuda_launcher", True) + # Load the binary on the parent + if not self.kernel.cubin_path: + self.reload_cubin_path() + device = self.compile_meta.get("device", 0) + if device is None: + device = 0 + self.kernel.load_kernel(device) + scope = { + "runner": self.kernel.run, + } + + # NOTE: Constexpr handling for triton and static cuda launcher + + # Triton kernels have two types of constexprs: *declared* ones, which are ones the user + # has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton + # deems constant while compiling/analyzing the code (i.e. unused parameters, for example) + + # Triton kernels handle constexprs slightly differently depending on which version of triton + # we care about (we support 3.2.0 and 3.3.0). + + # In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel + # In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where + # they are subsequently ignored. + # When statically launching, since we're launching from the triton generated cubin, we actually want to + # always get rid of all const exprs, declared or implied, since the underlying cubin file has all + # of the constants stripped away anyway. + + # But CachingAutotuner.run will pass us a different number of arguments depending on + # whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic + # as the (non static) TritonCompileResult. We then generate call_args ourselves, since we + # want only a subset of the arguments passed to triton. + # Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs, + # which matches behavior with regular TritonCompileResult + _, def_args, none_args = self._get_arg_lists( + self.kernel.arg_names, self.kernel.declared_constexprs + ) + + call_args = [ + arg + for i, arg in enumerate(self.kernel.arg_names) + if i not in self.kernel.full_constexprs and arg not in none_args + ] + + # StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args + runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args] + launcher = self._gen_launcher_code(scope, def_args, runner_args) + launcher.config = self.config # type: ignore[attr-defined] + launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined] + launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined] + launcher.shared = self.kernel.shared # type: ignore[attr-defined] + launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined] + launcher.store_cubin = False # type: ignore[attr-defined] + launcher._is_static = True # type: ignore[attr-defined] + return launcher + + +class TritonCompileResult(CompileResult[CompiledKernel]): + """ + Upstream Triton CompileKernel can not be pickled. This is a wrapper + to support serialization and generate the launcher function. + """ + + @staticmethod + @functools.lru_cache(32) + def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any: + return namedtuple("KernelMetadata", sorted(fields)) + + @staticmethod + def _serialize_metadata(metadata): + """ + Triton uses a nested class called KernelMetadata to store metadata information. + Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear + in the toplevel namespace of the module. So these serialization/deser functions + are used to convert the namedtuples to a dict and back. + + As for packed_metadata, depending on the triton backend, KernelMetadata can be + a namedtuple, or a regular tuple! So the serialization function branches on whether + the metadata to be serialized is a namedtuple or regular, serializable one. + """ + + def is_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if is_namedtuple(metadata): + return metadata._asdict() + else: + return metadata + + @staticmethod + def _deserialize_metadata(metadata): + if isinstance(metadata, dict): + return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))( + **metadata + ) + else: + return metadata + + def __getstate__(self) -> dict[str, Any]: + kernel = self.kernel + # replace the fields that don't pickle nicely + kernel_state = { + **kernel.__dict__, + # See doc about serializing metadata above + "metadata": self._serialize_metadata(kernel.metadata), + "packed_metadata": self._serialize_metadata( + getattr(kernel, "packed_metadata", None) + ), + "module": None, # regenerated by kernel._init_handles() + "function": None, # regenerated by kernel._init_handles() + "run": None, # regenerated by kernel._init_handles() + } + return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item] + + def __setstate__(self, state: dict[str, Any]) -> None: + # src = ASTSource.__new__(ASTSource) + # src.__setstate__(state["kernel"]["src"]) + # TODO(jansel): need to fixup src.fn which is now None + kernel = CompiledKernel.__new__(CompiledKernel) + metadata = state["kernel"]["metadata"] + packed_metadata = state["kernel"]["packed_metadata"] + kernel.__dict__.update( + { + **state["kernel"], + # "src": src, + "metadata": self._deserialize_metadata(metadata), + "packed_metadata": self._deserialize_metadata(packed_metadata), + } + ) + self.__dict__.update(state) + self.kernel = kernel + + def make_launcher(self) -> LauncherType: + """ + Launching triton kernels is performance sensitive, we compile + a custom Python function get the grid() and reorder the args to + the underlying wrapper. + """ + cfg = self.config + compile_meta = self.compile_meta + binary = self.kernel + fn = binary.src.fn + binary._init_handles() + (call_args, def_args, none_args) = self._get_arg_lists( + fn.arg_names, get_constexprs(fn) + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + if knobs is None: + launch_enter = binary.__class__.launch_enter_hook + launch_exit = binary.__class__.launch_exit_hook + else: + launch_enter = knobs.runtime.launch_enter_hook + launch_exit = knobs.runtime.launch_exit_hook + + import math as math_lib + + import triton as triton_lib + + import torch as torch_lib + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": launch_enter, + "launch_exit_hook": launch_exit, + "metadata": ( + binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata + ), + "shared": binary_shared, + "num_warps": ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ), + "cta_args": ( + ( + binary.num_ctas, + *get_first_attr(binary, "cluster_dims", "clusterDims"), + ) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + and hasattr(binary.metadata, "num_ctas") + and hasattr(binary.metadata, "cluster_dims") + else () + ) + ), + "function": get_first_attr(binary, "function", "cu_function"), + "runner": get_first_attr(binary, "run", "c_wrapper"), + "math": math_lib, + "torch": torch_lib, + "triton": triton_lib, + } + + if not hasattr(binary, "launch_metadata"): + # launch args before CompiledKernel.launch_metadata is added. + # TODO(jansel): delete this branch in mid-2025 + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "num_warps", + "*cta_args", + "shared", + "stream", + "function", + "launch_enter_hook", + "launch_exit_hook", + "metadata", + *call_args, + ] + else: # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492 + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if launch_enter: + launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})" + else: + launch_metadata = "None" + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "stream", + "function", + "metadata", + launch_metadata, + "launch_enter_hook", + "launch_exit_hook", + *call_args, + ] + + launcher = self._gen_launcher_code(scope, def_args, runner_args) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.cache_hash = triton_hash_to_path_key(binary.hash) + launcher.store_cubin = self.inductor_meta.get("store_cubin", False) + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = fn + launcher.bin = binary + if triton_version_uses_attrs_dict(): + # arg filtering wasn't done above + cfg_dict = config_to_dict(cfg) + def_args = [x for x in def_args if x not in cfg_dict] + call_args = [ + x + for x in call_args + if compile_meta["signature"].get(x, "constexpr") != "constexpr" + and x not in none_args + ] + launcher.def_args = def_args + launcher.call_args = call_args + kernel_metadata = getattr(self.kernel, "metadata", None) + + # for the scratch arguments: None indicates that the kernel doesn't + # take any scratch argument; otherwise a number indicates the number + # of bytes of scratch that need to be provided. + + # in AMD's Triton backend, the global scratch size is never provided + # (but for AMD it's safe to pass an extra null arg, so always include it) + global_scratch: int | None = getattr( + kernel_metadata, + "global_scratch_size", + (0 if torch.version.hip else None), + ) + profile_scratch: int | None = getattr( + kernel_metadata, "profile_scratch_size", None + ) + launcher.global_scratch = global_scratch + launcher.profile_scratch = profile_scratch + return launcher + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: list[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(output_file): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s" + ) + log.info( + "%s", + summary_str, + ) + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.info( + "Save profile bandwidth results to %s", + output_file, + ) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms / overall_time * 100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception: + log.warning( + "failed to write profile bandwidth result into %s", + output_file, + exc_info=True, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__( + self, + *args, + regex_filter="", + with_profiler=False, + with_bandwidth_info=True, + **kwargs, + ): + self.regex_filter = regex_filter + self.with_profiler = with_profiler + self.with_bandwidth_info = with_bandwidth_info + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, stream, **kwargs): + if not self.with_bandwidth_info: + super().run(*args, stream=stream, **kwargs, benchmark_run=True) + return + else: + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + (launcher,) = self.launchers + + if launcher.store_cubin: + self.save_gpu_kernel(stream, launcher) + + if self.cached is None: + ms = self.bench(launcher, *args, with_profiler=self.with_profiler) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = ms, num_gb, gb_per_s, kernel_name + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + log.info( + "%s", + create_bandwidth_info_str( + ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}" + ), + ) + else: + # in AOTI, we will call the kernel and its timing info has been cached already + collected_calls.append(self.cached) + + +def hash_configs(configs: list[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def cached_autotune( + size_hints: list[int] | None, + configs: list[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + inductor_meta = {} if inductor_meta is None else inductor_meta + + configs, autotune_cache, autotune_cache_info = check_autotune_cache( + configs, filename, inductor_meta + ) + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: list[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if inductor_meta.get("profile_bandwidth"): + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + autotune_cache_info=autotune_cache_info, + ) + + return decorator + + +def unique_configs(configs: list[Config]): + """Remove duplicate configurations""" + seen: OrderedSet[Hashable] = OrderedSet() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = TRITON_MAX_BLOCK[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def check_max_block(cfg: dict[str, int]): + """ + Check that block sizes are within the maximum allowed. + """ + for var, val in cfg.items(): + block_suffix = "BLOCK" + if block_suffix in var: + prefix = var.removesuffix(block_suffix) + max_block = TRITON_MAX_BLOCK[prefix] + assert val <= max_block, ( + f"'{var}' too large. Maximum: {max_block}. Actual: {val}." + ) + + +def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): + # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, + # therefore using half the number of warps here correspondingly. + if torch.version.hip: + max_num_warps = (max_num_warps + 1) // 2 + min_num_warps = (min_num_warps + 1) // 2 + # persistent reduction is register intensive + if register_intensive: + max_num_warps = max_num_warps // 2 + return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) + + +def _check_max_grid_x(size_hints, x, num_warps): + # Check if maxGridSize is exceeded - if so then must scale XBLOCK further + max_grid_x = 2147483647 + warp_size = ( + 64 if torch.version.hip else 32 + ) # TODO: query warp size once #129663 is merged + num_blocks = (size_hints["x"] + x - 1) // x + + while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]: + x *= 2 # Scale up XBLOCK if grid exceeds limits + num_blocks = num_blocks // 2 + if (num_blocks * num_warps * warp_size) > max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) + return x, num_blocks + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + if y: + y = min(y, size_hints["y"]) + if z: + z = min(z, size_hints["z"]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and ( + x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"]) + and ( + y * maxGridSize[1] < size_hints["y"] + or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"]) + and ( + z * maxGridSize[2] < size_hints["z"] + or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + if conditional_product(x, y, z) >= 128 and not torch.version.hip: + num_warps = max(num_warps, 4) + xnumel = size_hints["x"] + ynumel = size_hints.get("y") + znumel = size_hints.get("z") + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x = min(x, size_hints["x"]) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + check_max_block(cfg) + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config + + +def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: + """ + Converts a linear reduction numel to ND, in row major order. + This order is often desirable as it presents opportunities to coalesce memory + accesses. + For example, if r = 64 and size_hints = [32,32], this function returns [32, 2]. + This unraveling works because both r and size_hints are powers of 2. + """ + # Shrink r to size_hints. + r = min(r, get_total_reduction_numel(size_hints)) + num_reduction_dims = len( + [prefix for prefix in size_hints if prefix_is_reduction(prefix)] + ) + + remaining = r + rnumels = {} + for idx in range(num_reduction_dims - 1, -1, -1): + prefix = f"r{idx}_" + max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()]) + dim = min(max_size, remaining) + assert remaining % dim == 0, ( + f"Expected dimension '{dim}' to divide remaining size '{remaining}'" + ) + rnumels[prefix] = dim + remaining //= dim + + # Sanity check the results. + final_numel = conditional_product(*rnumels.values()) + assert r == final_numel, ( + f"Expected ND reduction size ({rnumels}) to have {r} elements." + ) + assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), ( + f"rnumels exceed size_hints. {rnumels} > {size_hints}" + ) + + return rnumels + + +def triton_config_reduction( + size_hints, + x: int, + r: int, + num_stages=1, + num_warps=None, + register_intensive=False, + dynamic_scale_rblock=True, + reduction_hint=None, + min_num_warps=None, +) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + + def total_numel() -> int: + return conditional_product(x, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + + if num_warps is None: + if reduction_hint == ReductionHint.INNER: + # r is contiguous, ensure at least 8 elements per thread + # xblock is usually 1-2, default to giving each thread more work + num_warps = r // 128 + else: + num_warps = total_numel() // 128 + + max_num_warps = 16 if r <= 8192 else 32 + if min_num_warps is not None: + _num_warps_func = functools.partial(_num_warps, min_num_warps=min_num_warps) + else: + _num_warps_func = _num_warps + + num_warps = _num_warps_func( + num_warps, max_num_warps=max_num_warps, register_intensive=register_intensive + ) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + + for prefix in sorted(rnumels): + while total_numel() > target: + if rnumels[prefix] == 1: + break + rnumels[prefix] //= 2 + + cfg = _get_config({"x": x, **rnumels}) + check_max_block(cfg) + check_config(cfg, xnumel=size_hints["x"]) + return InductorConfig( + cfg, + num_warps=num_warps, + num_stages=num_stages, + dynamic_scale_rblock=dynamic_scale_rblock, + ) + + +def _get_config(numels: dict[str, int]) -> dict[str, int]: + """ + Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc. + """ + + return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} + + +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False +): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + y = min(y, size_hints["y"]) + + def total_numel() -> int: + return conditional_product(x, y, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + while y < size_hints["y"] and total_numel() < target: + y *= 2 + + cfg = _get_config({"x": x, "y": y, **rnumels}) + num_warps = _num_warps(total_numel() // 256, min_num_warps=1) + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) + check_max_block(cfg) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]): + tma_min_block_sizes: dict[str, int] + if (tma_min_block_sizes := inductor_meta.get("tma_min_block_sizes")) and configs: + # Rn blocks are not provided to the kernel for persistent reductions + if inductor_meta.get("persistent_reduction"): + tma_min_block_sizes = { + block_type: block_size + for block_type, block_size in tma_min_block_sizes.items() + if not prefix_is_reduction(block_type.lower()) + } + + assert all( + block_type in configs[0].kwargs for block_type in tma_min_block_sizes + ) + + # Add a config that is guaranteed to compile + example_config = configs[0] + config_block_sizes = {**example_config.kwargs} + config_block_sizes.update(tma_min_block_sizes) + new_configs = [ + Config( + config_block_sizes, + num_warps=example_config.num_warps, + num_stages=example_config.num_stages, + maxnreg=example_config.maxnreg, + pre_hook=example_config.pre_hook, + ) + ] + # Remove configs that will not compile + for c in configs: + if all( + c.kwargs.get(block_type) >= min_block_value + for block_type, min_block_value in tma_min_block_sizes.items() + ): + new_configs.append(c) + + log.debug( + "Filtering configs for TMA API restrictions. Input configs size: %d. Output configs size: %d", + len(configs), + len(new_configs), + ) + return new_configs + return configs + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints.values()) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", OrderedSet()), + size_hints, + bs, + triton_meta["device"], + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + configs = None + if len(size_hints) == 1: + if not inductor_meta.get("autotune_pointwise", True) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, bs)] + else: + configs = [ + triton_config_with_settings(size_hints, bs, num_elements_per_warp=256), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ] + # Additional configs appended for ROCm builds + if torch.version.hip: + configs.extend( + [ + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), + triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1, # 20% improvement + ), + ] + ) + if inductor_meta.get("atomic_add_found"): + configs.extend( + [ + triton_config_with_settings( + size_hints, + 64, + num_warps=1, + num_stages=1, # 250% improvement + ) + ] + ) + if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here + if ( + not inductor_meta.get("autotune_pointwise", True) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) + ) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, 32, 32)] + else: + configs = [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ] + # Additional configs appended for ROCm builds + if torch.version.hip: + configs.extend( + [ + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings( + size_hints, 128, 32 + ), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels + ] + ) + if len(size_hints) == 3: + if not inductor_meta.get("autotune_pointwise", True): + configs = [triton_config_with_settings(size_hints, 16, 16, 16)] + else: + configs = [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ] + + if not configs: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def make_matmul_triton_config(sizes: dict[str, int], num_warps: int, num_stages: int): + config = { + "XBLOCK": sizes.get("x"), + "YBLOCK": sizes.get("y"), + "ZBLOCK": sizes.get("z"), + "R0_BLOCK": sizes.get("r"), + } + # Remove keys with None values (i.e., missing in sizes) + config = {k: v for k, v in config.items() if v is not None} + return Config(config, num_warps=num_warps, num_stages=num_stages) + + +def _config_helper(bmm=False, persistent=False): + # Each entry is: (sizes_dict, num_warps, num_stages) + _base_mm_configs = [ + ({"x": 32, "y": 32, "r": 16}, 2, 1), + ({"x": 32, "y": 32, "r": 128}, 4, 2), + ({"x": 32, "y": 64, "r": 32}, 8, 5), + ({"x": 64, "y": 32, "r": 32}, 8, 5), + ({"x": 64, "y": 32, "r": 128}, 4, 5), + ({"x": 64, "y": 64, "r": 16}, 4, 2), + ({"x": 64, "y": 64, "r": 32}, 4, 2), + ({"x": 64, "y": 64, "r": 64}, 8, 3), + ({"x": 64, "y": 64, "r": 128}, 4, 5), + ({"x": 64, "y": 128, "r": 32}, 4, 3), + ({"x": 64, "y": 128, "r": 32}, 8, 4), + ({"x": 64, "y": 128, "r": 64}, 4, 3), + ({"x": 64, "y": 128, "r": 128}, 4, 4), + ({"x": 128, "y": 64, "r": 32}, 4, 3), + ({"x": 128, "y": 64, "r": 32}, 8, 4), + ({"x": 128, "y": 128, "r": 32}, 8, 2), + ({"x": 128, "y": 128, "r": 32}, 4, 3), + ({"x": 128, "y": 128, "r": 64}, 4, 3), + ({"x": 128, "y": 128, "r": 64}, 8, 5), + ] + out = [] + for sizes, w, s in _base_mm_configs: + d = dict(sizes) + if persistent: + d.pop("r", None) + if bmm: + d["z"] = 1 + out.append((d, w, s)) + + # Deduplicate by converting dicts to immutable frozensets + deduped = {(frozenset(d.items()), w, s): (d, w, s) for d, w, s in out} + + return list(deduped.values()) + + +triton_native_mm_configs = _config_helper(bmm=False, persistent=False) +triton_native_persistent_mm_configs = _config_helper(bmm=False, persistent=True) +triton_native_bmm_configs = _config_helper(bmm=True, persistent=False) +triton_native_persistent_bmm_configs = _config_helper(bmm=True, persistent=True) + + +def _reduction_configs( + *, + size_hints: dict[str, int], + inductor_meta: dict[str, Any], + triton_meta: dict[str, Any], + num_dynamic=0, +) -> list[Config]: + reduction_hint = inductor_meta.get("reduction_hint") + + # Convert reductions to 1D, to simplify heuristics. + rnumel = get_total_reduction_numel(size_hints) + + register_intensive = False + MAX_R0_BLOCK = 2048 + loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_reduction", 0 + ) + if size_hints["x"] >= 1024 and loads_and_red >= 10: + # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. + # Consider load and reduction since load need move data into registers and + # reduction needs an accumulator. + # + # The magic numbers are a bit arbitrary. + # + # We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes + # triton makes it to use less registers with worse perf. Check: + # https://github.com/pytorch/pytorch/issues/126463 + # + # The heuristic is a very simple one since registers can be reused. But + # hopefully it can be a good enough indicator. + MAX_R0_BLOCK = 1024 + register_intensive = True + + if triton_meta.get("native_matmul"): + if len(size_hints) == 3: + return [ + make_matmul_triton_config(sizes, num_warps, num_stages) + for sizes, num_warps, num_stages in triton_native_mm_configs + ] + elif len(size_hints) == 4: + return [ + make_matmul_triton_config(sizes, num_warps, num_stages) + for sizes, num_warps, num_stages in triton_native_bmm_configs + ] + else: + raise NotImplementedError("native matmul only supports mm/bmm pattern") + + def make_config( + x, + r, + num_warps=None, + num_stages=1, + register_intensive=False, + dynamic_scale_rblock=True, + ): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + dynamic_scale_rblock=dynamic_scale_rblock, + reduction_hint=reduction_hint, + ) + + def outer_config_opt(): + # Default to 64 for vectorized loads + max_x_block, x_block = 256, 64 + load_factor = inductor_meta.get("num_load", 0) + x = size_hints["x"] + num_warps = None + + # Try to use all SMs with small x + if x <= 1024: + x_block = max(min(x // 128, 8), 2) + outer_r_block = min(rnumel, 64) + # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs + elif x // 4096 <= 8: + x_block = 16 + outer_r_block = 512 // x_block + elif num_dynamic > 1: + # Lots of compute with multiple dynamic shape per loop iteration + # Larger RBLOCK minimizes loop iteration + outer_r_block = max(min((rnumel // 64), 64), 8) + elif num_dynamic == 1: + # Dynamic shapes introduce a lot register pressure for indexing + outer_r_block = ( + 1 + if load_factor >= 3 + else min(next_power_of_2(max(rnumel, 128) // 128), 8) + ) + else: + x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block) + if load_factor < 4 or rnumel <= 128: + outer_r_block = 512 // x_block + else: + # Heavier reductions contain a lot more overhead per loop iteration + # We minimize the overhead by enlarging r block + if rnumel >= 2048: + outer_r_block = 64 + else: + outer_r_block = 32 + x_block = min(x_block, 32) + num_warps = 4 + + # Set register intensive to true by default as we try to maximize tiles with heuristic + return make_config( + x_block, + outer_r_block, + num_warps=num_warps, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( + 2 if rnumel <= 2048 else 1, # 1024 or less is persistent + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + tiny_config = make_config( + 2 * (256 // rnumel) if rnumel <= 256 else 1, + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + + outer_config = make_config(64, 8, register_intensive=register_intensive) + # TODO (paulzhan): Test heuristic on AMD and internal testing + # for correctness + if not torch.version.hip: + outer_config = outer_config_opt() + + configs = [] + + if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8: + xnumel = max(4096 // rnumel, 1) + c = make_config( + xnumel, + min(rnumel, 32768), + register_intensive=register_intensive, + dynamic_scale_rblock=False, + ) + configs.append(c) + + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ): + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return configs + [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return configs + [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return configs + [tiny_config] + + return configs + [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and relative_scores: + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + ) + + +def filter_reduction_configs_for_determinism( + inductor_meta: dict[str, Any], configs: list[Config] +) -> list[Config]: + """ + Filter configs for reduction so the numerics can be deterministic. + + Heuristics: + - skip reduction configs with too small RBLOCK + - skip reduction configs with XBLOCK==1 if we are confident it will not perform well + - if there is a tie, pick the config with second largest RBLOCK + - if there is still a tie, pick the config with second largest num_warps + - if there is still a tie, pick the config with second largest XBLOCK + """ + configs = unique_configs(configs) + assert len(configs) > 0 + + def _do_filter_due_to_inductor_config(): + return ( + inductor_meta.get("deterministic", False) + or inductor_meta.get("force_filter_reduction_configs", False) + ) or inductor_meta.get("are_deterministic_algorithms_enabled") + + if not _do_filter_due_to_inductor_config() or len(configs) == 1: + # no filtering happening if NOT in deterministic mode + return configs + + if log.isEnabledFor(logging.DEBUG): + log.debug("reduction configs before filtering:") + for c in configs: + log.debug("%s", c) + log.debug("") + + def _has_too_small_rblock(config): + rblock = config.kwargs.get("R0_BLOCK") + # too small RBLOCK is likely to be bad + return rblock is not None and rblock <= 4 + + def _nonpromising_xblock_1(config): + # kernel like https://gist.github.com/shunting314/0b3281c087e79bc915fe45985ff9d7d5 + # without a load/store having contiguous rdim is unlikely to perform well with XBLOCK==1 + return config.kwargs["XBLOCK"] == 1 and not inductor_meta.get( + "has_loadstore_with_contiguous_rdim", True + ) + + newconfigs = [*filter(lambda x: not _has_too_small_rblock(x), configs)] + # accept the filtering only if there are configs left + if len(newconfigs) > 0: + configs = newconfigs + + newconfigs = [*filter(lambda x: not _nonpromising_xblock_1(x), configs)] + if len(newconfigs) > 0: + configs = newconfigs + + assert len(configs) > 0 + + def _r0_block(c): + return c.kwargs.get("R0_BLOCK", -1) + + def _xblock(c): + return c.kwargs.get("XBLOCK", -1) + + def _num_warps(c): + return c.num_warps + + def _pick_second_largest(accessor): + nonlocal configs + configs = sorted(configs, key=lambda x: accessor(x)) + if accessor(configs[0]) != accessor(configs[-1]): + max_val = accessor(configs[-1]) + configs = [*filter(lambda x: accessor(x) != max_val, configs)] + second_max_val = accessor(configs[-1]) + configs = [*filter(lambda x: accessor(x) == second_max_val, configs)] + return configs + + def _pick_config(): + nonlocal configs + assert len(configs) > 0 + if len(configs) == 1: + return configs[0] + + # break tie by R0_BLOCK + configs = _pick_second_largest(_r0_block) + if len(configs) == 1: + return configs[0] + + # break tie by num_warps + configs = _pick_second_largest(_num_warps) + if len(configs) == 1: + return configs[0] + + # break tie by XBLOCK + configs = _pick_second_largest(_xblock) + + # there is still a tie, pick the first one + return configs[0] + + configs = [_pick_config()] + + if log.isEnabledFor(logging.DEBUG): + log.debug("reduction configs after filtering:") + for c in configs: + log.debug("%s", c) + log.debug("") + return configs + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + + num_dynamic = 0 + for k in triton_meta["signature"]: + if "ks" in k: + num_dynamic += 1 + + configs = _reduction_configs( + size_hints=size_hints, + inductor_meta=inductor_meta, + triton_meta=triton_meta, + num_dynamic=num_dynamic, + ) + + configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + configs = filter_reduction_configs_for_determinism(inductor_meta, configs) + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def cooperative_reduction( + size_hints, + reduction_hint, + triton_meta, + filename, + inductor_meta, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + # Cooperative reductions currently only support a single reduction dimension. + assert len(size_hints) == 2, ( + "Cooperative reductions don't support tiling reduction dims" + ) + xnumel, rnumel = size_hints["x"], size_hints["r0_"] + + # TODO(jansel): we should base target on the SM count of the local GPU + target = 64 + split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT)) + assert rnumel >= split + assert split <= TRITON_MAX_RSPLIT + if inductor_meta["persistent_reduction"]: + configs = _persistent_reduction_configs( + {"x": xnumel, "r0_": rnumel // split}, + reduction_hint, + inductor_meta, + triton_meta, + ) + else: + configs = _reduction_configs( + size_hints={"x": xnumel, "r0_": rnumel // split}, + inductor_meta=inductor_meta, + triton_meta=triton_meta, + ) + for config in configs: + config.kwargs["RSPLIT"] = split + # TODO(jansel): add more configs in max_autotune + + configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + configs = filter_reduction_configs_for_determinism(inductor_meta, configs) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def _persistent_reduction_configs( + size_hints, + reduction_hint=False, + inductor_meta=None, + triton_meta=None, +): + xnumel = size_hints["x"] + rnumel = get_total_reduction_numel(size_hints) + + MAX_PERSISTENT_BLOCK_NUMEL = 4096 + + if triton_meta.get("native_matmul"): + if len(size_hints) == 3: + return [ + make_matmul_triton_config(sizes, num_warps, num_stages) + for sizes, num_warps, num_stages in triton_native_persistent_mm_configs + ] + elif len(size_hints) == 4: + return [ + make_matmul_triton_config(sizes, num_warps, num_stages) + for sizes, num_warps, num_stages in triton_native_persistent_bmm_configs + ] + else: + raise NotImplementedError("native matmul only supports mm/bmm pattern") + + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + + if "y" not in size_hints: + configs = [ + triton_config_reduction( + size_hints, + xblock, + rnumel, + register_intensive=True, + reduction_hint=reduction_hint, + ) + for xblock in xblock_vals + if xblock == 1 + or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + ] + else: + configs = [] + assert "tiling_scores" in inductor_meta + x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} + for target_block_size in xblock_vals: + if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + continue + + block_sizes = match_target_block_product( + size_hints, x_y_scores, target_block_size + ) + configs.append( + triton_config_tiled_reduction( + size_hints, block_sizes["x"], block_sizes["y"], rnumel + ) + ) + + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + # defer to more autotuning, initially + if "y" in size_hints: + pass + # TODO(jansel): we should be able to improve these heuristics + elif not max_autotune_enabled: # Do not filter configs when tuning + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + if rnumel > 1024 or xnumel // 8 < 128 or inductor_meta.get("RSPLIT_SIZE"): + configs = configs[:1] + else: + if not torch.cuda.is_available(): + # TODO(Intel): CUDA uses num_warps = 1 to disable shared memory. + # We apply different configurations from #168335. + # We currently let cost model in Triton to decide whether to use shared memory. + loads_and_stores = inductor_meta.get( + "num_load", 0 + ) + inductor_meta.get("num_store", 0) + x_block = 8 + if xnumel // x_block < 128 or loads_and_stores >= 5: + x_block = 1 + num_warps, min_num_warps, reduction_hint = None, None, None + else: + x_block = min(1024 // rnumel, 8) + num_warps, min_num_warps = 1, 1 + configs = [ + triton_config_reduction( + size_hints, + x_block, + rnumel, + register_intensive=True, + num_warps=num_warps, + min_num_warps=min_num_warps, + reduction_hint=reduction_hint, + ) + ] + + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + if torch.version.hip: + # If autotune is enabled append tiny configs + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + + for c in configs: + # we don't need Rn_BLOCK for persistent reduction + for prefix in size_hints: + if prefix_is_reduction(prefix): + c.kwargs.pop(f"{prefix.upper()}BLOCK") + + return configs + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + configs = _persistent_reduction_configs( + size_hints, reduction_hint, inductor_meta, triton_meta + ) + + # This key is not added to the inductor meta as its clear from the heuristic + # choice that it is persistent. Add it and remove it below so that persistent + # configs can be filtered appropriately by _maybe_filter_configs_for_tma_restrictions + persistent_reduction_key = "persistent_reduction" + inductor_meta[persistent_reduction_key] = True + configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + inductor_meta.pop(persistent_reduction_key) + + if inductor_meta.get("RSPLIT_SIZE"): + new_configs = [] + rsplit_size = inductor_meta.get("RSPLIT_SIZE") + rnumel_hint = size_hints["r0_"] + min_x_block = 1 + if rnumel_hint <= 512: + min_x_block = 4 + x_block = min(max(rsplit_size // 32, min_x_block), 16) + for c in configs: + c.kwargs["RSPLIT_SIZE"] = rsplit_size + # small XBLOCK to use less registers/smem + c.kwargs["XBLOCK"] = x_block + + num_iters = rsplit_size // x_block + c.kwargs["NUM_STAGES"] = min(max(num_iters // 4, 1), 3) + + if rnumel_hint <= 1024: + c.num_warps //= 2 + c.num_warps = max(c.num_warps, 1) + new_configs.append(c) + + # less warps so potentially each sm can run more thread blocks + # Inside each thread block, we handle the split sequentially, + # more thread blocks is beneficial here. + newc = copy.deepcopy(c) + newc.num_warps = 2 + new_configs.append(newc) + else: + # more warps for larger rows + new_configs.append(c) + + if c.num_warps < 32: + newc = copy.deepcopy(c) + newc.num_warps *= 2 + new_configs.append(newc) + + configs = unique_configs(new_configs) + + configs = filter_reduction_configs_for_determinism(inductor_meta, configs) + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs( + size_hints=size_hints, inductor_meta=inductor_meta, triton_meta=triton_meta + ) + + # Fixup configs to enforce the minimum Rn_BLOCK size + min_rblock = inductor_meta.get("min_split_scan_rblock", 256) + for cfg in configs: + for var in list(cfg.kwargs.keys()): + if var.startswith("R") and cfg.kwargs[var] < min_rblock: + cfg.kwargs[var] = min_rblock + + configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) + configs = filter_reduction_configs_for_determinism(inductor_meta, configs) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template( + num_stages, + num_warps, + triton_meta, + num_consumer_groups=0, + num_buffers_warp_spec=0, + filename=None, + inductor_meta=None, +): + """ + Compile a triton template + """ + # Prepare the base configuration + config_args = { + "num_stages": num_stages, + "num_warps": num_warps, + } + + # Conditionally add arguments based on HAS_WARP_SPEC + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + } + ) + return cached_autotune( + None, + [triton.Config({}, **config_args)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]: + """Extract triton.Config options that should become kwargs""" + popped = {} + for key in ( + "num_warps", + "num_stages", + "num_ctas", + "maxnreg", + "num_consumer_groups", + "num_buffers_warp_spec", + ): + val = config.pop(key, None) + if val is not None: + popped[key] = val + return popped + + +def config_to_dict(config: Config) -> dict[str, Any]: + config_dict = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + if HAS_WARP_SPEC: + config_dict.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0), + } + ) + return config_dict + + +def config_from_dict(config: dict[str, Any]) -> Config: + config = {**config} + return Config(config, **_pop_config_kwargs(config)) + + +def fixed_config(config, filename, triton_meta, inductor_meta): + """ + Used when the configuration is already decided at compile time + """ + config = {**config} + return cached_autotune( + None, + [triton.Config(config, **_pop_config_kwargs(config))], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.FIXED, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + if len(configs) == 0: + configs = [triton.Config({})] + else: + configs = [*map(config_from_dict, configs)] + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dataclasses.dataclass +class GridExpr: + """Generate code for grid size expressions in launcher""" + + inductor_meta: dict[str, Any] + mode: Literal["python", "cpp"] = "python" + prefix: list[str] = dataclasses.field(default_factory=list) + x_grid: str | int = 1 + y_grid: str | int = 1 + z_grid: str | int = 1 + + def __post_init__(self) -> None: + assert self.mode in ("python", "cpp") + + def generate(self, meta: dict[str, int]) -> None: + raise NotImplementedError + + def ceildiv(self, numel: str | int, block: None | int | str) -> str | int: + if block is None or block == 1: + return numel + if isinstance(numel, int) and isinstance(block, int): + return ceildiv(numel, block) # constant fold + # This trick only works in python, where + # negative integer division is floored + if self.mode == "python": + return f"-(({numel}) // -({block}))" + # For cpp code gen + return f"(({numel} + ({block} - 1)) / ({block}))" + + def maximum(self, seq: list[int | str]) -> int | str: + """Codegen for max function with constant folding, constants are represented as int""" + items = self._constant_fold(max, seq) + if len(items) <= 1: + return items[0] + if self.mode == "python": + return f"max({', '.join(map(str, items))})" + return functools.reduce(lambda x, y: f"std::max({x}, {y})", items) + + def summation(self, seq: list[int | str]) -> int | str: + """Codegen for sum function with constant folding, constants are represented as int""" + items = self._constant_fold(sum, seq) + if len(items) <= 1: + return items[0] + return " + ".join(map(str, items)) + + def _constant_fold( + self, fn: Callable[[list[int]], int], seq: list[int | str] + ) -> list[int | str]: + """Constant fold through a commutative fn where ints are constants""" + items: list[int | str] = [x for x in seq if not isinstance(x, int)] + const_items = [x for x in seq if isinstance(x, int)] + if const_items: + items.append(fn(const_items)) + return items + + def assign_tmp(self, name: str, expr: str | int) -> str: + # Grid functions are one per kernel, so name collisions are fine + if self.mode == "python": + return f"{name} = {expr}" + if self.mode == "cpp": + return f"uint32_t {name} = {expr};" + raise AssertionError(f"invalid mode {self.mode}") + + @staticmethod + def from_meta( + inductor_meta: dict[str, Any], + cfg: Config | dict[str, int], + mode: Literal["python", "cpp"] = "python", + ) -> GridExpr: + grid_cls = globals()[inductor_meta["grid_type"]] + assert issubclass(grid_cls, GridExpr) + grid = grid_cls(inductor_meta=inductor_meta, mode=mode) + if isinstance(cfg, Config): + cfg = config_to_dict(cfg) + grid.generate(cfg) + return grid + + def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]: + scope = {**meta} + for line in self.prefix: + exec(line, scope) + exec(f"grid_0 = {self.x_grid}", scope) + exec(f"grid_1 = {self.y_grid}", scope) + exec(f"grid_2 = {self.z_grid}", scope) + return scope["grid_0"], scope["grid_1"], scope["grid_2"] + + +class Grid1D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class Grid2D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + + +class Grid3D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK")) + + +class Grid2DWithYZOverflow(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.prefix.extend( + [ + self.assign_tmp( + "y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK")) + ), + self.assign_tmp( + "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid()) + ), + ] + ) + self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") + self.z_grid = "y_grid_div_" + + +class MixOrderReductionGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + split_size = meta.get("RSPLIT_SIZE") + xblock = meta.get("XBLOCK") + assert split_size, "Missing RSPLIT_SIZE" + assert xblock, "Missing XBLOCK" + assert split_size % xblock == 0, f"{split_size=}, {xblock=}" + self.x_grid = self.ceildiv("xnumel", split_size) + + +class CooperativeReductionGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = str(meta["RSPLIT"]) + self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class SplitScanGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + assert meta.get("XBLOCK", 1) == 1 + self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK")) + self.y_grid = "xnumel" + + +class FixedGrid(GridExpr): + @staticmethod + def setup_grid_as_args() -> dict[str, Any]: + """Inductor meta so the launcher takes three extra grid arguments""" + return { + "grid_type": FixedGrid.__name__, + "fixed_grid": ["_grid_0", "_grid_1", "_grid_2"], + "extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"], + } + + def generate(self, meta: dict[str, int]) -> None: + self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"] + + +class PrecomputedGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + for candidate in self.inductor_meta["precomputed_grids"]: + if all(meta.get(k) == v for k, v in candidate["config"].items()): + self.x_grid, self.y_grid, self.z_grid = candidate[self.mode] + return + raise AssertionError( + f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}" + ) + + +class ComboKernelGrid(GridExpr): + def generate(self, meta: dict[str, int]): + combo_meta = self.inductor_meta["combo_grid_meta"] + if combo_meta["default_config"]: + meta = {**combo_meta["default_config"], **meta} + no_x_dims = [] + xnumels = [] + ynumels = [] + znumels = [] + for num in range(combo_meta["num_kernels"]): + assert ( + combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0 + ) + no_x_dims.append(combo_meta[f"no_x_dim_{num}"]) + xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}") + if f"ynumel_{num}" in combo_meta: + ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}") + if f"znumel_{num}" in combo_meta: + znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}") + + self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta) + if combo_meta["min_blocks"]: + self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]]) + if ynumels: + self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK")) + if znumels: + self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK")) + + def combo_x_grid( + self, + xnumels: list[int | str], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> str | int: + raise NotImplementedError + + +class SequentialComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[int | str], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> str | int: + assert len(xnumels) == len(no_x_dims) + return self.summation( + [ + self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK")) + for x, no_x_dim in zip(xnumels, no_x_dims) + ] + ) + + +class RoundRobinComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[int | str], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> str: + assert len(xnumels) == len(no_x_dims) + num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"] + exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim] + xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim] + if xnumels_x_dim: + exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK"))) + return f"({self.maximum(exprs)}) * {num_kernels}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1ea2dd8a425eedf69670f5fa634f31566ed643 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c681d632a4a34cb132a59cc9f525502d2a23c110 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/closure.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d16468efa9d3ed3294faddacbd35ed6063a1d45 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/computation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ef6f5fc602ca3c036000e8f78a10e89c9ad1ad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9916dd91d7d4f0d1d05d59379b3372b3950e621e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/debug.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4d711c990fbcd6d4e156e4dfdae8a0cfb1cd9e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/device_context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f18622070d66edf6fc59ecafae37efcd99f50fe6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd80ba1982bda0884be540d9a568afd72b6ff433 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b5dfbdb6ac488bb10436bfc6105e781714def1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/metrics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f6960e16c62b4210eeec96dcdcd3c4fbd25dc91 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daa12f2081465bd1c196374a965b649e4ec175cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59187daff7bca13274e277248e14773a635e2ab9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b0129b445418b8ddfd0f445a20faa312c5a2e3c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba759024f8310e9ac91f0a7fef487c67dacaee4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..142324394abc2faf4ce86e66f2d82adc39e29373 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae9d239812dca62a712892cd322b5c02a33d8cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..962fbca4d7b8d4abbd8ce644cc57ca8047034ffc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0b532d7cf3851f03cd496d001c70587adedc74 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..117bbecc10276c5e90e62841b4fdec584f52f092 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22809cfd5dc25792d77070c269fc8d111a12eed0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb44ccd2e3f9c2ff65a7d4ce5fe44d8932867d6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d834dbd7715394973e94a8e4426b1eb8d7f3002e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d0cc5589ef4be573cf2de1bea1a78a1863a09a6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6465f9682c886363eea5327dac64bf623a6ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/version.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5faab9bd0dcf28847960162b2b4f13a8a556ef20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cd20b747d95da413c0f97bcebe494237dbdce8f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e7bfb1244a4f98712ba487d2477979477476f52
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72510c8b837eb1eb80d558b059c44c98a2390484
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..3180e56a6baf96b56c88a712a4426108d8c8e2fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py
@@ -0,0 +1,150 @@
+# mypy: allow-untyped-defs
+import hashlib
+import json
+
+import coremltools as ct  # type: ignore[import]
+from coremltools.converters.mil.input_types import TensorType  # type: ignore[import]
+from coremltools.converters.mil.mil import types  # type: ignore[import]
+from coremltools.models.neural_network import quantization_utils  # type: ignore[import]
+
+import torch
+
+
+CT_METADATA_VERSION = "com.github.apple.coremltools.version"
+CT_METADATA_SOURCE = "com.github.apple.coremltools.source"
+
+
+class ScalarType:
+    Float = 0
+    Double = 1
+    Int = 2
+    Long = 3
+    Undefined = 4
+
+
+# Supported Tensor types in coremltools:
+# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28
+torch_to_mil_types = {
+    ScalarType.Float: types.fp32,
+    ScalarType.Double: types.fp64,
+    ScalarType.Int: types.int32,
+    ScalarType.Long: types.int64,
+}
+
+
+class CoreMLComputeUnit:
+    CPU = "cpuOnly"
+    CPUAndGPU = "cpuAndGPU"
+    ALL = "all"
+
+
+class CoreMLQuantizationMode:
+    LINEAR = "linear"
+    LINEAR_SYMMETRIC = "linear_symmetric"
+    NONE = "none"
+
+
+def TensorSpec(shape, dtype=ScalarType.Float):
+    return (shape, dtype)
+
+
+def CompileSpec(
+    inputs,
+    outputs,
+    backend=CoreMLComputeUnit.CPU,
+    allow_low_precision=True,
+    quantization_mode=CoreMLQuantizationMode.NONE,
+    mlmodel_export_path=None,
+    convert_to=None,
+):
+    return (
+        inputs,
+        outputs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+        convert_to,
+    )
+
+
+def _check_enumerated_shape(shape):
+    for s in shape:
+        if not isinstance(s, (list, tuple)):
+            return False
+    return True
+
+
+def _convert_to_mil_type(shape, dtype, name: str):
+    mil_shape = shape
+    if _check_enumerated_shape(shape):
+        mil_shape = ct.EnumeratedShapes(shape)
+    ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
+    ml_type.name = name
+    return ml_type
+
+
+def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]):
+    spec = compile_spec["forward"]
+    (
+        input_specs,
+        output_specs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+        convert_to,
+    ) = spec
+    mil_inputs = []
+    inputs = []
+    for index, input in enumerate(input_specs):
+        shape, dtype = input
+        name = "input_" + str(index)
+        inputs.append([name, str(dtype), str(shape)])
+        ml_type = _convert_to_mil_type(shape, dtype, name)
+        mil_inputs.append(ml_type)
+    model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
+    mlmodel = ct.convert(model, inputs=mil_inputs, convert_to=convert_to)
+
+    if quantization_mode != CoreMLQuantizationMode.NONE:
+        quant_model_spec = quantization_utils.quantize_weights(
+            mlmodel, nbits=8, quantization_mode=quantization_mode
+        )
+        mlmodel = ct.models.MLModel(quant_model_spec)
+
+    spec = mlmodel.get_spec()
+    assert len(spec.description.output) == len(output_specs)  # type: ignore[attr-defined]
+    outputs = []
+    for index, output in enumerate(output_specs):
+        shape, dtype = output
+        name = spec.description.output[index].name  # type: ignore[attr-defined]
+        outputs.append([name, str(dtype), str(shape)])
+    mlmodel = ct.models.model.MLModel(spec)
+    print(mlmodel)
+
+    if mlmodel_export_path is not None:
+        print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}")
+        mlmodel.save(mlmodel_export_path)
+
+    config = {
+        "spec_ver": str(spec.specificationVersion),  # type: ignore[attr-defined]
+        "backend": backend,
+        "allow_low_precision": str(allow_low_precision),
+    }
+    metadata = {
+        "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],
+        "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE],
+    }
+    coreml_compile_spec = {
+        "inputs": inputs,
+        "outputs": outputs,
+        "config": config,
+        "metadata": metadata,
+    }
+    mlmodel = spec.SerializeToString()  # type: ignore[attr-defined]
+
+    return {
+        "model": mlmodel,
+        "hash": str(hashlib.sha256(mlmodel).hexdigest()),
+        "extra": json.dumps(coreml_compile_spec),
+    }
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69185955a29d25e74c3cc9711cde1c58152ad2c6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..630f7e3f2dd27085fa551bc22b6c9d27340c4f80
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fc48d711111ffd417fa1c544bd4b2362e75cf16
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py
@@ -0,0 +1,199 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+from typing import Optional
+
+import torch
+from torch.backends._nnapi.serializer import _NnapiSerializer
+
+
+ANEURALNETWORKS_PREFER_LOW_POWER = 0
+ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
+ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
+
+
+class NnapiModule(torch.nn.Module):
+    """Torch Module that wraps an NNAPI Compilation.
+
+    This module handles preparing the weights, initializing the
+    NNAPI TorchBind object, and adjusting the memory formats
+    of all inputs and outputs.
+    """
+
+    # _nnapi.Compilation is defined
+    comp: Optional[torch.classes._nnapi.Compilation]  # type: ignore[name-defined]
+    weights: list[torch.Tensor]
+    out_templates: list[torch.Tensor]
+
+    def __init__(
+        self,
+        shape_compute_module: torch.nn.Module,
+        ser_model: torch.Tensor,
+        weights: list[torch.Tensor],
+        inp_mem_fmts: list[int],
+        out_mem_fmts: list[int],
+        compilation_preference: int,
+        relax_f32_to_f16: bool,
+    ):
+        super().__init__()
+        self.shape_compute_module = shape_compute_module
+        self.ser_model = ser_model
+        self.weights = weights
+        self.inp_mem_fmts = inp_mem_fmts
+        self.out_mem_fmts = out_mem_fmts
+        self.out_templates = []
+        self.comp = None
+        self.compilation_preference = compilation_preference
+        self.relax_f32_to_f16 = relax_f32_to_f16
+
+    @torch.jit.export
+    def init(self, args: list[torch.Tensor]):
+        assert self.comp is None
+        self.out_templates = self.shape_compute_module.prepare(self.ser_model, args)  # type: ignore[operator]
+        self.weights = [w.contiguous() for w in self.weights]
+        comp = torch.classes._nnapi.Compilation()
+        comp.init2(
+            self.ser_model,
+            self.weights,
+            self.compilation_preference,
+            self.relax_f32_to_f16,
+        )
+
+        self.comp = comp
+
+    def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
+        if self.comp is None:
+            self.init(args)
+        comp = self.comp
+        assert comp is not None
+        outs = [torch.empty_like(out) for out in self.out_templates]
+
+        assert len(args) == len(self.inp_mem_fmts)
+        fixed_args = []
+        for idx in range(len(args)):
+            fmt = self.inp_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt == 0:
+                fixed_args.append(args[idx].contiguous())
+            elif fmt == 1:
+                fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
+            else:
+                raise ValueError("Invalid mem_fmt")
+        comp.run(fixed_args, outs)
+        assert len(outs) == len(self.out_mem_fmts)
+        for idx in range(len(self.out_templates)):
+            fmt = self.out_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt in (0, 2):
+                pass
+            elif fmt == 1:
+                outs[idx] = outs[idx].permute(0, 3, 1, 2)
+            else:
+                raise ValueError("Invalid mem_fmt")
+        return outs
+
+
+def convert_model_to_nnapi(
+    model,
+    inputs,
+    serializer=None,
+    return_shapes=None,
+    use_int16_for_qint16=False,
+    compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
+    relax_f32_to_f16=False,
+):
+    (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    ) = process_for_nnapi(
+        model, inputs, serializer, return_shapes, use_int16_for_qint16
+    )
+
+    nnapi_model = NnapiModule(
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        compilation_preference,
+        relax_f32_to_f16,
+    )
+
+    class NnapiInterfaceWrapper(torch.nn.Module):
+        """NNAPI list-ifying and de-list-ifying wrapper.
+
+        NNAPI always expects a list of inputs and provides a list of outputs.
+        This module allows us to accept inputs as separate arguments.
+        It returns results as either a single tensor or tuple,
+        matching the original module.
+        """
+
+        def __init__(self, mod):
+            super().__init__()
+            self.mod = mod
+
+    wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
+    wrapper_model = torch.jit.script(wrapper_model_py)
+    # TODO: Maybe make these names match the original.
+    arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
+    if retval_count < 0:
+        ret_expr = "retvals[0]"
+    else:
+        ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
+    wrapper_model.define(
+        f"def forward(self, {arg_list}):\n"
+        f"    retvals = self.mod([{arg_list}])\n"
+        f"    return {ret_expr}\n"
+    )
+    return wrapper_model
+
+
+def process_for_nnapi(
+    model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
+):
+    model = torch.jit.freeze(model)
+
+    if isinstance(inputs, torch.Tensor):
+        inputs = [inputs]
+
+    serializer = serializer or _NnapiSerializer(
+        config=None, use_int16_for_qint16=use_int16_for_qint16
+    )
+    (
+        ser_model,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        shape_compute_lines,
+        retval_count,
+    ) = serializer.serialize_model(model, inputs, return_shapes)
+    ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
+
+    # We have to create a new class here every time this function is called
+    # because module.define adds a method to the *class*, not the instance.
+    class ShapeComputeModule(torch.nn.Module):
+        """Code-gen-ed module for tensor shape computation.
+
+        module.prepare will mutate ser_model according to the computed operand
+        shapes, based on the shapes of args.  Returns a list of output templates.
+        """
+
+    shape_compute_module = torch.jit.script(ShapeComputeModule())
+    real_shape_compute_lines = [
+        "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
+    ] + [f"    {line}\n" for line in shape_compute_lines]
+    shape_compute_module.define("".join(real_shape_compute_lines))
+
+    return (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff09959f840c4b8c61147cc2180abc8d5d25b13
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py
@@ -0,0 +1,2231 @@
+# mypy: allow-untyped-defs
+import array
+import enum
+import functools
+import logging
+import operator
+import struct
+import sys
+from typing import NamedTuple, Optional
+
+import torch
+
+
+# TODO: Add type annotations
+# TODO: Check tensor types for ops
+
+
+LOG = logging.getLogger("nnapi_serialize")
+
+
+class NNAPI_OperandCode:
+    FLOAT32 = 0
+    INT32 = 1
+    UINT32 = 2
+    TENSOR_FLOAT32 = 3
+    TENSOR_INT32 = 4
+    TENSOR_QUANT8_ASYMM = 5
+    BOOL = 6
+    TENSOR_QUANT16_SYMM = 7
+    TENSOR_FLOAT16 = 8
+    TENSOR_BOOL8 = 9
+    FLOAT16 = 10
+    TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
+    TENSOR_QUANT16_ASYMM = 12
+
+
+class NNAPI_OperationCode:
+    ADD = 0
+    AVERAGE_POOL_2D = 1
+    CONCATENATION = 2
+    CONV_2D = 3
+    DEPTHWISE_CONV_2D = 4
+    DEPTH_TO_SPACE = 5
+    DEQUANTIZE = 6
+    EMBEDDING_LOOKUP = 7
+    FLOOR = 8
+    FULLY_CONNECTED = 9
+    HASHTABLE_LOOKUP = 10
+    L2_NORMALIZATION = 11
+    L2_POOL_2D = 12
+    LOCAL_RESPONSE_NORMALIZATION = 13
+    LOGISTIC = 14
+    LSH_PROJECTION = 15
+    LSTM = 16
+    MAX_POOL_2D = 17
+    MUL = 18
+    RELU = 19
+    RELU1 = 20
+    RELU6 = 21
+    RESHAPE = 22
+    RESIZE_BILINEAR = 23
+    RNN = 24
+    SOFTMAX = 25
+    SPACE_TO_DEPTH = 26
+    SVDF = 27
+    TANH = 28
+    BATCH_TO_SPACE_ND = 29
+    DIV = 30
+    MEAN = 31
+    PAD = 32
+    SPACE_TO_BATCH_ND = 33
+    SQUEEZE = 34
+    STRIDED_SLICE = 35
+    SUB = 36
+    TRANSPOSE = 37
+    ABS = 38
+    ARGMAX = 39
+    ARGMIN = 40
+    AXIS_ALIGNED_BBOX_TRANSFORM = 41
+    BIDIRECTIONAL_SEQUENCE_LSTM = 42
+    BIDIRECTIONAL_SEQUENCE_RNN = 43
+    BOX_WITH_NMS_LIMIT = 44
+    CAST = 45
+    CHANNEL_SHUFFLE = 46
+    DETECTION_POSTPROCESSING = 47
+    EQUAL = 48
+    EXP = 49
+    EXPAND_DIMS = 50
+    GATHER = 51
+    GENERATE_PROPOSALS = 52
+    GREATER = 53
+    GREATER_EQUAL = 54
+    GROUPED_CONV_2D = 55
+    HEATMAP_MAX_KEYPOINT = 56
+    INSTANCE_NORMALIZATION = 57
+    LESS = 58
+    LESS_EQUAL = 59
+    LOG = 60
+    LOGICAL_AND = 61
+    LOGICAL_NOT = 62
+    LOGICAL_OR = 63
+    LOG_SOFTMAX = 64
+    MAXIMUM = 65
+    MINIMUM = 66
+    NEG = 67
+    NOT_EQUAL = 68
+    PAD_V2 = 69
+    POW = 70
+    PRELU = 71
+    QUANTIZE = 72
+    QUANTIZED_16BIT_LSTM = 73
+    RANDOM_MULTINOMIAL = 74
+    REDUCE_ALL = 75
+    REDUCE_ANY = 76
+    REDUCE_MAX = 77
+    REDUCE_MIN = 78
+    REDUCE_PROD = 79
+    REDUCE_SUM = 80
+    ROI_ALIGN = 81
+    ROI_POOLING = 82
+    RSQRT = 83
+    SELECT = 84
+    SIN = 85
+    SLICE = 86
+    SPLIT = 87
+    SQRT = 88
+    TILE = 89
+    TOPK_V2 = 90
+    TRANSPOSE_CONV_2D = 91
+    UNIDIRECTIONAL_SEQUENCE_LSTM = 92
+    UNIDIRECTIONAL_SEQUENCE_RNN = 93
+    RESIZE_NEAREST_NEIGHBOR = 94
+
+
+class NNAPI_FuseCode:
+    FUSED_NONE = 0
+    FUSED_RELU = 1
+    FUSED_RELU1 = 2
+    FUSED_RELU6 = 3
+
+
+class OperandValueSourceType:
+    IMMEDIATE = 0
+    NUMBERED_BUFFER = 2
+    NUMBERED_MEMORY = 3
+
+
+# Scalar types that appear explicitly in models.
+# These must be kept in sync with
+# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+# TODO: Expose these directly to Python to avoid maintaining this list.
+class TorchScalarTypes(enum.Enum):
+    QUINT8 = 13
+
+
+def approx_equal(lhs, rhs, tolerance=1e-6):
+    return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
+
+
+def tensor_size(op_type, dims):
+    ITEM_SIZES = {
+        NNAPI_OperandCode.TENSOR_FLOAT32: 4,
+        NNAPI_OperandCode.TENSOR_INT32: 4,
+        NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
+        NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
+        NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
+    }
+    size = ITEM_SIZES[op_type]
+    for d in dims:
+        size *= d
+    return size
+
+
+def change_element(tup, index, value):
+    ls = list(tup)
+    ls[index] = value
+    return tuple(ls)
+
+
+class ConvPoolArgs2d(NamedTuple):
+    """Configuration arguments for a convolution."""
+
+    kernel_h: int
+    kernel_w: int
+    stride_h: int
+    stride_w: int
+    pad_t: int
+    pad_b: int
+    pad_l: int
+    pad_r: int
+    dilation_h: int
+    dilation_w: int
+    group: int
+
+
+class DimOrder(enum.Enum):
+    PRESUMED_CONTIGUOUS = 0
+    CHANNELS_LAST = 1
+    SCALAR_OR_VECTOR = 2
+    UNKNOWN_CONSTANT = 999
+
+
+class Operand(NamedTuple):
+    """Representation of an NNAPI operand."""
+
+    # NNAPI operand type.  One of NNAPI_OperandCode.
+    # TODO: Make this an enum.
+    op_type: int
+
+    # This is always the PyTorch shape, which is NCHW for feature maps.
+    # The actual NNAPI operand might have a transposed shape.
+    # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
+    shape: tuple[int, ...]
+
+    # Specifies how the shape of the operand that we define in NNAPI
+    # relates to the shape we track above.
+    # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
+    #   the shape of the PyTorch tensor.
+    # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
+    #   the NNAPI operand will be represented explicitly as NHWC.
+    dim_order: DimOrder
+
+    # Quantization params
+    scale: float
+    zero_point: int
+
+    def use_nchw(self):
+        if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+            return True
+        if self.dim_order is DimOrder.CHANNELS_LAST:
+            return False
+        raise Exception("Unknown dim order")  # noqa: TRY002
+
+
+def broadcast_shapes(shape1, shape2):
+    assert len(shape1) > 0
+    assert len(shape2) > 0
+    s1 = list(shape1)
+    s2 = list(shape2)
+    # TODO: Support non-equal-rank broadcast where semantics match.
+    # This can be tricky for NHWC tensors because dimension orders
+    # don't match between PT and NNAPI, even though semantics match.
+    if len(s1) > len(s2):
+        # s2 = [1] * (len(s1) - len(s2)) + s2
+        raise Exception(  # noqa: TRY002
+            "Non-equal-rank broadcast is not supported yet."
+        )  # noqa: TRY002
+    if len(s2) > len(s1):
+        # s3 = [1] * (len(s2) - len(s1)) + s1
+        raise Exception(  # noqa: TRY002
+            "Non-equal-rank broadcast is not supported yet."
+        )  # noqa: TRY002
+    ret = []
+    for d1, d2 in zip(s1, s2):
+        if d1 == 1:
+            ret.append(d2)
+        elif d2 == 1:
+            ret.append(d1)
+        elif d1 == d2:
+            ret.append(d1)
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Cannot broadcast shapes: {shape1} and {shape2}"
+            )  # noqa: TRY002
+    return tuple(ret)
+
+
+def get_conv_pool_shape(image_shape, args, out_ch, transpose):
+    batch, _in_c, in_h, in_w = image_shape
+
+    # TODO: Handle dilation
+    if args.dilation_h != 1 or args.dilation_w != 1:
+        raise Exception("Dilation not supported yet.")  # noqa: TRY002
+
+    if transpose:
+        out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
+        out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
+    else:
+        out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
+        out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
+
+    # Handle variable-sized tensors.
+    if in_h == 0:
+        out_h = 0
+    if in_w == 0:
+        out_w = 0
+
+    out_shape = (batch, out_ch, out_h, out_w)
+    return out_shape
+
+
+def fix_shape(shape, dim_order):
+    # Return the actual shape that an operand should have in NNAPI,
+    # given a PyTorch shape and dimension order.  This is where we
+    # convert from PyTorch's "always NCHW" shape to explicit NHWC.
+    if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+        return shape
+    if dim_order is DimOrder.CHANNELS_LAST:
+        return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
+    if dim_order is DimOrder.SCALAR_OR_VECTOR:
+        assert len(shape) == 0 or len(shape) == 1
+        return shape
+    if dim_order is DimOrder.UNKNOWN_CONSTANT:
+        # XXX think this through
+        return shape
+    raise Exception(f"Bad dim_order: {dim_order!r}.")  # noqa: TRY002
+
+
+def reverse_map_dim(dim_order, d):
+    # Return the original PyTorch dimension position for a given dimension.
+    # d should be the dimension that NNAPI will see.
+    # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
+    # reverse_map_dim(CHANNELS_LAST, 3) == 1
+    if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
+        return d
+    assert dim_order is DimOrder.CHANNELS_LAST
+    return [0, 2, 3, 1][d]
+
+
+def flex_name(op_id, dim):
+    # Return the local variable name for the computed flexible size
+    # for a given op and dimension.
+    return f"s_{op_id}_{dim}"
+
+
+class _NnapiSerializer:
+    def __init__(self, config, use_int16_for_qint16=False):
+        self.operands = []
+        self.values = []
+        self.operations = []
+        self.value_data = []
+        self.operation_args = []
+        self.inputs = []
+        self.outputs = []
+        self.flexible_shape_computation_lines = []
+
+        self.modules = {}
+        self.constants = {}
+        self.tensor_sequences = {}
+        self.jitval_operand_map = {}
+        self.cached_immediates = {}
+        self.used_weights = []
+        self.weight_offset = 0
+        self.use_int16_for_qint16 = use_int16_for_qint16
+
+        if config is None:
+            config = {}
+
+    def get_next_operand_id(self):
+        return len(self.operands)
+
+    # Add a tensor operand corresponding to a JIT Value.
+    # Returns the NNAPI operand ID.  Can be looked up later with
+    # get_tensor_operand_by_jitval.
+    def add_tensor_operand(self, jitval, oper):
+        assert isinstance(oper, Operand)
+        if jitval in self.jitval_operand_map:
+            raise Exception(f"Duplicate tensor: {jitval!r}")  # noqa: TRY002
+
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        self.jitval_operand_map[jitval] = operand_id
+        return operand_id
+
+    # Add a tensor operand that does not correspond to a JIT Value.
+    # Useful for cases where multiple NNAPI operands are required
+    # to implement one JIT IR node.  Returns the NNAPI operand ID.
+    def add_anonymous_tensor_operand(self, oper):
+        assert isinstance(oper, Operand)
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        return operand_id
+
+    def torch_tensor_to_operand(self, tensor, dim_order):
+        dtype = str(tensor.dtype).replace("torch.", "")
+        scale = 0.0
+        zero_point = 0
+        if dtype == "float32":
+            op_type = NNAPI_OperandCode.TENSOR_FLOAT32
+        elif dtype == "int32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+        elif dtype == "quint8":
+            op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+        elif dtype == "qint32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+            assert zero_point == 0
+        elif dtype == "int16":
+            if self.use_int16_for_qint16:
+                nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
+                op_codes = (
+                    NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+                    NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+                )
+                if nnapi_dtype in op_codes:
+                    op_type = nnapi_dtype
+                    scale = tensor.nnapi_scale
+                    zero_point = tensor.nnapi_zero_point
+                else:
+                    raise Exception(  # noqa: TRY002
+                        f"`nnapi_type` needs to be one of {op_codes} for `int16`"
+                    )
+            else:
+                raise Exception(  # noqa: TRY002
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Can't handle input with dtype '{tensor.dtype}'"
+            )  # noqa: TRY002
+        return Operand(
+            shape=tuple(tensor.shape),
+            # pyrefly: ignore [bad-argument-type]
+            op_type=op_type,
+            dim_order=dim_order,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+    def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
+        dim_order = (
+            DimOrder.CHANNELS_LAST
+            if getattr(tensor, "nnapi_nhwc", False)
+            else DimOrder.PRESUMED_CONTIGUOUS
+        )
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = self.add_tensor_operand(jitval, toper)
+        self.inputs.append(operand_id)
+        for dim, size in enumerate(tensor.shape):
+            if size == 0:
+                self.compute_operand_shape(
+                    operand_id, dim, f"args[{arg_idx}].shape[{dim}]"
+                )
+        return operand_id
+
+    def add_tensor_operand_for_weight(
+        self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT
+    ):
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = len(self.operands)
+        self.operands.append(toper)
+        tsize = tensor_size(toper.op_type, toper.shape)
+        self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
+        buf_num = len(self.used_weights)
+        offset = 0
+        self.value_data.append(struct.pack("iii", buf_num, offset, tsize))
+        # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
+        if dim_order == DimOrder.CHANNELS_LAST:
+            tensor = tensor.permute(0, 2, 3, 1)
+        self.used_weights.append(tensor)
+        return operand_id
+
+    def add_immediate_operand(self, code, value, dims):
+        assert isinstance(dims, tuple)
+        cache_key = (code, value)
+        if cache_key not in self.cached_immediates:
+            operand_id = len(self.operands)
+            self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
+            self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
+            self.value_data.append(value)
+            self.cached_immediates[cache_key] = operand_id
+        return self.cached_immediates[cache_key]
+
+    def add_immediate_int_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.INT32, struct.pack("i", value), ()
+        )
+
+    def add_immediate_float_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.FLOAT32, struct.pack("f", value), ()
+        )
+
+    def add_immediate_bool_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", ()
+        )
+
+    def add_immediate_int_vector(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.TENSOR_INT32,
+            array.array("i", value).tobytes(),
+            (len(value),),
+        )
+
+    def has_operand_for_jitval(self, jitval):
+        return jitval in self.jitval_operand_map
+
+    def get_tensor_operand_by_jitval(self, jitval):
+        operand_id = self.jitval_operand_map[jitval]
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_by_jitval_fixed_size(self, jitval):
+        op_id, oper = self.get_tensor_operand_by_jitval(jitval)
+        for s in oper.shape:
+            if s == 0:
+                # TODO: Improve this error message, possibly after converting
+                # many callsites to support flexible size.
+                raise Exception(  # noqa: TRY002
+                    "Flexible size is not supported for this operand."
+                )  # noqa: TRY002
+            if s < 0:
+                # runtime flex
+                LOG.warning("Operand %s has runtime flex shape", oper)
+        return op_id, oper
+
+    def get_tensor_operand_or_constant(
+        self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+    ):
+        operand_id = self.jitval_operand_map.get(jitval)
+        if operand_id is None:
+            _, value = self.get_constant_value(jitval, "TensorType")
+            operand_id = self.add_tensor_operand_for_weight(value, dim_order)
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_for_weight(self, jitval):
+        _, value = self.get_constant_value(jitval, "TensorType")
+        operand_id = self.add_tensor_operand_for_weight(value)
+        return (operand_id, self.operands[operand_id])
+
+    def add_operation(self, opcode, inputs, outputs):
+        self.operations.append((opcode, len(inputs), len(outputs)))
+        self.operation_args.extend(inputs + outputs)
+
+    def add_tensor_sequence(self, jitval, values):
+        assert jitval not in self.tensor_sequences
+        self.tensor_sequences[jitval] = values
+
+    def add_constant_value(self, jitval, ctype, value):
+        assert jitval not in self.constants
+        self.constants[jitval] = (ctype, value)
+
+    def get_constant_value(self, jitval, typekind=None):
+        record = self.constants.get(jitval)
+        if record is None:
+            raise Exception(  # noqa: TRY002
+                f"Could not find constant value for '{jitval!r}'."
+            )  # noqa: TRY002
+        ctype, _ = record
+        if typekind is not None and ctype.kind() != typekind:
+            raise Exception(  # noqa: TRY002
+                f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'"
+            )
+        return record
+
+    def operand_to_template_torchscript(self, op_id, oper, shape=None):
+        """Return a TorchScript expression to build a template for a given operand."""
+        if shape is None:
+            shape = oper.shape
+        else:
+            assert len(shape) == len(oper.shape)
+
+        shape_parts = ["("]
+        for d, s in enumerate(shape):
+            if s > 0:
+                # Fixed shape dimension: just add the value.
+                shape_parts.append(str(s))
+            elif s == 0:
+                # Load time flexible shape dimension: it should have been computed in a variable.
+                shape_parts.append(flex_name(op_id, d))
+            elif s == -1:
+                # Runtime flexible shape
+                shape_parts.append("0")
+            else:
+                raise Exception(  # noqa: TRY002
+                    "Unknown dim value, dimensions should be >= -1"
+                )  # noqa: TRY002
+            shape_parts.append(",")
+        shape_parts.append(")")
+        shape_code = "".join(shape_parts)
+        if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            return f"torch.zeros({shape_code}, dtype=torch.float32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
+            return f"torch.zeros({shape_code}, dtype=torch.int32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            return (
+                f"torch.quantize_per_tensor("
+                f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
+                f".expand({shape_code}).contiguous()"
+            )
+        elif oper.op_type in (
+            NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+            NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+        ):
+            if self.use_int16_for_qint16:
+                return f"torch.zeros({shape_code}, dtype=torch.int16)"
+            else:
+                raise Exception(  # noqa: TRY002
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+
+        raise Exception(  # noqa: TRY002
+            f"Unsupported output operand type: {oper.op_type}"
+        )  # noqa: TRY002
+
+    def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
+        self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
+
+    def compute_operand_shape(self, op_id, dim, expr):
+        self.flexible_shape_computation_lines.append(
+            f"{flex_name(op_id, dim)} = {expr}"
+        )
+
+    def transpose_to_nhwc(self, in_id, oper):
+        if oper.shape[2:] != (1, 1):
+            raise Exception(  # noqa: TRY002
+                "Automatic transpose only supported for H,W == 1,1"
+            )  # noqa: TRY002
+
+        out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
+
+        outputs = [None] * 1
+        outputs[0] = self.add_anonymous_tensor_operand(out_oper)
+
+        self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
+
+        return outputs[0], out_oper
+
+    # Transpose inputs as necessary to allow broadcasting.
+    def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
+        if in0_oper.dim_order == in1_oper.dim_order:
+            return in0_id, in0_oper, in1_id, in1_oper
+
+        # Assume NHWC is preferred if there is a mismatch.
+        orders = (in0_oper.dim_order, in1_oper.dim_order)
+        if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
+            return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+        if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
+            return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
+        raise Exception(  # noqa: TRY002
+            f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}"
+        )
+
+    def get_size_arg(self, jitval):
+        ctype, value = self.get_constant_value(jitval)
+        if ctype.kind() == "ListType":
+            assert ctype.getElementType().kind() == "IntType"
+            return value
+        raise Exception(  # noqa: TRY002
+            f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'"
+        )  # noqa: TRY002
+
+    def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
+        pc = [i.item() for i in packed_config]
+        assert pc[0] == 2
+        strides = [pc[1], pc[2]]
+        paddings = [pc[3], pc[4]]
+        dilations = [pc[5], pc[6]]
+        output_padding = [pc[7], pc[8]]
+        group_num = pc[9]
+
+        assert len(pc) == 11
+        assert output_padding == [0, 0]
+
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_from_jit(
+        self, kernel_size, stride, padding, dilation=None, group=None
+    ):
+        strides = self.get_size_arg(stride)
+        paddings = self.get_size_arg(padding)
+        if dilation is None:
+            dilations = [1, 1]
+        else:
+            dilations = self.get_size_arg(dilation)
+        if group is not None:
+            _, group_num = self.get_constant_value(group, "IntType")
+        else:
+            group_num = None
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_common(
+        self, kernel_size, strides, paddings, dilations, group_num
+    ):
+        kernels = list(kernel_size)
+
+        assert len(kernels) == 2
+        assert len(strides) == 2
+        assert len(paddings) == 2
+        assert len(dilations) == 2
+
+        # NNAPI uses 4 values for padding.
+        ph, pw = paddings
+        real_paddings = [ph, ph, pw, pw]
+
+        return ConvPoolArgs2d(
+            *(kernels + strides + real_paddings + dilations + [group_num])
+        )
+
+    def serialize_model(self, model, inputs, return_shapes=None):
+        self.add_immediate_bool_scalar(False)
+        self.add_immediate_bool_scalar(True)
+
+        inp_dim_orders = []
+        out_dim_orders = []
+
+        self_jitval = next(model.graph.inputs())
+        self.add_constant_value(self_jitval, self_jitval.type(), model)
+
+        for arg_idx, (input_value, input_tensor) in enumerate(
+            zip(list(model.graph.inputs())[1:], inputs)
+        ):
+            op_id = self.add_tensor_operand_for_input(
+                arg_idx, input_value, input_tensor
+            )
+            inp_dim_orders.append(self.operands[op_id].dim_order.value)
+
+        for idx, node in enumerate(model.graph.nodes()):
+            LOG.debug("Processing node #%d: %r", idx, node)
+            self.add_node(node)
+
+        retn = model.graph.return_node()
+        assert retn.inputsSize() == 1
+        assert retn.outputsSize() == 0
+        retn_input = retn.inputsAt(0)
+        template_return_lines = ["return ["]
+        if retn_input.type().kind() == "TensorType":
+            return_values = [retn_input]
+            retval_count = -1
+        elif retn_input.type().kind() == "TupleType":
+            return_values = self.tensor_sequences[retn_input]
+            retval_count = len(return_values)
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported return type: {retn_input.type()}"
+            )  # noqa: TRY002
+
+        if return_shapes is not None:
+            assert len(return_shapes) == len(return_values)
+        for i, v in enumerate(return_values):
+            op_id = self.jitval_operand_map[v]
+            self.outputs.append(op_id)
+            out_dim_orders.append(self.operands[op_id].dim_order.value)
+            shape = return_shapes[i] if return_shapes else None
+            template_return_lines.append(
+                self.operand_to_template_torchscript(op_id, self.operands[op_id], shape)
+                + ","
+            )
+        template_return_lines.append("]")
+
+        model = []
+
+        version = 1
+        header = struct.pack(
+            "iiiiii",
+            version,
+            len(self.operands),
+            len(self.values),
+            len(self.operations),
+            len(self.inputs),
+            len(self.outputs),
+        )
+        model.append(header)
+
+        serialized_values, serialized_value_data = self.serialize_values()
+
+        model.extend(
+            struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands
+        )
+        model.extend(serialized_values)
+        model.extend(struct.pack("iii", *x) for x in self.operations)
+
+        # Compact the model so we can get its length so far.
+        model = [b"".join(model)]
+        model_offset = len(model[0])
+        # Model offset is the index into the model (in 32-bit words, not bytes)
+        # of the next dimension we're about to serialize.  If it's 0,
+        # generate code to mutate it before passing to NNAPI.
+        assert model_offset % 4 == 0
+        model_offset = int(model_offset / 4)
+
+        for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands):
+            shape = fix_shape(dims, dim_order)
+            for d, s in enumerate(shape):
+                if s == 0:
+                    pt_d = reverse_map_dim(dim_order, d)
+                    self.flexible_shape_computation_lines.append(
+                        f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}"
+                    )
+                model_offset += 1
+
+            # convert runtime flex shape from -1 to 0
+            shape = tuple(d if d != -1 else 0 for d in shape)
+            model.append(self.serialize_ints(shape))
+
+        model.extend(serialized_value_data)
+        model.append(self.serialize_ints(self.operation_args))
+        model.append(self.serialize_ints(self.inputs))
+        model.append(self.serialize_ints(self.outputs))
+
+        self.flexible_shape_computation_lines.extend(template_return_lines)
+
+        return (
+            array.array("i", b"".join(model)),
+            self.used_weights,
+            inp_dim_orders,
+            out_dim_orders,
+            self.flexible_shape_computation_lines,
+            retval_count,
+        )
+
+    def serialize_values(self):
+        serialized_values = []
+        serialized_value_data = []
+        assert len(self.values) == len(self.value_data)
+        for (op_index, source_type), data in zip(self.values, self.value_data):
+            source_length = len(data)
+
+            # Pad with 0 bytes out to a multiple of 4 for alignment.
+            physical_length = ((source_length - 1) | 0x3) + 1
+            padded_data = data + (b"\0" * (physical_length - source_length))
+
+            serialized_values.append(
+                struct.pack("iii", op_index, source_type, source_length)
+            )
+            serialized_value_data.append(padded_data)
+
+        return serialized_values, serialized_value_data
+
+    @staticmethod
+    def serialize_ints(ints):
+        return array.array("i", ints).tobytes()
+
+    ADDER_MAP = {
+        "prim::GetAttr": lambda self, node: self.add_getattr(node),
+        "prim::Constant": lambda self, node: self.add_constant_node(node),
+        "prim::ListConstruct": lambda self, node: self.add_list_construct(node),
+        "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node),
+        "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node),
+        "aten::to": lambda self, node: self.add_to(node),
+        "aten::detach": lambda self, node: self._identity(node),
+        "aten::reshape": lambda self, node: self.add_reshape(node),
+        "aten::flatten": lambda self, node: self.add_flatten(node),
+        "aten::slice": lambda self, node: self.add_slice(node),
+        "aten::size": lambda self, node: self.add_size(node),
+        "aten::cat": lambda self, node: self.add_cat(node),
+        "aten::mean": lambda self, node: self.add_mean(node),
+        "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node),
+        "aten::dequantize": lambda self, node: self.add_dequantize(node),
+        "aten::add": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::sub": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.RELU
+        ),
+        "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.LOGISTIC
+        ),
+        "aten::softmax": lambda self, node: self.add_softmax(node),
+        "aten::hardtanh": lambda self, node: self.add_hardtanh(node),
+        "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node),
+        "aten::max_pool2d": lambda self, node: self.add_pool2d_node(
+            node, NNAPI_OperationCode.MAX_POOL_2D
+        ),
+        "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d(
+            node
+        ),
+        "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d(
+            node
+        ),
+        "aten::prelu": lambda self, node: self.add_prelu_op(node),
+        "aten::addmm": lambda self, node: self.add_addmm(node),
+        "aten::linear": lambda self, node: self.add_linear(node),
+        "aten::_convolution": lambda self, node: self.add_conv_underscore(node),
+        "aten::conv2d": lambda self, node: self.add_conv2d(node),
+        "aten::log_softmax": lambda self, node: self.add_log_softmax(node),
+        "quantized::linear": lambda self, node: self.add_qlinear(node),
+        "quantized::conv2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::conv2d_relu": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE, transpose=True
+        ),
+        "quantized::add": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::add_relu": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::mul": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+    }
+
+    def add_node(self, node):
+        adder = self.ADDER_MAP.get(node.kind())
+        if not adder:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported node kind ({node.kind()!r}) in node {node!r}"
+            )  # noqa: TRY002
+        adder(self, node)
+
+    def _identity(self, node):
+        in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        jitval = node.outputsAt(0)
+        self.jitval_operand_map[jitval] = in_id
+
+    def add_getattr(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+        obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
+        assert str(obj_ctype).startswith("__torch__.")
+        name = node.s("name")
+        value = getattr(obj, name)
+        output = node.outputsAt(0)
+        ctype = output.type()
+        self.add_constant_value(output, ctype, value)
+
+    def add_constant_node(self, node):
+        assert node.inputsSize() == 0
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        value = output.toIValue()
+        self.add_constant_value(output, ctype, value)
+
+    def add_list_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        const_vals: Optional[list] = []
+        tensors: Optional[list] = []
+        for inp in node.inputs():
+            if const_vals is not None and inp in self.constants:
+                _, val = self.get_constant_value(inp)
+                const_vals.append(val)
+            else:
+                const_vals = None
+            if tensors is not None and inp.type().kind() == "TensorType":
+                tensors.append(inp)
+            else:
+                tensors = None
+
+        if const_vals is not None:
+            # NOTE: Now that TorchScript supports list constants,
+            # this code path might not be used anymore.
+            self.add_constant_value(output, ctype, const_vals)
+        if tensors is not None:
+            self.add_tensor_sequence(output, tensors)
+        if const_vals is None and tensors is None:
+            raise Exception(  # noqa: TRY002
+                f"Unable to handle ListConstruct node.  Neither all constants nor all tensors. {node!r}"
+            )
+
+    def add_tuple_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        values = list(node.inputs())
+        self.add_tensor_sequence(output, values)
+
+    def add_unsqueeze(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
+
+        real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
+        out_shape_list = list(in_oper.shape)
+        out_shape_list.insert(real_dim, 1)
+        out_shape = tuple(out_shape_list)
+        out_oper = in_oper._replace(shape=out_shape)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
+
+    def add_to(self, node):
+        # Handle to("cpu") / to("gpu") case
+        self._identity(node)
+
+    def add_reshape(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
+        assert shape_ctype.kind() == "ListType"
+        assert shape_ctype.getElementType().kind() == "IntType"
+        is_trivial_reshape = len(shape) == 2 and shape[1] == -1
+
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
+            raise Exception(  # noqa: TRY002
+                "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]."
+            )
+
+        # Bit of a hack here.  Use a real tensor to infer the output shape.
+        out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(shape)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_flatten(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
+
+        # channels last with channels == 1 or (height & width both 1)
+        is_trivial_flatten = len(in_oper.shape) == 4 and (
+            in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)
+        )
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
+            raise Exception(  # noqa: TRY002
+                "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1"
+            )
+
+        if start_dim < 0:
+            start_dim += len(in_oper.shape)
+        if end_dim < 0:
+            end_dim += len(in_oper.shape)
+
+        out_shape = (
+            in_oper.shape[:start_dim]
+            + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
+            + in_oper.shape[end_dim + 1 :]
+        )
+
+        if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]):
+            raise Exception(  # noqa: TRY002
+                "Flattening flexible dims is not supported yet"
+            )  # noqa: TRY002
+        non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :]
+        if non_flattened_dims.count(0) > 1:
+            raise Exception("Only 1 dim can be flexible")  # noqa: TRY002
+
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
+
+        inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape)
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(inputs_1)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_slice(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        _, dim_value = self.get_constant_value(node.inputsAt(1))
+        _, start_value = self.get_constant_value(node.inputsAt(2))
+        _, stop_value = self.get_constant_value(node.inputsAt(3))
+        _, step_value = self.get_constant_value(node.inputsAt(4))
+
+        if start_value is None:
+            start_value = 0
+        if stop_value is None:
+            stop_value = sys.maxsize
+
+        if start_value < 0:
+            start_value += in_oper.shape[dim_value]
+        elif start_value == sys.maxsize:
+            start_value = 0
+
+        if start_value == 0 and stop_value == sys.maxsize:
+            self._identity(node)
+            return
+
+        if in_oper.shape[dim_value] == 0:
+            raise Exception("Unable to slice with flexible shape")  # noqa: TRY002
+
+        if stop_value < 0:
+            stop_value += in_oper.shape[dim_value]
+        elif stop_value == sys.maxsize:
+            stop_value = in_oper.shape[dim_value]
+
+        if start_value >= stop_value:
+            raise Exception(  # noqa: TRY002
+                "Slice start value should be less than stop value"
+            )  # noqa: TRY002
+
+        out_len = (stop_value - start_value) // step_value
+        out_shape = tuple(
+            out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)
+        )
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), in_oper._replace(shape=out_shape)
+        )
+
+        # flex inputs
+        end_mask = 0
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+                end_mask |= 1 << idx
+
+        inputs = [None] * 7
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(
+            [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]
+        )
+        inputs[2] = self.add_immediate_int_vector(
+            [
+                stop_value if i == dim_value else dim
+                for i, dim in enumerate(in_oper.shape)
+            ]
+        )
+        inputs[3] = self.add_immediate_int_vector(
+            [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]
+        )
+        inputs[4] = self.add_immediate_int_scalar(0)  # begin mask
+        inputs[5] = self.add_immediate_int_scalar(end_mask)
+        inputs[6] = self.add_immediate_int_scalar(0)  # shrink axis mas
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
+
+    def add_size(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, value = self.constants[node.inputsAt(1)]
+        res = in_oper.shape[value]
+        output = node.outputsAt(0)
+        self.add_constant_value(output, output.type(), res)
+
+    def add_cat(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        tensors = self.tensor_sequences[node.inputsAt(0)]
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        assert len(tensors) > 0
+        in_ids = []
+        out_oper = None
+        out_dim_size = 0
+        for inp in tensors:
+            in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
+            if out_oper is None:
+                out_shape = change_element(in_oper.shape, dim, -1)
+                out_oper = in_oper._replace(shape=out_shape)
+            assert in_oper.op_type == out_oper.op_type
+            assert in_oper.dim_order == out_oper.dim_order
+            assert change_element(in_oper.shape, dim, -1) == change_element(
+                out_oper.shape, dim, -1
+            )
+            # TODO: Possibly check scale and zero point.
+            in_ids.append(in_id)
+            # TODO: Possibly support variable-sized inputs.
+            out_dim_size += in_oper.shape[dim]
+
+        assert out_oper is not None
+        out_oper = out_oper._replace(
+            shape=change_element(out_oper.shape, dim, out_dim_size)
+        )
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:  # type: ignore[possibly-undefined]
+            assert len(out_oper.shape) == 4
+            nnapi_dim = [0, 3, 1, 2][dim]
+        else:
+            nnapi_dim = dim
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, d in enumerate(out_oper.shape):
+            if d == 0:
+                if idx == dim:
+                    shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
+                    self.compute_operand_shape(out_id, idx, shape)
+                else:
+                    self.forward_operand_shape(out_id, idx, in_ids[0], idx)
+
+        inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
+
+    def add_mean(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
+        assert dim_ctype.kind() == "ListType"
+        assert dim_ctype.getElementType().kind() == "IntType"
+        _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
+        # Expect None for dtype
+        self.get_constant_value(node.inputsAt(3), "NoneType")
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:
+            assert len(in_oper.shape) == 4
+            nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
+        else:
+            nnapi_dim = dim
+
+        collapsed_dims = set()
+        for d in dim:
+            if d < 0:
+                d += len(in_oper.shape)
+            collapsed_dims.add(d)
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
+            assert collapsed_dims.issuperset({2, 3})
+            out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
+        else:
+            out_dim_order = in_oper.dim_order
+
+        out_shape = []
+        for i, s in enumerate(in_oper.shape):
+            if i not in collapsed_dims:
+                out_shape.append(s)
+            elif keep_dim:
+                out_shape.append(1)
+
+        out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(nnapi_dim)
+        inputs[2] = self.add_immediate_int_scalar(keep_dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
+
+    def add_quantize(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        if in_oper.dim_order != DimOrder.CHANNELS_LAST:
+            raise Exception(  # noqa: TRY002
+                "Most hardware backends prefer NHWC quantized tensors.  "
+                "Try setting `t.nnapi_nhwc = True` on your tensor inputs.  "
+            )
+        _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
+        _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
+        if scalar_type != TorchScalarTypes.QUINT8.value:
+            raise Exception(  # noqa: TRY002
+                "PyTorch NNAPI export only supports quantized tensors "
+                "with the quint8 dtype."
+            )
+        op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+
+        out_oper = in_oper._replace(
+            op_type=op_type,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
+
+    def add_dequantize(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        out_oper = in_oper._replace(
+            op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
+            scale=0.0,
+            zero_point=0,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
+
+    def add_pointwise_simple_unary_op(self, node, opcode):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        out_oper = in_oper
+        if opcode == NNAPI_OperationCode.LOGISTIC:
+            # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
+            # must be 1.f / 256 and the zeroPoint must be 0.
+            # https://fburl.com/h52stoog
+            if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+                out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(in_oper.shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):  # noqa: D401
+        """Helper for pointwise binary broadcast ops with superfluous extra args."""
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        if self.has_operand_for_jitval(node.inputsAt(0)):
+            in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+            in1_id, in1_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(1), in0_oper.dim_order
+            )
+        elif self.has_operand_for_jitval(node.inputsAt(1)):
+            in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
+            in0_id, in0_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(0), in1_oper.dim_order
+            )
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Can't do a NNAPI binary op: {opcode} on two constants"
+            )  # noqa: TRY002
+
+        assert in0_oper.op_type == in1_oper.op_type
+        in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
+            in0_id, in0_oper, in1_id, in1_oper
+        )
+        # NOTE: PyTorch and NNAPI have the same broadcast semantics.
+        out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
+        out_oper = in0_oper._replace(shape=out_shape)
+        if qparams is not None:
+            scale, zp = qparams
+            out_oper = out_oper._replace(scale=scale, zero_point=zp)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
+            if d0 == 1 and d1 == 0:
+                self.forward_operand_shape(out_id, idx, in1_id, idx)
+            elif d0 == 0 and d1 == 1:
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+            elif d0 == 0 and d1 == 0:
+                self.flexible_shape_computation_lines.append(
+                    f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
+                )
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+
+        inputs = [None] * 3
+        inputs[0] = in0_id
+        inputs[1] = in1_id
+        inputs[2] = self.add_immediate_int_scalar(fuse_code)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 2
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_add_sub_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 3
+
+        _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
+        if alpha != 1:
+            raise Exception(  # noqa: TRY002
+                "NNAPI does not support add/sub with alpha."
+            )  # noqa: TRY002
+
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_qadd(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 4
+
+        _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
+
+        self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
+
+    def add_softmax(self, node):
+        assert node.inputsSize() == 3
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size == 0:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_float_scalar(
+            1.0
+        )  # positive scaling factor of exponent, beta
+        inputs[2] = self.add_immediate_int_scalar(softmax_dim)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
+
+    def add_hardtanh(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
+
+        op_map = {
+            (-1, 1): NNAPI_OperationCode.RELU1,
+            (0, 6): NNAPI_OperationCode.RELU6,  # noqa: E201
+        }
+
+        opcode = op_map.get((min_val, max_val))
+        if opcode is None:
+            raise Exception(  # noqa: TRY002
+                "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)."
+            )  # noqa: TRY002
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_prelu_op(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
+        assert len(w_oper.shape) == 1
+        assert w_oper.shape[0] > 0
+        if w_oper.shape[0] > 1:
+            if in_oper.use_nchw():
+                # TODO: Support this by adding trailing 1 dims.
+                raise Exception(  # noqa: TRY002
+                    "Per-channel PReLU only supports channels_last right now."
+                )
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size > 0:
+                pass
+            elif dim <= 1:
+                raise Exception(  # noqa: TRY002
+                    "PReLU requires fixed size for dim 0 and dim 1."
+                )  # noqa: TRY002
+            else:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = w_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
+
+    def add_pool2d_node(self, node, opcode):
+        assert node.inputsSize() == 6
+        assert node.outputsSize() == 1
+        image, kernel, stride, padding, dilation, _ceil_mode = node.inputs()
+
+        stride = stride or kernel
+
+        # TODO: Validate ceil_mode semantics.
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding, dilation
+        )
+        if args.dilation_h != 1 or args.dilation_w != 1:
+            raise Exception("NNAPI does not support dilated pooling.")  # noqa: TRY002
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_avg_pool2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+        (
+            image,
+            kernel,
+            stride,
+            padding,
+            _ceil_mode,
+            count_include_pad,
+            divisor_override,
+        ) = node.inputs()
+
+        _, count_include_pad_value = self.get_constant_value(count_include_pad)
+        _, divisor_override_value = self.get_constant_value(divisor_override)
+        if not count_include_pad_value or divisor_override_value:
+            raise Exception(  # noqa: TRY002
+                "NNAPI doesn't support count_include_pad=False or divisor_override"
+            )
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding
+        )
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+        self._handle_conv_pool_flexible_input(out_id, image, args, False)
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_adaptive_avg_pool2d(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(
+            node.inputsAt(0)
+        )
+        assert len(image_oper.shape) == 4
+
+        size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
+        assert size_ctype.kind() == "ListType"
+        assert size_ctype.getElementType().kind() == "IntType"
+        if size_arg != [1, 1]:
+            raise Exception(  # noqa: TRY002
+                "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)."
+            )
+
+        out_shape = image_oper.shape[0:2] + tuple(size_arg)
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(0)
+        inputs[2] = self.add_immediate_int_scalar(0)
+        inputs[3] = self.add_immediate_int_scalar(0)
+        inputs[4] = self.add_immediate_int_scalar(0)
+        inputs[5] = self.add_immediate_int_scalar(1)
+        inputs[6] = self.add_immediate_int_scalar(1)
+        inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
+        inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_upsample_nearest2d(self, node):
+        assert node.inputsSize() == 3 or node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        if node.inputsSize() == 3:
+            image, size_jit, scale_jit = node.inputs()
+        else:
+            image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
+        size_ctype, size_arg = self.get_constant_value(size_jit)
+
+        if node.inputsSize() == 3:
+            scale_ctype, scale_arg = self.get_constant_value(scale_jit)  # type: ignore[possibly-undefined]
+        else:
+            scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)  # type: ignore[possibly-undefined]
+            scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit)  # type: ignore[possibly-undefined]
+
+            # The only way for the 4-argument overload of upsample_nearest2d to
+            # have been added to the graph without error is if the scale_h and
+            # scale_w arguments are None
+            assert scale_h_ctype.kind() == "NoneType"
+            assert scale_w_ctype.kind() == "NoneType"
+
+            scale_ctype = scale_h_ctype
+            scale_arg = scale_h_arg
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
+            raise Exception("Size and scale cannot both be non-None.")  # noqa: TRY002
+        elif size_ctype.kind() != "NoneType":
+            assert size_ctype.kind() == "ListType"
+            assert size_ctype.getElementType().kind() == "IntType"
+            assert scale_ctype.kind() == "NoneType"
+            assert scale_arg is None
+            assert isinstance(size_arg, list)
+            assert size_arg
+            assert all(isinstance(val, int) for val in size_arg)
+            if len(size_arg) == 1:
+                size_arg = size_arg * 2
+            assert len(size_arg) == 2
+            out_h = size_arg[0]
+            out_w = size_arg[1]
+            arg_h = self.add_immediate_int_scalar(out_h)
+            arg_w = self.add_immediate_int_scalar(out_w)
+        elif scale_ctype.kind() != "NoneType":
+            assert scale_ctype.kind() == "ListType"
+            assert scale_ctype.getElementType().kind() == "FloatType"
+            assert size_ctype.kind() == "NoneType"
+            assert size_arg is None
+            assert isinstance(scale_arg, list)
+            assert scale_arg
+            assert all(isinstance(val, float) for val in scale_arg)
+            if len(scale_arg) == 1:
+                scale_arg = scale_arg * 2
+            assert len(scale_arg) == 2
+            out_h = int(scale_arg[0] * image_oper.shape[2])
+            out_w = int(scale_arg[1] * image_oper.shape[3])
+            arg_h = self.add_immediate_float_scalar(scale_arg[0])
+            arg_w = self.add_immediate_float_scalar(scale_arg[1])
+        else:
+            raise Exception("Size and scale cannot both be None.")  # noqa: TRY002
+
+        out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
+        use_nchw = image_oper.use_nchw()
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
+            raise Exception("Flexible batch or channels not supported")  # noqa: TRY002
+
+        # Handle variable input size
+        for dim in (2, 3):  # h, w indices
+            if image_oper.shape[dim] == 0:
+                if size_ctype.kind() != "NoneType":
+                    # pyrefly: ignore [unsupported-operation]
+                    self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
+                elif scale_ctype.kind() != "NoneType":
+                    self.compute_operand_shape(
+                        out_id,
+                        dim,
+                        # pyrefly: ignore [unsupported-operation]
+                        f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
+                    )
+                else:
+                    raise Exception(  # noqa: TRY002
+                        "Size and scale cannot both be None."
+                    )  # noqa: TRY002
+
+        inputs = [None] * 4
+        inputs[0] = image_id
+        inputs[1] = arg_w
+        inputs[2] = arg_h
+        inputs[3] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
+
+    def add_addmm(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+        jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
+
+        for jitval in (jit_beta, jit_alpha):
+            scale_ctype, scale_value = self.get_constant_value(jitval)
+            assert scale_ctype.kind() in ("IntType", "FloatType")
+            if scale_value != 1:
+                raise Exception(  # noqa: TRY002
+                    "NNAPI Fully-Connected does not support alpha and beta."
+                )
+
+        self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
+
+    def add_linear(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+        jit_input, jit_weight, jit_bias = node.inputs()
+
+        self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
+
+    def add_addmm_or_linear(
+        self, node, transpose_weight, jit_input, jit_weight, jit_bias
+    ):
+        input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
+        bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
+
+        assert len(input_oper.shape) == 2
+        assert len(bias_oper.shape) == 1
+
+        # TODO: Transform at load time to share weights with CPU model.
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        assert len(weight_tensor.shape) == 2
+        if transpose_weight:
+            nnapi_weight_tensor = weight_tensor.t().contiguous()
+        else:
+            nnapi_weight_tensor = weight_tensor.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+
+        if input_oper.shape[0] == 0:
+            self.forward_operand_shape(out_id, 0, input_id, 0)
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def add_qlinear(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        (
+            jit_input,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        # TODO: Support automatic reshape
+        assert len(input_oper.shape) == 2
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "LinearPackedParamsBase"
+        raw_weight, raw_bias = packed_weight.__getstate__()[0]
+        assert raw_bias is not None
+
+        assert len(raw_weight.shape) == 2
+        assert len(raw_bias.shape) == 1
+        assert raw_bias.shape[0] == raw_weight.shape[0]
+        assert raw_weight.shape[1] == input_oper.shape[1]
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        bias_scale = input_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = input_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(  # noqa: TRY002
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = unsigned_weight.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_oper = input_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
+        ctype, _value = self.get_constant_value(jit_bias)
+        if ctype.kind() == "NoneType":
+            bias_idx = 1 if transpose else 0
+            nnapi_bias_tensor = torch.zeros(
+                weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype
+            )
+            bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
+            bias_oper = self.operands[bias_id]
+            return bias_id, bias_oper
+        else:
+            return self.get_tensor_operand_for_weight(jit_bias)
+
+    def add_conv2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_groups,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            False,  # transpose
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_conv_underscore(self, node):
+        assert node.inputsSize() == 13
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_transpose,
+            _,
+            jit_groups,
+            _,
+            _,
+            _,
+            _,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        _, transpose = self.get_constant_value(jit_transpose)
+        bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            transpose,
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_log_softmax(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        jit_input, jit_dim, _jit_half_to_float = node.inputs()
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        _, dim = self.get_constant_value(jit_dim, "IntType")
+
+        out_shape = input_oper.shape
+
+        inputs = [None] * 3
+        inputs[0] = input_id
+        # specifying 1 as the scaling factor for the exponent, beta
+        inputs[1] = self.add_immediate_float_scalar(1)
+        inputs[2] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+        self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
+
+    def add_qconv2d(self, node, fuse_code, transpose=False):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "Conv2dPackedParamsBase"
+        (
+            pack_version,
+            tensors,
+            opt_tensors,
+        ) = packed_weight.__getstate__()[0]
+        assert pack_version == "2"
+        packed_config, raw_weight = tensors
+        (raw_bias,) = opt_tensors
+        assert raw_bias is not None
+        args = self.get_conv_pool_args_2d_from_pack(
+            raw_weight.shape[2:4], packed_config
+        )
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        bias_scale = image_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = image_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(  # noqa: TRY002
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            out_scale,
+            out_zero_point,
+            jit_image,
+            unsigned_weight,
+            bias_id,
+            args,
+            transpose,
+            fuse_code,
+        )
+
+    def add_conv2d_common(
+        self,
+        jit_out,
+        out_scale,
+        out_zero_point,
+        jit_image,
+        weight_tensor,
+        bias_id,
+        args,
+        transpose,
+        fuse_code,
+    ):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        in_c = image_oper.shape[1]
+
+        if args.group == 1:
+            # Full convolution
+            depthwise = False
+            if transpose:
+                weight_permutation = (1, 2, 3, 0)
+            else:
+                weight_permutation = (0, 2, 3, 1)
+        elif args.group == in_c:
+            # Depthwise convolution
+            depthwise = True
+            weight_permutation = (1, 2, 3, 0)
+        else:
+            raise Exception("Group convolution not supported yet.")  # noqa: TRY002
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        bias_oper = self.operands[bias_id]
+
+        if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+        elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
+            assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
+            assert bias_oper.zero_point == 0
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported input type for conv2d: {image_oper.op_type}"
+            )  # noqa: TRY002
+
+        assert len(image_oper.shape) == 4
+        assert len(weight_oper.shape) == 4
+        assert len(bias_oper.shape) == 1
+
+        if depthwise:
+            # Depthwise convolution
+            one, _kern_h, _kern_w, out_c = weight_oper.shape
+            assert one == 1
+            assert out_c % in_c == 0
+            channel_multiplier = out_c // in_c
+            assert channel_multiplier == 1  # Don't support multiplier
+            assert out_c == in_c
+        else:
+            # Full convolution
+            out_c, _kern_h, _kern_w, kern_d = weight_oper.shape
+            assert kern_d == in_c
+
+        assert out_c == bias_oper.shape[0]
+
+        use_nchw = image_oper.use_nchw()
+
+        if depthwise:
+            num_args = 12
+            opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
+        else:
+            num_args = 11
+            if transpose:
+                opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
+            else:
+                opcode = NNAPI_OperationCode.CONV_2D
+
+        inputs = [None] * num_args
+        inputs[0] = image_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[5] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[6] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[7] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[8] = self.add_immediate_int_scalar(args.stride_h)
+        if depthwise:
+            inputs[9] = self.add_immediate_int_scalar(1)
+            inputs[10] = self.add_immediate_int_scalar(fuse_code)
+            inputs[11] = self.add_immediate_bool_scalar(use_nchw)
+        else:
+            inputs[9] = self.add_immediate_int_scalar(fuse_code)
+            inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
+        out_oper = image_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+        out_id = self.add_tensor_operand(jit_out, out_oper)
+        self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
+
+        outputs[0] = out_id
+        self.add_operation(opcode, inputs, outputs)
+
+    def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        batch, in_ch, in_h, in_w = image_oper.shape
+
+        if batch == 0:
+            self.forward_operand_shape(out_id, 0, image_id, 0)
+        if in_ch == 0:
+            raise Exception("Input channels can't be flexible")  # noqa: TRY002
+        # H & W
+        if transpose:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}",
+                )
+        else:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1",
+                )
+
+
+def serialize_model(
+    module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False
+):
+    """Convert to NNAPI and serialize torchscript module.
+
+    Parameters:
+        module: Torchscript module to convert
+        inputs: Tensors used to specify input details for NNAPI
+        config (optional): Optional config to attach to module
+        return_shapes (optional): Specify shape of outputs if
+            your module uses runtime flexible shapes to set output
+            buffer size for NNAPI
+        use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
+    """
+    return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(
+        module, inputs, return_shapes
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82dc52cd4904c1cda023c876c586550a5a33ff7a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__init__.py
@@ -0,0 +1,21 @@
+import torch
+
+
+__all__ = [
+    "get_cpu_capability",
+]
+
+
+def get_cpu_capability() -> str:
+    r"""Return cpu capability as a string value.
+
+    Possible values:
+    - "DEFAULT"
+    - "VSX"
+    - "Z VECTOR"
+    - "NO AVX"
+    - "AVX2"
+    - "AVX512"
+    - "SVE256"
+    """
+    return torch._C._get_cpu_capability()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b68f43036b01804563bfc44cc0917a07ef79f4d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62c2b05a1ea1f3ecc5ceb0fbc17f5a714d87941
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__init__.py
@@ -0,0 +1,593 @@
+# mypy: allow-untyped-defs
+import contextlib
+from typing import Any, Union
+from typing_extensions import deprecated
+
+import torch
+
+
+__all__ = [
+    "is_built",
+    "cuFFTPlanCacheAttrContextProp",
+    "cuFFTPlanCache",
+    "cuFFTPlanCacheManager",
+    "cuBLASModule",
+    "preferred_linalg_library",
+    "preferred_blas_library",
+    "preferred_rocm_fa_library",
+    "cufft_plan_cache",
+    "matmul",
+    "SDPAParams",
+    "enable_cudnn_sdp",
+    "cudnn_sdp_enabled",
+    "enable_flash_sdp",
+    "flash_sdp_enabled",
+    "enable_mem_efficient_sdp",
+    "mem_efficient_sdp_enabled",
+    "math_sdp_enabled",
+    "enable_math_sdp",
+    "allow_fp16_bf16_reduction_math_sdp",
+    "fp16_bf16_reduction_math_sdp_allowed",
+    "is_flash_attention_available",
+    "can_use_flash_attention",
+    "can_use_efficient_attention",
+    "can_use_cudnn_attention",
+    "sdp_kernel",
+]
+
+
+def is_built():
+    r"""
+    Return whether PyTorch is built with CUDA support.
+
+    Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch
+    binary were run on a machine with working CUDA drivers and devices, we would be able to use it.
+    """
+    return torch._C._has_cuda
+
+
+class cuFFTPlanCacheAttrContextProp:
+    # Like regular ContextProp, but uses the `.device_index` attribute from the
+    # calling object as the first argument to the getter and setter.
+    def __init__(self, getter, setter):
+        self.getter = getter
+        self.setter = setter
+
+    def __get__(self, obj, objtype):
+        return self.getter(obj.device_index)
+
+    def __set__(self, obj, val):
+        if isinstance(self.setter, str):
+            raise RuntimeError(self.setter)
+        self.setter(obj.device_index, val)
+
+
+class cuFFTPlanCache:
+    r"""
+    Represent a specific plan cache for a specific `device_index`.
+
+    The attributes `size` and `max_size`, and method `clear`, can fetch and/ or
+    change properties of the C++ cuFFT plan cache.
+    """
+
+    def __init__(self, device_index):
+        self.device_index = device_index
+
+    size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_size,
+        ".size is a read-only property showing the number of plans currently in the "
+        "cache. To change the cache capacity, set cufft_plan_cache.max_size.",
+    )
+
+    max_size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size
+    )
+
+    def clear(self):
+        return torch._cufft_clear_plan_cache(self.device_index)
+
+
+class cuFFTPlanCacheManager:
+    r"""
+    Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed.
+
+    Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
+    setting the `.max_size`) attribute, the current device's cuFFT plan cache is
+    used.
+    """
+
+    __initialized = False
+
+    def __init__(self):
+        self.caches = []
+        self.__initialized = True
+
+    def __getitem__(self, device):
+        index = torch.cuda._utils._get_device_index(device)
+        if index < 0 or index >= torch.cuda.device_count():
+            raise RuntimeError(
+                f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "
+                f"device with index {index}"
+            )
+        if len(self.caches) == 0:
+            self.caches.extend(
+                cuFFTPlanCache(index) for index in range(torch.cuda.device_count())
+            )
+        return self.caches[index]
+
+    def __getattr__(self, name):
+        return getattr(self[torch.cuda.current_device()], name)
+
+    def __setattr__(self, name, value):
+        if self.__initialized:
+            return setattr(self[torch.cuda.current_device()], name, value)
+        else:
+            return super().__setattr__(name, value)
+
+
+class cuBLASModule:
+    @staticmethod
+    def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]:
+        def _ensure_bool(obj: Any, which: str) -> bool:
+            if isinstance(obj, bool):
+                return obj
+            raise TypeError(
+                f"{attr_name} expects a bool for {which}, but got {type(obj)!r}"
+            )
+
+        if isinstance(value, bool):
+            return value, True
+        if isinstance(value, (list, tuple)):
+            if not value:
+                raise TypeError(f"{attr_name} expects at least one boolean argument")
+            if len(value) > 2:
+                raise TypeError(f"{attr_name} expects at most two boolean arguments")
+            allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision")
+            if len(value) == 1:
+                return allow_reduced_precision, True
+            allow_splitk = _ensure_bool(value[1], "allow_splitk")
+            return allow_reduced_precision, allow_splitk
+        raise TypeError(
+            f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}"
+        )
+
+    def __getattr__(self, name):
+        if name == "allow_tf32":
+            return torch._C._get_cublas_allow_tf32()
+        elif name == "allow_fp16_reduced_precision_reduction":
+            allow_reduced_precision, _ = (
+                torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
+            )
+            return allow_reduced_precision
+        elif name == "allow_fp16_reduced_precision_reduction_split_k":
+            _, allow_splitk = (
+                torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
+            )
+            return allow_splitk
+        elif name == "allow_bf16_reduced_precision_reduction":
+            allow_reduced_precision, _ = (
+                torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
+            )
+            return allow_reduced_precision
+        elif name == "allow_bf16_reduced_precision_reduction_split_k":
+            _, allow_splitk = (
+                torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
+            )
+            return allow_splitk
+        elif name == "allow_fp16_accumulation":
+            return torch._C._get_cublas_allow_fp16_accumulation()
+        elif name == "fp32_precision":
+            return torch._C._get_fp32_precision_getter("cuda", "matmul")
+        raise AttributeError("Unknown attribute " + name)
+
+    def __setattr__(self, name, value):
+        if name == "allow_tf32":
+            return torch._C._set_cublas_allow_tf32(value)
+        elif name == "allow_fp16_reduced_precision_reduction":
+            allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
+                value, "allow_fp16_reduced_precision_reduction"
+            )
+            return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
+                allow_reduced_precision,
+                allow_splitk,
+            )
+        elif name == "allow_bf16_reduced_precision_reduction":
+            allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
+                value, "allow_bf16_reduced_precision_reduction"
+            )
+            return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
+                allow_reduced_precision,
+                allow_splitk,
+            )
+        elif name == "allow_fp16_accumulation":
+            return torch._C._set_cublas_allow_fp16_accumulation(value)
+        elif name == "fp32_precision":
+            return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
+        raise AttributeError("Unknown attribute " + name)
+
+
+_LinalgBackends = {
+    "default": torch._C._LinalgBackend.Default,
+    "cusolver": torch._C._LinalgBackend.Cusolver,
+    "magma": torch._C._LinalgBackend.Magma,
+}
+_LinalgBackends_str = ", ".join(_LinalgBackends.keys())
+
+
+def preferred_linalg_library(
+    backend: Union[None, str, torch._C._LinalgBackend] = None,
+) -> torch._C._LinalgBackend:
+    r"""
+    Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations.
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
+    and if both are available it decides which to use with a heuristic.
+    This flag (a :class:`str`) allows overriding those heuristics.
+
+    * If `"cusolver"` is set then cuSOLVER will be used wherever possible.
+    * If `"magma"` is set then MAGMA will be used wherever possible.
+    * If `"default"` (the default) is set then heuristics will be used to pick between
+      cuSOLVER and MAGMA if both are available.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
+      globally.
+      This flag only sets the initial value of the preferred library and the preferred library
+      may still be overridden by this function call later in your script.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
+    for your application's inputs.
+
+    Currently supported linalg operators:
+
+    * :func:`torch.linalg.inv`
+    * :func:`torch.linalg.inv_ex`
+    * :func:`torch.linalg.cholesky`
+    * :func:`torch.linalg.cholesky_ex`
+    * :func:`torch.cholesky_solve`
+    * :func:`torch.cholesky_inverse`
+    * :func:`torch.linalg.lu_factor`
+    * :func:`torch.linalg.lu`
+    * :func:`torch.linalg.lu_solve`
+    * :func:`torch.linalg.qr`
+    * :func:`torch.linalg.eigh`
+    * :func:`torch.linalg.eighvals`
+    * :func:`torch.linalg.svd`
+    * :func:`torch.linalg.svdvals`
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _LinalgBackends:
+            raise RuntimeError(
+                f"Unknown input value. Choose from: {_LinalgBackends_str}."
+            )
+        torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
+    elif isinstance(backend, torch._C._LinalgBackend):
+        torch._C._set_linalg_preferred_backend(backend)
+    else:
+        raise RuntimeError("Unknown input value type.")
+
+    return torch._C._get_linalg_preferred_backend()
+
+
+_BlasBackends = {
+    "default": torch._C._BlasBackend.Default,
+    "cublas": torch._C._BlasBackend.Cublas,
+    "hipblas": torch._C._BlasBackend.Cublas,  # alias
+    "cublaslt": torch._C._BlasBackend.Cublaslt,
+    "hipblaslt": torch._C._BlasBackend.Cublaslt,  # alias
+    "ck": torch._C._BlasBackend.Ck,
+}
+_BlasBackends_str = ", ".join(_BlasBackends.keys())
+
+
+def preferred_blas_library(
+    backend: Union[None, str, torch._C._BlasBackend] = None,
+) -> torch._C._BlasBackend:
+    r"""
+    Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only].
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
+    For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance.
+    This flag (a :class:`str`) allows overriding which BLAS library to use.
+
+    * If `"cublas"` is set then cuBLAS will be used wherever possible.
+    * If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
+    * If `"ck"` is set then CK will be used wherever possible.
+    * If `"default"` (the default) is set then heuristics will be used to pick between the other options.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
+      globally.
+      This flag only sets the initial value of the preferred library and the preferred library
+      may still be overridden by this function call later in your script.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's library selection is incorrect
+    for your application's inputs.
+
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _BlasBackends:
+            raise RuntimeError(
+                f"Unknown input value. Choose from: {_BlasBackends_str}."
+            )
+        torch._C._set_blas_preferred_backend(_BlasBackends[backend])
+    elif isinstance(backend, torch._C._BlasBackend):
+        torch._C._set_blas_preferred_backend(backend)
+    else:
+        raise RuntimeError("Unknown input value type.")
+
+    return torch._C._get_blas_preferred_backend()
+
+
+_ROCmFABackends = {
+    "default": torch._C._ROCmFABackend.Default,
+    "aotriton": torch._C._ROCmFABackend.AOTriton,
+    "ck": torch._C._ROCmFABackend.Ck,
+}
+_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys())
+
+
+from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
+
+
+def preferred_rocm_fa_library(
+    backend: Union[None, str, torch._C._ROCmFABackend] = None,
+) -> torch._C._ROCmFABackend:
+    r"""
+    [ROCm-only]
+    Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend.
+    This flag (a :class:`str`) allows users to override this backend to use composable_kernel
+
+    * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton.
+    * If `"aotriton"` is set then AOTriton will be used wherever possible.
+    * If `"ck"` is set then CK will be used wherever possible.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK
+      globally.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's library selection is incorrect
+    for your application's inputs.
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _ROCmFABackends:
+            raise RuntimeError(
+                f"Unknown input value. Choose from: {_ROCmFABackends_str}."
+            )
+        torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend])
+    elif isinstance(backend, torch._C._ROCmFABackend):
+        torch._C._set_rocm_fa_preferred_backend(backend)
+    else:
+        raise ValueError(f"Unknown input value. Choose from: {_ROCmFABackends_str}.")
+
+    return torch._C._get_rocm_fa_preferred_backend()
+
+
+# Set the __module__ attribute
+SDPAParams.__module__ = "torch.backends.cuda"
+SDPAParams.__name__ = "SDPAParams"
+
+
+def flash_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether flash scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_flash_sdp_enabled()
+
+
+def enable_flash_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables flash scaled dot product attention.
+    """
+    torch._C._set_sdp_use_flash(enabled)
+
+
+def mem_efficient_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether memory efficient scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_mem_efficient_sdp_enabled()
+
+
+def enable_mem_efficient_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables memory efficient scaled dot product attention.
+    """
+    torch._C._set_sdp_use_mem_efficient(enabled)
+
+
+def math_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether math scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_math_sdp_enabled()
+
+
+def enable_math_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables math scaled dot product attention.
+    """
+    torch._C._set_sdp_use_math(enabled)
+
+
+def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables fp16/bf16 reduction in math scaled dot product attention.
+    """
+    torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
+
+
+def fp16_bf16_reduction_math_sdp_allowed():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
+
+
+def is_flash_attention_available() -> bool:
+    r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.
+
+    Returns:
+        True if FlashAttention is built and available; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._is_flash_attention_available()
+
+
+def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if FlashAttention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn debug information as to why FlashAttention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if FlashAttention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_flash_attention(params, debug)
+
+
+def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if efficient_attention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn with information as to why efficient_attention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if efficient_attention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_mem_efficient_attention(params, debug)
+
+
+def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn with information as to why cuDNN attention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if cuDNN can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_cudnn_attention(params, debug)
+
+
+def cudnn_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether cuDNN scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_cudnn_sdp_enabled()
+
+
+def enable_cudnn_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables cuDNN scaled dot product attention.
+    """
+    torch._C._set_sdp_use_cudnn(enabled)
+
+
+@contextlib.contextmanager
+@deprecated(
+    (
+        "`torch.backends.cuda.sdp_kernel()` is deprecated. "
+        "In the future, this context manager will be removed. "
+        "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, "
+        "with updated signature."
+    ),
+    category=FutureWarning,
+)
+def sdp_kernel(
+    enable_flash: bool = True,
+    enable_math: bool = True,
+    enable_mem_efficient: bool = True,
+    enable_cudnn: bool = True,
+):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
+    Upon exiting the context manager, the previous state of the flags will be restored.
+    """
+    from torch.nn.attention import sdpa_kernel
+
+    backend_list = []
+    if enable_flash:
+        backend_list.append(SDPBackend.FLASH_ATTENTION)
+    if enable_mem_efficient:
+        backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
+    if enable_math:
+        backend_list.append(SDPBackend.MATH)
+    if enable_cudnn:
+        backend_list.append(SDPBackend.CUDNN_ATTENTION)
+
+    with sdpa_kernel(backend_list) as context:
+        try:
+            yield context
+        finally:
+            pass
+
+
+cufft_plan_cache = cuFFTPlanCacheManager()
+matmul = cuBLASModule()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af1d57b00eb509262599ab4fb26b97af33e41eb0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd6ec297c7a8a21c407e12112ba961b76624a6f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py
@@ -0,0 +1,248 @@
+# mypy: allow-untyped-defs
+import os
+import sys
+import warnings
+from contextlib import contextmanager
+from typing import Optional
+
+import torch
+from torch.backends import (
+    __allow_nonbracketed_mutation,
+    _FP32Precision,
+    _get_fp32_precision_getter,
+    _set_fp32_precision_setter,
+    ContextProp,
+    PropModule,
+)
+
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    _cudnn = None  # type: ignore[assignment]
+
+# Write:
+#
+#   torch.backends.cudnn.enabled = False
+#
+# to globally disable CuDNN/MIOpen
+
+__cudnn_version: Optional[int] = None
+
+if _cudnn is not None:
+
+    def _init():
+        global __cudnn_version
+        if __cudnn_version is None:
+            # pyrefly: ignore [missing-attribute]
+            __cudnn_version = _cudnn.getVersionInt()
+            # pyrefly: ignore [missing-attribute]
+            runtime_version = _cudnn.getRuntimeVersion()
+            # pyrefly: ignore [missing-attribute]
+            compile_version = _cudnn.getCompileVersion()
+            runtime_major, runtime_minor, _ = runtime_version
+            compile_major, compile_minor, _ = compile_version
+            # Different major versions are always incompatible
+            # Starting with cuDNN 7, minor versions are backwards-compatible
+            # Not sure about MIOpen (ROCm), so always do a strict check
+            if runtime_major != compile_major:
+                cudnn_compatible = False
+            # pyrefly: ignore [missing-attribute]
+            elif runtime_major < 7 or not _cudnn.is_cuda:
+                cudnn_compatible = runtime_minor == compile_minor
+            else:
+                cudnn_compatible = runtime_minor >= compile_minor
+            if not cudnn_compatible:
+                if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
+                    return True
+                base_error_msg = (
+                    f"cuDNN version incompatibility: "
+                    f"PyTorch was compiled  against {compile_version} "
+                    f"but found runtime version {runtime_version}. "
+                    f"PyTorch already comes bundled with cuDNN. "
+                    f"One option to resolving this error is to ensure PyTorch "
+                    f"can find the bundled cuDNN. "
+                )
+
+                if "LD_LIBRARY_PATH" in os.environ:
+                    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
+                    if any(
+                        substring in ld_library_path for substring in ["cuda", "cudnn"]
+                    ):
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. "
+                            f"Please either remove it from the path or install cudnn {compile_version}"
+                        )
+                    else:
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"one possibility is that there is a "
+                            f"conflicting cuDNN in LD_LIBRARY_PATH."
+                        )
+                else:
+                    raise RuntimeError(base_error_msg)
+            # Check if cuDNN version is compatible with available CUDA devices
+            if torch.cuda.is_available() and not torch.version.hip:
+                min_cc = min(
+                    [
+                        torch.cuda.get_device_capability(i)
+                        for i in range(torch.cuda.device_count())
+                    ]
+                )
+                if __cudnn_version >= 91100 and min_cc < (7, 5):
+                    raise RuntimeError(
+                        f"cuDNN version {__cudnn_version} is not compatible with devices with SM < 7.5. "
+                        f"Please install a version of PyTorch with a compatible cuDNN version. "
+                        f"https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix"
+                    )
+
+        return True
+
+else:
+
+    def _init():
+        return False
+
+
+def version():
+    """Return the version of cuDNN."""
+    if not _init():
+        return None
+    return __cudnn_version
+
+
+CUDNN_TENSOR_DTYPES = {
+    torch.half,
+    torch.float,
+    torch.double,
+}
+
+
+def is_available():
+    r"""Return a bool indicating if CUDNN is currently available."""
+    return torch._C._has_cudnn
+
+
+def is_acceptable(tensor):
+    if not torch._C._get_cudnn_enabled():
+        return False
+    if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES:
+        return False
+    if not is_available():
+        warnings.warn(
+            "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
+            "PyTorch making sure the library is visible to the build system.",
+            stacklevel=2,
+        )
+        return False
+    if not _init():
+        warnings.warn(
+            "cuDNN/MIOpen library not found. Check your {libpath}".format(
+                libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get(
+                    sys.platform, "LD_LIBRARY_PATH"
+                )
+            ),
+            stacklevel=2,
+        )
+        return False
+    return True
+
+
+def set_flags(
+    _enabled=None,
+    _benchmark=None,
+    _benchmark_limit=None,
+    _deterministic=None,
+    _allow_tf32=None,
+    _fp32_precision="none",
+):
+    orig_flags = (
+        torch._C._get_cudnn_enabled(),
+        torch._C._get_cudnn_benchmark(),
+        None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
+        torch._C._get_cudnn_deterministic(),
+        torch._C._get_cudnn_allow_tf32(),
+        torch._C._get_fp32_precision_getter("cuda", "all"),
+    )
+    if _enabled is not None:
+        torch._C._set_cudnn_enabled(_enabled)
+    if _benchmark is not None:
+        torch._C._set_cudnn_benchmark(_benchmark)
+    if _benchmark_limit is not None and is_available():
+        torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
+    if _deterministic is not None:
+        torch._C._set_cudnn_deterministic(_deterministic)
+    if _allow_tf32 is not None:
+        torch._C._set_cudnn_allow_tf32(_allow_tf32)
+    if _fp32_precision is not None:
+        torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision)
+    return orig_flags
+
+
+@contextmanager
+def flags(
+    enabled=False,
+    benchmark=False,
+    benchmark_limit=10,
+    deterministic=False,
+    allow_tf32=True,
+    fp32_precision="none",
+):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(
+            enabled,
+            benchmark,
+            benchmark_limit,
+            deterministic,
+            allow_tf32,
+            fp32_precision,
+        )
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends..enabled = True
+
+
+class CudnnModule(PropModule):
+    enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
+    deterministic = ContextProp(
+        torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic
+    )
+    benchmark = ContextProp(
+        torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark
+    )
+    benchmark_limit = None
+    if is_available():
+        benchmark_limit = ContextProp(
+            torch._C._cuda_get_cudnn_benchmark_limit,
+            torch._C._cuda_set_cudnn_benchmark_limit,
+        )
+    allow_tf32 = ContextProp(
+        torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
+    )
+    conv = _FP32Precision("cuda", "conv")
+    rnn = _FP32Precision("cuda", "rnn")
+    fp32_precision = ContextProp(
+        _get_fp32_precision_getter("cuda", "all"),
+        _set_fp32_precision_setter("cuda", "all"),
+    )
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
+
+# Add type annotation for the replaced module
+enabled: bool
+deterministic: bool
+benchmark: bool
+allow_tf32: bool
+benchmark_limit: int
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6fa22684efbe6e769a0cd88bd115f509cf74744
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a78297561efb07d81f95ebcb31073078244c56e0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc9ca80aa6fd10efc41910d38ba33d00852729c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py
@@ -0,0 +1,69 @@
+# mypy: allow-untyped-defs
+import torch.cuda
+
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
+    # so it's safe to not emit any checks here.
+    _cudnn = None  # type: ignore[assignment]
+
+
+def get_cudnn_mode(mode):
+    if mode == "RNN_RELU":
+        # pyrefly: ignore [missing-attribute]
+        return int(_cudnn.RNNMode.rnn_relu)
+    elif mode == "RNN_TANH":
+        # pyrefly: ignore [missing-attribute]
+        return int(_cudnn.RNNMode.rnn_tanh)
+    elif mode == "LSTM":
+        # pyrefly: ignore [missing-attribute]
+        return int(_cudnn.RNNMode.lstm)
+    elif mode == "GRU":
+        # pyrefly: ignore [missing-attribute]
+        return int(_cudnn.RNNMode.gru)
+    else:
+        raise Exception(f"Unknown mode: {mode}")  # noqa: TRY002
+
+
+# NB: We don't actually need this class anymore (in fact, we could serialize the
+# dropout state for even better reproducibility), but it is kept for backwards
+# compatibility for old models.
+class Unserializable:
+    def __init__(self, inner):
+        self.inner = inner
+
+    def get(self):
+        return self.inner
+
+    def __getstate__(self):
+        # Note: can't return {}, because python2 won't call __setstate__
+        # if the value evaluates to False
+        return ""
+
+    def __setstate__(self, state):
+        self.inner = None
+
+
+def init_dropout_state(dropout, train, dropout_seed, dropout_state):
+    dropout_desc_name = "desc_" + str(torch.cuda.current_device())
+    dropout_p = dropout if train else 0
+    if (dropout_desc_name not in dropout_state) or (
+        dropout_state[dropout_desc_name].get() is None
+    ):
+        if dropout_p == 0:
+            dropout_state[dropout_desc_name] = Unserializable(None)
+        else:
+            dropout_state[dropout_desc_name] = Unserializable(
+                torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
+                    dropout_p,
+                    train,
+                    dropout_seed,
+                    # pyrefly: ignore [unexpected-keyword]
+                    self_ty=torch.uint8,
+                    device=torch.device("cuda"),
+                )
+            )
+    dropout_ts = dropout_state[dropout_desc_name].get()
+    return dropout_ts
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e9b9df2acf144e00a193ee312f728ec30327f8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py
@@ -0,0 +1,57 @@
+from typing import Optional
+
+import torch
+
+
+__all__ = [
+    "version",
+    "is_available",
+    "get_max_alg_id",
+]
+
+try:
+    from torch._C import _cusparselt
+except ImportError:
+    _cusparselt = None  # type: ignore[assignment]
+
+__cusparselt_version: Optional[int] = None
+__MAX_ALG_ID: Optional[int] = None
+
+if _cusparselt is not None:
+
+    def _init() -> bool:
+        global __cusparselt_version
+        global __MAX_ALG_ID
+        if __cusparselt_version is None:
+            # pyrefly: ignore [missing-attribute]
+            __cusparselt_version = _cusparselt.getVersionInt()
+            if __cusparselt_version == 400:
+                __MAX_ALG_ID = 4
+            elif __cusparselt_version == 502:
+                __MAX_ALG_ID = 5
+            elif __cusparselt_version == 602:
+                __MAX_ALG_ID = 37
+        return True
+
+else:
+
+    def _init() -> bool:
+        return False
+
+
+def version() -> Optional[int]:
+    """Return the version of cuSPARSELt"""
+    if not _init():
+        return None
+    return __cusparselt_version
+
+
+def is_available() -> bool:
+    r"""Return a bool indicating if cuSPARSELt is currently available."""
+    return torch._C._has_cusparselt
+
+
+def get_max_alg_id() -> Optional[int]:
+    if not _init():
+        return None
+    return __MAX_ALG_ID
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a28422656e308c41e6b43c4f112f6ee52c4e97a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a681b77ef58ce1f390232b82c4a9843d5559ca3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py
@@ -0,0 +1,7 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with KleidiAI support."""
+    return torch._C._has_kleidiai
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac52a2482c9d9bdd44c8072d6b4b73213c3eb04e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1dd2ebd688805bdf3359cb56b64d0854cf258c4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__init__.py
@@ -0,0 +1,25 @@
+# Config options to enable/disable C++ kernel for nn.functional.MHA
+# and nn.TransformerEncoder
+import torch
+
+
+_is_fastpath_enabled: bool = True
+
+
+def get_fastpath_enabled() -> bool:
+    """Returns whether fast path for TransformerEncoder and MultiHeadAttention
+    is enabled, or ``True`` if jit is scripting.
+
+    .. note::
+        The fastpath might not be run even if ``get_fastpath_enabled`` returns
+        ``True`` unless all conditions on inputs are met.
+    """
+    if not torch.jit.is_scripting():
+        return _is_fastpath_enabled
+    return True
+
+
+def set_fastpath_enabled(value: bool) -> None:
+    """Sets whether fast path is enabled"""
+    global _is_fastpath_enabled
+    _is_fastpath_enabled = value
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4d743f7acebea5fc1ab4b9cad16fd5c39e809f5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b270b658e31a91dfb37380abec383009dfc5bfa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__init__.py
@@ -0,0 +1,50 @@
+# mypy: allow-untyped-defs
+import sys
+from contextlib import contextmanager
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+def set_flags(
+    _immediate=None,
+):
+    orig_flags = (torch._C._get_miopen_immediate(),)
+    if _immediate is not None:
+        torch._C._set_miopen_immediate(_immediate)
+    return orig_flags
+
+
+@contextmanager
+def flags(
+    immediate=False,
+):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(
+            immediate,
+        )
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends..immediate = True
+
+
+class MiopenModule(PropModule):
+    immediate = ContextProp(
+        torch._C._get_miopen_immediate, torch._C._set_miopen_immediate
+    )
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__)
+
+# Add type annotation for the replaced module
+immediate: bool
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52680143f53e08dbaaf7a521f6e6efb982169fa7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/miopen/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae16922761afeafa53766757641bcc532b4d5ef4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__init__.py
@@ -0,0 +1,58 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL support."""
+    return torch._C.has_mkl
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+
+
+class verbose:
+    """
+    On-demand oneMKL verbosing functionality.
+
+    To make it easier to debug performance issues, oneMKL can dump verbose
+    messages containing execution information like duration while executing
+    the kernel. The verbosing functionality can be invoked via an environment
+    variable named `MKL_VERBOSE`. However, this methodology dumps messages in
+    all steps. Those are a large amount of verbose messages. Moreover, for
+    investigating the performance issues, generally taking verbose messages
+    for one single iteration is enough. This on-demand verbosing functionality
+    makes it possible to control scope for verbose message dumping. In the
+    following example, verbose messages will be dumped out for the second
+    inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+
+        model(data)
+        with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+    """
+
+    def __init__(self, enable):
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkl_set_verbose(self.enable)
+        assert st, (
+            "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
+        )
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
+        return False
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74c104e38dc710c4b106888e26ddd11af017182d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..58e6b2c595e9853942b0a3a58a6e5ab2627d3608
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py
@@ -0,0 +1,137 @@
+# mypy: allow-untyped-defs
+import sys
+from contextlib import contextmanager
+from typing import TYPE_CHECKING
+
+import torch
+from torch.backends import (
+    __allow_nonbracketed_mutation,
+    _FP32Precision,
+    _get_fp32_precision_getter,
+    _set_fp32_precision_setter,
+    ContextProp,
+    PropModule,
+)
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL-DNN support."""
+    return torch._C._has_mkldnn
+
+
+def is_acl_available():
+    r"""Return whether PyTorch is built with MKL-DNN + ACL support."""
+    # pyrefly: ignore [missing-attribute]
+    return torch._C._has_mkldnn_acl
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+VERBOSE_ON_CREATION = 2
+
+
+class verbose:
+    """
+    On-demand oneDNN (former MKL-DNN) verbosing functionality.
+
+    To make it easier to debug performance issues, oneDNN can dump verbose
+    messages containing information like kernel size, input data size and
+    execution duration while executing the kernel. The verbosing functionality
+    can be invoked via an environment variable named `DNNL_VERBOSE`. However,
+    this methodology dumps messages in all steps. Those are a large amount of
+    verbose messages. Moreover, for investigating the performance issues,
+    generally taking verbose messages for one single iteration is enough.
+    This on-demand verbosing functionality makes it possible to control scope
+    for verbose message dumping. In the following example, verbose messages
+    will be dumped out for the second inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+
+        model(data)
+        with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+            - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation
+    """
+
+    def __init__(self, level):
+        self.level = level
+
+    def __enter__(self):
+        if self.level == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkldnn_set_verbose(self.level)
+        assert st, (
+            "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope."
+        )
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF)
+        return False
+
+
+def set_flags(
+    _enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none"
+):
+    orig_flags = (
+        torch._C._get_mkldnn_enabled(),
+        torch._C._get_mkldnn_deterministic(),
+        torch._C._get_onednn_allow_tf32(),
+        torch._C._get_fp32_precision_getter("mkldnn", "all"),
+    )
+    if _enabled is not None:
+        torch._C._set_mkldnn_enabled(_enabled)
+    if _deterministic is not None:
+        torch._C._set_mkldnn_deterministic(_deterministic)
+    if _allow_tf32 is not None:
+        torch._C._set_onednn_allow_tf32(_allow_tf32)
+    if _fp32_precision is not None:
+        torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+class MkldnnModule(PropModule):
+    def is_available(self):
+        return is_available()
+
+    enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
+    deterministic = ContextProp(
+        torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic
+    )
+    allow_tf32 = ContextProp(
+        torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
+    )
+    matmul = _FP32Precision("mkldnn", "matmul")
+    conv = _FP32Precision("mkldnn", "conv")
+    rnn = _FP32Precision("mkldnn", "rnn")
+    fp32_precision = ContextProp(
+        _get_fp32_precision_getter("mkldnn", "all"),
+        _set_fp32_precision_setter("generic", "all"),
+    )
+
+
+if TYPE_CHECKING:
+    enabled: ContextProp
+    deterministic: ContextProp
+    allow_tf32: ContextProp
+
+sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfd63c38dd8424ac2d1fb66d1660833d41c3da18
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c3c507428cfff85a02e1d9939b4951d7e8b84bf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__init__.py
@@ -0,0 +1,78 @@
+from functools import lru_cache as _lru_cache
+from typing import Optional, TYPE_CHECKING
+
+import torch
+from torch.library import Library as _Library
+
+
+__all__ = [
+    "get_core_count",
+    "get_name",
+    "is_built",
+    "is_available",
+    "is_macos13_or_newer",
+    "is_macos_or_newer",
+]
+
+
+def is_built() -> bool:
+    r"""Return whether PyTorch is built with MPS support.
+
+    Note that this doesn't necessarily mean MPS is available; just that
+    if this PyTorch binary were run a machine with working MPS drivers
+    and devices, we would be able to use it.
+    """
+    return torch._C._has_mps
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if MPS is currently available."""
+    return torch._C._mps_is_available()
+
+
+@_lru_cache
+def is_macos_or_newer(major: int, minor: int) -> bool:
+    r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
+    return torch._C._mps_is_on_macos_or_newer(major, minor)
+
+
+@_lru_cache
+def is_macos13_or_newer(minor: int = 0) -> bool:
+    r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
+    return torch._C._mps_is_on_macos_or_newer(13, minor)
+
+
+@_lru_cache
+def get_name() -> str:
+    r"""Return Metal device name"""
+    return torch._C._mps_get_name()
+
+
+@_lru_cache
+def get_core_count() -> int:
+    r"""Return GPU core count.
+
+    According to the documentation, one core is comprised of 16 Execution Units.
+    One execution Unit has 8 ALUs.
+    And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
+    """
+    return torch._C._mps_get_core_count()
+
+
+_lib: Optional[_Library] = None
+
+
+def _init() -> None:
+    r"""Register prims as implementation of var_mean and group_norm."""
+    global _lib
+
+    if _lib is not None or not is_built():
+        return
+
+    from torch._decomp.decompositions import native_group_norm_backward
+    from torch._refs import native_group_norm
+
+    _lib = _Library("aten", "IMPL")  # noqa: TOR901
+    _lib.impl("native_group_norm", native_group_norm, "MPS")
+    _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfaad57f1c9125ccd1138de55215167e1af94509
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d8a72f3cda9b0da16702c0d7c6fe92ae8f3f153
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py
@@ -0,0 +1,32 @@
+# mypy: allow-untyped-defs
+from contextlib import contextmanager
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+__all__ = ["is_available", "flags", "set_flags"]
+
+
+def is_available():
+    r"""Return whether PyTorch is built with NNPACK support."""
+    return torch._nnpack_available()
+
+
+def set_flags(_enabled):
+    r"""Set if nnpack is enabled globally"""
+    orig_flags = (torch._C._get_nnpack_enabled(),)
+    torch._C._set_nnpack_enabled(_enabled)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False):
+    r"""Context manager for setting if nnpack is enabled globally"""
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(orig_flags[0])
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dbd06623781b0b9bbe63b2f4169bc8b46170471
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8d46cd4ac2d9ff49942542d99ac2afbb85896
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__init__.py
@@ -0,0 +1,7 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with OpenMP support."""
+    return torch._C.has_openmp
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb425e2fbe83417be418b9ecd8a7c3bd727ac16b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..264be78aa9a1c24a4624a87782e2b2c5afd29c05
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py
@@ -0,0 +1,117 @@
+# mypy: allow-untyped-defs
+import sys
+import warnings
+from contextlib import contextmanager
+from functools import lru_cache as _lru_cache
+from typing import Any
+
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+try:
+    import opt_einsum as _opt_einsum  # type: ignore[import]
+except ImportError:
+    _opt_einsum = None
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if opt_einsum is currently available.
+
+    You must install opt-einsum in order for torch to automatically optimize einsum. To
+    make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]``
+    or by itself: ``pip install opt-einsum``. If the package is installed, torch will import
+    it automatically and use it accordingly. Use this function to check whether opt-einsum
+    was installed and properly imported by torch.
+    """
+    return _opt_einsum is not None
+
+
+def get_opt_einsum() -> Any:
+    r"""Return the opt_einsum package if opt_einsum is currently available, else None."""
+    return _opt_einsum
+
+
+def _set_enabled(_enabled: bool) -> None:
+    if not is_available() and _enabled:
+        raise ValueError(
+            f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap "
+            "the benefits of calculating an optimal path for einsum. torch.einsum will "
+            "fall back to contracting from left to right. To enable this optimal path "
+            "calculation, please install opt-einsum."
+        )
+    global enabled
+    enabled = _enabled
+
+
+def _get_enabled() -> bool:
+    return enabled
+
+
+def _set_strategy(_strategy: str) -> None:
+    if not is_available():
+        raise ValueError(
+            f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please install opt_einsum or unset `strategy`."
+        )
+    if not enabled:
+        raise ValueError(
+            f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please set `enabled` to `True` as well or unset `strategy`."
+        )
+    if _strategy not in ["auto", "greedy", "optimal"]:
+        raise ValueError(
+            f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}"
+        )
+    global strategy
+    strategy = _strategy
+
+
+def _get_strategy() -> str:
+    # pyrefly: ignore [bad-return]
+    return strategy
+
+
+def set_flags(_enabled=None, _strategy=None):
+    orig_flags = (enabled, None if not is_available() else strategy)
+    if _enabled is not None:
+        _set_enabled(_enabled)
+    if _strategy is not None:
+        _set_strategy(_strategy)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=None, strategy=None):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled, strategy)
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends.opt_einsum.enabled = True
+
+
+class OptEinsumModule(PropModule):
+    global enabled
+    enabled = ContextProp(_get_enabled, _set_enabled)
+    global strategy
+    strategy = None
+    if is_available():
+        strategy = ContextProp(_get_strategy, _set_strategy)
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
+
+enabled = bool(is_available())
+strategy = "auto" if is_available() else None
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3526e3073d67cb637373123d64af6a442b80e14
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..caabfdf243783f2161a201c6a6ec9bd6eca83b18
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__init__.py
@@ -0,0 +1,65 @@
+# mypy: allow-untyped-defs
+import sys
+import types
+
+import torch
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_id(qengine: str) -> int:
+    if qengine == "none" or qengine == "" or qengine is None:
+        ret = 0
+    elif qengine == "fbgemm":
+        ret = 1
+    elif qengine == "qnnpack":
+        ret = 2
+    elif qengine == "onednn":
+        ret = 3
+    elif qengine == "x86":
+        ret = 4
+    else:
+        ret = -1
+        raise RuntimeError(f"{qengine} is not a valid value for quantized engine")
+    return ret
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_str(qengine: int) -> str:
+    all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"}
+    return all_engines.get(qengine, "*undefined")
+
+
+class _QEngineProp:
+    def __get__(self, obj, objtype) -> str:
+        return _get_qengine_str(torch._C._get_qengine())
+
+    def __set__(self, obj, val: str) -> None:
+        torch._C._set_qengine(_get_qengine_id(val))
+
+
+class _SupportedQEnginesProp:
+    def __get__(self, obj, objtype) -> list[str]:
+        qengines = torch._C._supported_qengines()
+        return [_get_qengine_str(qe) for qe in qengines]
+
+    def __set__(self, obj, val) -> None:
+        raise RuntimeError("Assignment not supported")
+
+
+class QuantizedEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    engine = _QEngineProp()
+    supported_engines = _SupportedQEnginesProp()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
+engine: str
+supported_engines: list[str]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b25a57fb16db6ee5eac43070a9db9f2c9313118e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8202c56920e74796f1c3bb4f9f7410586cf56a2f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3002a6508ff8bc4b7a08231004e63b17fc929eb5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b6bdf78991dcc140d9fedf2be2ea3ba6dedf74
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py
@@ -0,0 +1,947 @@
+# mypy: allow-untyped-defs
+"""
+This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations.
+
+Single instance inference, multi-instance inference are enabled.
+
+Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes
+multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this
+context.
+
+Illustrated as below:
+
+::
+
+    +-----------------------------+----------------------+-------+
+    |            process          |        thread        | core  |
+    +=============================+======================+=======+
+    | torch.backends.xeon.run_cpu | instance 0: thread 0 |   0   |
+    |                             |             thread 1 |   1   |
+    |                             +----------------------+-------+
+    |                             | instance 1: thread 0 |   2   |
+    |                             |             thread 1 |   3   |
+    |                             +----------------------+-------+
+    |                             | ...                  |  ...  |
+    |                             +----------------------+-------+
+    |                             | instance N: thread 0 |   M   |
+    |                             |             thread 1 |  M+1  |
+    +-----------------------------+----------------------+-------+
+
+To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory
+management. For thread management, the script configures thread affinity and the preload of Intel OMP library.
+For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc).
+
+Environment variables that will be set by this script:
+
++------------------+-------------------------------------------------------------------------------------------------+
+| Environ Variable |                                             Value                                               |
++==================+=================================================================================================+
+|    LD_PRELOAD    | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might |
+|                  | be appended to LD_PRELOAD.                                                                      |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_AFFINITY   | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0".       |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_BLOCKTIME  | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1".                                       |
++------------------+-------------------------------------------------------------------------------------------------+
+|  OMP_NUM_THREADS | value of ncores_per_instance                                                                    |
++------------------+-------------------------------------------------------------------------------------------------+
+|    MALLOC_CONF   | If libjemalloc.so is preloaded, MALLOC_CONF will be set to                                      |
+|                  | "oversize_threshold:1,background_thread:true,metadata_thp:auto".                                |
++------------------+-------------------------------------------------------------------------------------------------+
+
+*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables
+mentioned above before running the script, the script will not overwrite the values in the script.
+
+How to use this module:
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Single instance inference
+-------------------------
+
+1. Run single-instance inference on a single node with all CPU nodes.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --throughput-mode script.py args
+
+2. Run single-instance inference on a single CPU node.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --node-id 1 script.py args
+
+Multi-instance inference
+------------------------
+
+1. Multi-instance
+   By default this tool runs one process per node. If you want to set the instance numbers and core per instance,
+   --ninstances and  --ncores-per-instance should be set.
+
+::
+
+   python -m torch.backends.xeon.run_cpu -- python_script args
+
+   eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores-per-instance 4 python_script args
+
+2. Run single-instance inference among multiple instances.
+   By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank.
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args
+
+   eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance,
+   first four cores (i.e., numactl -C 0-1)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --core-list "0, 1, 2, 3" --ninstances 2 --ncores-per-instance 2
+   --rank 0 python_script args
+
+3. To look up what optional arguments this module offers:
+
+::
+
+    python -m torch.backends.xeon.run_cpu --help
+
+Memory allocator
+----------------
+
+"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allocator.
+
+"""
+
+import glob
+import logging
+import os
+import platform
+import re
+import subprocess
+import sys
+from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER
+from os.path import expanduser
+
+from torch.distributed.elastic.multiprocessing import (
+    DefaultLogsSpecs as _DefaultLogsSpecs,
+    start_processes,
+    Std,
+)
+
+
+format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+logging.basicConfig(level=logging.INFO, format=format_str)
+logger = logging.getLogger(__name__)
+
+
+class _CPUinfo:
+    """Get CPU information, such as cores list and NUMA information."""
+
+    def __init__(self, test_input=""):
+        self.cpuinfo = []
+        if platform.system() in ["Windows", "Darwin"]:
+            raise RuntimeError(f"{platform.system()} is not supported!!!")
+        elif platform.system() == "Linux":
+            # Sample output of: `lscpu --parse=CPU,Core,Socket,Node`
+            #
+            # # The following is the parsable format, which can be fed to other
+            # # programs. Each different item in every column has an unique ID
+            # # starting from zero.
+            # # CPU,Core,Socket,Node
+            # 0,0,0,0
+            # 1,1,0,0
+            # ...
+            if test_input == "":
+                lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"]
+                lscpu_info = subprocess.check_output(
+                    lscpu_cmd, universal_newlines=True
+                ).split("\n")
+            else:
+                lscpu_info = test_input.split("\n")
+
+            # Get information about  cpu, core, socket and node
+            for line in lscpu_info:
+                pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)"
+                regex_out = re.search(pattern, line)
+                if regex_out:
+                    self.cpuinfo.append(regex_out.group(1).strip().split(","))
+
+            # physical cores := core column in lscpu output
+            #  logical cores :=  cPU column in lscpu output
+            self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1
+            self.node_physical_cores: list[list[int]] = []  # node_id is index
+            self.node_logical_cores: list[list[int]] = []  # node_id is index
+            self.physical_core_node_map = {}  # physical core to numa node id
+            self.logical_core_node_map = {}  # logical core to numa node id
+
+            for node_id in range(self.node_nums):
+                cur_node_physical_core = []
+                cur_node_logical_core = []
+                for cpuinfo in self.cpuinfo:
+                    nid = cpuinfo[3] if cpuinfo[3] != "" else "0"
+                    if node_id == int(nid):
+                        if int(cpuinfo[1]) not in cur_node_physical_core:
+                            cur_node_physical_core.append(int(cpuinfo[1]))
+                            self.physical_core_node_map[int(cpuinfo[1])] = int(node_id)
+                        cur_node_logical_core.append(int(cpuinfo[0]))
+                        self.logical_core_node_map[int(cpuinfo[0])] = int(node_id)
+                self.node_physical_cores.append(cur_node_physical_core)
+                self.node_logical_cores.append(cur_node_logical_core)
+
+    def _physical_core_nums(self):
+        return len(self.node_physical_cores) * len(self.node_physical_cores[0])
+
+    def _logical_core_nums(self):
+        return len(self.node_logical_cores) * len(self.node_logical_cores[0])
+
+    def get_node_physical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_physical_cores[node_id]
+
+    def get_node_logical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_logical_cores[node_id]
+
+    def get_all_physical_cores(self):
+        all_cores = []
+        for cores in self.node_physical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def get_all_logical_cores(self):
+        all_cores = []
+        for cores in self.node_logical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def numa_aware_check(self, core_list):
+        """
+        Check whether all cores in core_list are in the same NUMA node.
+
+        Cross NUMA will reduce performance.
+        We strongly advice to not use cores on different nodes.
+        """
+        cores_numa_map = self.logical_core_node_map
+        numa_ids = []
+        for core in core_list:
+            numa_id = cores_numa_map[core]
+            if numa_id not in numa_ids:
+                numa_ids.append(numa_id)
+        if len(numa_ids) > 1:
+            logger.warning(
+                "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \
+this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\
+instance. Alternatively, please use --skip-cross-node-cores knob.",
+                str(core_list),
+                str(numa_ids),
+            )
+        if len(numa_ids) == 0:
+            raise RuntimeError(
+                "invalid number of NUMA nodes; please make sure numa_ids >= 1"
+            )
+        return numa_ids
+
+
+class _Launcher:
+    r"""Class for launcher."""
+
+    msg_lib_notfound = (
+        f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \
+or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
+{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
+    )
+
+    def __init__(self) -> None:
+        self.cpuinfo = _CPUinfo()
+
+    def add_lib_preload(self, lib_type):
+        """Enable TCMalloc/JeMalloc/intel OpenMP."""
+        library_paths = []
+        if "CONDA_PREFIX" in os.environ:
+            library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib")
+        if "VIRTUAL_ENV" in os.environ:
+            library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib")
+
+        library_paths += [
+            f"{expanduser('~')}/.local/lib",
+            "/usr/local/lib",
+            "/usr/local/lib64",
+            "/usr/lib",
+            "/usr/lib64",
+        ]
+
+        lib_find = False
+        lib_set = False
+        for item in os.getenv("LD_PRELOAD", "").split(":"):
+            if item.endswith(f"lib{lib_type}.so"):
+                lib_set = True
+                break
+        if not lib_set:
+            for lib_path in library_paths:
+                library_file = os.path.join(lib_path, f"lib{lib_type}.so")
+                matches = glob.glob(library_file)
+                if len(matches) > 0:
+                    ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")]
+                    os.environ["LD_PRELOAD"] = os.pathsep.join(
+                        [p.strip(os.pathsep) for p in ld_preloads if p]
+                    )
+                    lib_find = True
+                    break
+        return lib_set or lib_find
+
+    def is_numactl_available(self):
+        numactl_available = False
+        try:
+            cmd = ["numactl", "-C", "0", "-m", "0", "hostname"]
+            r = subprocess.run(
+                cmd,
+                env=os.environ,
+                stdout=subprocess.DEVNULL,
+                stderr=subprocess.DEVNULL,
+                check=False,
+            )
+            if r.returncode == 0:
+                numactl_available = True
+        except Exception:
+            pass
+        return numactl_available
+
+    def set_memory_allocator(
+        self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False
+    ):
+        """
+        Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc.
+
+        By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better
+        memory reuse and reduce page fault to improve performance.
+        """
+        if enable_tcmalloc and enable_jemalloc:
+            raise RuntimeError(
+                "Unable to enable TCMalloc and JEMalloc at the same time."
+            )
+
+        if enable_tcmalloc:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if not find_tc:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}'
+                logger.warning(msg.format("TCmalloc", "tcmalloc"))  # noqa: G001
+            else:
+                logger.info("Use TCMalloc memory allocator")
+
+        elif enable_jemalloc:
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if not find_je:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}'
+                logger.warning(msg.format("Jemalloc", "jemalloc"))  # noqa: G001
+            else:
+                logger.info("Use JeMalloc memory allocator")
+                self.set_env(
+                    "MALLOC_CONF",
+                    "oversize_threshold:1,background_thread:true,metadata_thp:auto",
+                )
+
+        elif use_default_allocator:
+            pass
+
+        else:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if find_tc:
+                logger.info("Use TCMalloc memory allocator")
+                return
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if find_je:
+                logger.info("Use JeMalloc memory allocator")
+                return
+            logger.warning(
+                """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib
+                            or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or
+                           %s/.local/lib/ so the LD_PRELOAD environment variable will not be set.
+                           This may drop the performance""",
+                expanduser("~"),
+            )
+
+    def log_env_var(self, env_var_name=""):
+        if env_var_name in os.environ:
+            logger.info("%s=%s", env_var_name, os.environ[env_var_name])
+
+    def set_env(self, env_name, env_value):
+        if not env_value:
+            logger.warning("%s is None", env_name)
+        if env_name not in os.environ:
+            os.environ[env_name] = env_value
+        elif os.environ[env_name] != env_value:
+            logger.warning(
+                "Overriding value with the one set in environment variable: %s. \
+Value applied: %s. Value ignored: %s",
+                env_name,
+                os.environ[env_name],
+                env_value,
+            )
+        self.log_env_var(env_name)
+
+    # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not.
+    # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores.
+    # In this case, KMP_AFFINITY should not be set.
+    def set_multi_thread_and_allocator(
+        self,
+        ncores_per_instance,
+        disable_iomp=False,
+        set_kmp_affinity=True,
+        enable_tcmalloc=True,
+        enable_jemalloc=False,
+        use_default_allocator=False,
+    ):
+        """
+        Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc.
+
+        By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives
+        to get performance benefit.
+        """
+        self.set_memory_allocator(
+            enable_tcmalloc, enable_jemalloc, use_default_allocator
+        )
+        self.set_env("OMP_NUM_THREADS", str(ncores_per_instance))
+        if not disable_iomp:
+            find_iomp = self.add_lib_preload(lib_type="iomp5")
+            if not find_iomp:
+                msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}'
+                logger.warning(msg.format("iomp", "iomp5"))  # noqa: G001
+            else:
+                logger.info("Using Intel OpenMP")
+                if set_kmp_affinity:
+                    self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0")
+                self.set_env("KMP_BLOCKTIME", "1")
+        self.log_env_var("LD_PRELOAD")
+
+    r"""
+     Launcher for single instance and multi-instance
+     """
+
+    def launch(self, args):
+        cores = []
+        set_kmp_affinity = True
+        enable_taskset = False
+        if args.core_list:  # user specify what cores will be used by params
+            cores = [int(x) for x in args.core_list.split(",")]
+            if args.ncores_per_instance == -1:
+                raise RuntimeError(
+                    'please specify the "--ncores-per-instance" if you have pass the --core-list params'
+                )
+            elif (
+                args.ninstances > 1
+                and args.ncores_per_instance * args.ninstances < len(cores)
+            ):
+                logger.warning(
+                    "only first %s cores will be used, \
+but you specify %s cores in core_list",
+                    args.ncores_per_instance * args.ninstances,
+                    len(cores),
+                )
+            else:
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+        else:
+            if args.use_logical_core:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_logical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_logical_cores()
+                    # When using all cores on all nodes, including logical cores,
+                    # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set.
+                    set_kmp_affinity = False
+            else:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_physical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_physical_cores()
+            if (
+                not args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.ninstances = 1
+                args.ncores_per_instance = len(cores)
+            elif (
+                args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.throughput_mode = True
+            elif args.ncores_per_instance == -1 and args.ninstances != -1:
+                if args.ninstances > len(cores):
+                    raise RuntimeError(
+                        f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \
+please make sure ninstances <= total_cores)"
+                    )
+                else:
+                    args.ncores_per_instance = len(cores) // args.ninstances
+            elif args.ncores_per_instance != -1 and args.ninstances == -1:
+                if not args.skip_cross_node_cores:
+                    args.ninstances = len(cores) // args.ncores_per_instance
+                else:
+                    ncore_per_node = len(self.cpuinfo.node_physical_cores[0])
+                    num_leftover_cores = ncore_per_node % args.ncores_per_instance
+                    if args.ncores_per_instance > ncore_per_node:
+                        # too many ncores_per_instance to skip cross-node cores
+                        logger.warning(
+                            "there are %s core(s) per socket, but you specify %s ncores_per_instance and \
+skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \
+socket",
+                            ncore_per_node,
+                            args.ncores_per_instance,
+                        )
+                        sys.exit(-1)
+                    elif num_leftover_cores == 0:
+                        # aren't any cross-node cores
+                        logger.info(
+                            "--skip-cross-node-cores is set, but there are no cross-node cores."
+                        )
+                        args.ninstances = len(cores) // args.ncores_per_instance
+                    else:
+                        # skip cross-node cores
+                        if args.ninstances != -1:
+                            logger.warning(
+                                "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \
+won't take effect even if it is set explicitly."
+                            )
+
+                        i = 1
+                        leftover_cores = set()
+                        while ncore_per_node * i <= len(cores):
+                            leftover_cores.update(
+                                cores[
+                                    ncore_per_node * i
+                                    - num_leftover_cores : ncore_per_node * i
+                                ]
+                            )
+                            i += 1
+                        cores = list(set(cores) - leftover_cores)
+                        assert len(cores) % args.ncores_per_instance == 0
+                        args.ninstances = len(cores) // args.ncores_per_instance
+            else:
+                if args.ninstances * args.ncores_per_instance > len(cores):
+                    raise RuntimeError(
+                        "Please make sure ninstances * ncores_per_instance <= total_cores"
+                    )
+            if args.latency_mode:
+                logger.warning(
+                    "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ncores_per_instance = 4
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+            if args.throughput_mode:
+                logger.warning(
+                    "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ninstances = self.cpuinfo.node_nums
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ncores_per_instance = len(cores) // args.ninstances
+
+        if args.ninstances > 1 and args.rank != -1:
+            logger.info(
+                "assigning %s cores for instance %s",
+                args.ncores_per_instance,
+                args.rank,
+            )
+
+        if not args.disable_numactl:
+            numactl_available = self.is_numactl_available()
+            if not numactl_available:
+                if not args.disable_taskset:
+                    logger.warning(
+                        "Core binding with numactl is not available. Disabling numactl and using taskset instead. \
+                    This may affect performance in multi-socket system; please use numactl if memory binding is needed."
+                    )
+                    args.disable_numactl = True
+                    enable_taskset = True
+                else:
+                    logger.warning(
+                        "Core binding with numactl is not available, and --disable_taskset is set. \
+                    Please unset --disable_taskset to use taskset instead of numactl."
+                    )
+                    sys.exit(-1)
+
+        if not args.disable_taskset:
+            enable_taskset = True
+
+        self.set_multi_thread_and_allocator(
+            args.ncores_per_instance,
+            args.disable_iomp,
+            set_kmp_affinity,
+            args.enable_tcmalloc,
+            args.enable_jemalloc,
+            args.use_default_allocator,
+        )
+        entrypoint = ""
+        launch_args = {}
+        launch_envs: dict[int, dict] = {}
+        launch_tee = {}
+        # check whether is launched from torchrun with --nproc-per-node 
+        local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
+        local_rank = int(os.environ.get("LOCAL_RANK", 0))
+        for i in range(args.ninstances):
+            cmd = []
+            cur_process_cores = ""
+            if not args.disable_numactl or enable_taskset:
+                if not args.disable_numactl:
+                    cmd = ["numactl"]
+                elif enable_taskset:
+                    cmd = ["taskset"]
+                cores = sorted(cores)
+                if (
+                    args.rank == -1
+                ):  # sequentially assign ncores_per_instance to ninstances
+                    core_list = cores[
+                        i * args.ncores_per_instance : (i + 1)
+                        * args.ncores_per_instance
+                    ]
+                else:  # assign ncores_per_instance from rank
+                    core_list = cores[
+                        args.rank * args.ncores_per_instance : (args.rank + 1)
+                        * args.ncores_per_instance
+                    ]
+
+                core_ranges: list[dict] = []
+                if local_size > 1:
+                    total_num_cores = len(core_list)
+                    cores_per_rank = total_num_cores // local_size
+                    assert cores_per_rank >= 1, (
+                        "At least one core needs to be assigned to each rank"
+                    )
+                    core_list = core_list[
+                        cores_per_rank * local_rank : cores_per_rank * (local_rank + 1)
+                    ]
+                for core in core_list:
+                    if len(core_ranges) == 0:
+                        range_elem = {"start": core, "end": core}
+                        core_ranges.append(range_elem)
+                    else:
+                        if core - core_ranges[-1]["end"] == 1:
+                            core_ranges[-1]["end"] = core
+                        else:
+                            range_elem = {"start": core, "end": core}
+                            core_ranges.append(range_elem)
+                for r in core_ranges:
+                    cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']},"
+                cur_process_cores = cur_process_cores[:-1]
+                if not args.disable_numactl:
+                    numa_params = f"-C {cur_process_cores} "
+                    numa_ids = ",".join(
+                        [
+                            str(numa_id)
+                            for numa_id in self.cpuinfo.numa_aware_check(core_list)
+                        ]
+                    )
+                    numa_params += f"-m {numa_ids}"
+                    cmd.extend(numa_params.split())
+                elif enable_taskset:
+                    taskset_params = f"-c {cur_process_cores} "
+                    cmd.extend(taskset_params.split())
+            with_python = not args.no_python
+            if with_python:
+                cmd.append(sys.executable)
+                cmd.append("-u")
+            if args.module:
+                cmd.append("-m")
+            cmd.append(args.program)
+            cmd.extend(args.program_args)
+            cmd_s = " ".join(cmd)
+            logger.info(cmd_s)
+            if entrypoint == "":
+                entrypoint = cmd[0]
+            del cmd[0]
+            launch_args[i] = tuple(cmd)
+            launch_envs[i] = {}
+            launch_tee[i] = Std.ALL
+
+            if args.rank != -1:  # launches single instance, rank, only
+                break
+
+        ctx = start_processes(
+            name=args.log_file_prefix,
+            entrypoint=entrypoint,
+            args=launch_args,
+            envs=launch_envs,
+            logs_specs=_DefaultLogsSpecs(log_dir=args.log_path, tee=launch_tee),
+        )
+        ctx.wait()
+
+
+def _add_memory_allocator_params(parser):
+    group = parser.add_argument_group("Memory Allocator Parameters")
+    # allocator control
+    group.add_argument(
+        "--enable-tcmalloc",
+        "--enable_tcmalloc",
+        action="store_true",
+        default=False,
+        help="Enable tcmalloc allocator",
+    )
+    group.add_argument(
+        "--enable-jemalloc",
+        "--enable_jemalloc",
+        action="store_true",
+        default=False,
+        help="Enable jemalloc allocator",
+    )
+    group.add_argument(
+        "--use-default-allocator",
+        "--use_default_allocator",
+        action="store_true",
+        default=False,
+        help="Use default memory allocator",
+    )
+
+
+def _add_multi_instance_params(parser):
+    group = parser.add_argument_group("Multi-instance Parameters")
+    # multi-instance control
+    group.add_argument(
+        "--ncores-per-instance",
+        "--ncores_per_instance",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="Cores per instance",
+    )
+    group.add_argument(
+        "--ninstances",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="For multi-instance, you should give the cores number you used for per instance.",
+    )
+    group.add_argument(
+        "--skip-cross-node-cores",
+        "--skip_cross_node_cores",
+        action="store_true",
+        default=False,
+        help="If specified --ncores-per-instance, skips cross-node cores.",
+    )
+    group.add_argument(
+        "--rank",
+        metavar="\b",
+        default="-1",
+        type=int,
+        help="Specify instance index to assign ncores_per_instance for rank; \
+otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \
+https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md",
+    )
+    group.add_argument(
+        "--latency-mode",
+        "--latency_mode",
+        action="store_true",
+        default=False,
+        help="By default 4 core per instance and use all physical cores",
+    )
+    group.add_argument(
+        "--throughput-mode",
+        "--throughput_mode",
+        action="store_true",
+        default=False,
+        help="By default one instance per node and use all physical cores",
+    )
+    group.add_argument(
+        "--node-id",
+        "--node_id",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="node id for multi-instance, by default all nodes will be used",
+    )
+    group.add_argument(
+        "--use-logical-core",
+        "--use_logical_core",
+        action="store_true",
+        default=False,
+        help="Whether only use physical cores",
+    )
+    group.add_argument(
+        "--disable-numactl",
+        "--disable_numactl",
+        action="store_true",
+        default=False,
+        help="Disable numactl",
+    )
+    group.add_argument(
+        "--disable-taskset",
+        "--disable_taskset",
+        action="store_true",
+        default=False,
+        help="Disable taskset",
+    )
+    group.add_argument(
+        "--core-list",
+        "--core_list",
+        metavar="\b",
+        default=None,
+        type=str,
+        help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.',
+    )
+    group.add_argument(
+        "--log-path",
+        "--log_path",
+        metavar="\b",
+        default="",
+        type=str,
+        help="The log file directory. Default path is "
+        ", which means disable logging to files.",
+    )
+    group.add_argument(
+        "--log-file-prefix",
+        "--log_file_prefix",
+        metavar="\b",
+        default="run",
+        type=str,
+        help="log file prefix",
+    )
+
+
+def _add_kmp_iomp_params(parser):
+    group = parser.add_argument_group("IOMP Parameters")
+    group.add_argument(
+        "--disable-iomp",
+        "--disable_iomp",
+        action="store_true",
+        default=False,
+        help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD",
+    )
+
+
+def create_args(parser=None):
+    """
+    Parse the command line options.
+
+    @retval ArgumentParser
+    """
+    # pyrefly: ignore [missing-attribute]
+    parser.add_argument(
+        "--multi-instance",
+        "--multi_instance",
+        action="store_true",
+        default=False,
+        help="Enable multi-instance, by default one instance per node",
+    )
+
+    # pyrefly: ignore [missing-attribute]
+    parser.add_argument(
+        "-m",
+        "--module",
+        default=False,
+        action="store_true",
+        help="Changes each process to interpret the launch script "
+        "as a python module, executing with the same behavior as"
+        '"python -m".',
+    )
+
+    # pyrefly: ignore [missing-attribute]
+    parser.add_argument(
+        "--no-python",
+        "--no_python",
+        default=False,
+        action="store_true",
+        help='Do not prepend the --program script with "python" - just exec '
+        "it directly. Useful when the script is not a Python script.",
+    )
+
+    _add_memory_allocator_params(parser)
+    _add_kmp_iomp_params(parser)
+
+    _add_multi_instance_params(parser)
+    # positional
+    # pyrefly: ignore [missing-attribute]
+    parser.add_argument(
+        "program",
+        type=str,
+        help="The full path to the program/script to be launched. "
+        "followed by all the arguments for the script",
+    )
+
+    # rest from the training program
+    # pyrefly: ignore [missing-attribute]
+    parser.add_argument("program_args", nargs=REMAINDER)
+
+
+def main(args):
+    env_before = set(os.environ.keys())
+    if platform.system() in ["Windows", "Darwin"]:
+        raise RuntimeError(f"{platform.system()} is not supported!!!")
+
+    if args.log_path:
+        os.makedirs(args.log_path, exist_ok=True)
+    else:
+        args.log_path = os.devnull
+
+    if args.latency_mode and args.throughput_mode:
+        raise RuntimeError(
+            "Either args.latency_mode or args.throughput_mode should be set"
+        )
+
+    if not args.no_python and not args.program.endswith(".py"):
+        raise RuntimeError(
+            'For non Python script, you should use "--no-python" parameter.'
+        )
+
+    # Verify LD_PRELOAD
+    if "LD_PRELOAD" in os.environ:
+        lst_valid = []
+        tmp_ldpreload = os.environ["LD_PRELOAD"]
+        for item in tmp_ldpreload.split(":"):
+            matches = glob.glob(item)
+            if len(matches) > 0:
+                lst_valid.append(item)
+            else:
+                logger.warning("%s doesn't exist. Removing it from LD_PRELOAD.", item)
+        if len(lst_valid) > 0:
+            os.environ["LD_PRELOAD"] = ":".join(lst_valid)
+        else:
+            os.environ["LD_PRELOAD"] = ""
+
+    launcher = _Launcher()
+    launcher.launch(args)
+    for x in sorted(set(os.environ.keys()) - env_before):
+        logger.debug("%s=%s", x, os.environ[x])
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser(
+        description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable "
+        "Processors with optimal configurations. Single instance inference, "
+        "multi-instance inference are enable. To get the peak performance on Intel(R) "
+        "Xeon(R) Scalable Processors, the script optimizes the configuration "
+        "of thread and memory management. For thread management, the script configures thread "
+        "affinity and the preload of Intel OMP library. For memory management, it configures "
+        "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) "
+        "\n################################# Basic usage ############################# \n"
+        "\n 1. single instance\n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu python_script args \n"
+        "\n2. multi-instance \n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu --ninstances xxx "
+        "--ncores-per-instance xx python_script args\n"
+        "\n############################################################################# \n",
+        formatter_class=RawTextHelpFormatter,
+    )
+    create_args(parser)
+    args = parser.parse_args()
+    main(args)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e69876927d01878a9d1cb836d72fd14adf95e9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py
@@ -0,0 +1,29 @@
+# mypy: allow-untyped-defs
+import sys
+import types
+
+import torch
+
+
+class _XNNPACKEnabled:
+    def __get__(self, obj, objtype):
+        return torch._C._is_xnnpack_enabled()
+
+    def __set__(self, obj, val):
+        raise RuntimeError("Assignment not supported")
+
+
+class XNNPACKEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    enabled = _XNNPACKEnabled()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e30a86b7695e939d6e1534248142b5d41b99767
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h
new file mode 100644
index 0000000000000000000000000000000000000000..253c5e917e76bdc8a2adc669404fc8d5c40b6b27
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/csrc/inductor/aoti_runtime/model.h
@@ -0,0 +1,62 @@
+#pragma once
+
+// WARNING: Be careful when adding new includes here. This header will be used
+// in model.so, and should not refer to any aten/c10 headers except the stable
+// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
+// applies to other files under torch/csrc/inductor/aoti_runtime/.
+#include 
+
+namespace torch::aot_inductor {
+
+class AOTInductorModel : public AOTInductorModelBase {
+ public:
+  AOTInductorModel(
+      std::shared_ptr constants_map,
+      std::shared_ptr> constants_array,
+      const std::string& device_str,
+      std::optional cubin_dir);
+
+  std::unordered_map const_run_impl(
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor,
+      bool initialization = false);
+
+  void _const_run_impl(
+      std::vector& output_handles,
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  void run_impl(
+      AtenTensorHandle*
+          input_handles, // array of input AtenTensorHandle; handles
+                         // are stolen; the array itself is borrowed
+      AtenTensorHandle*
+          output_handles, // array for writing output AtenTensorHandle; handles
+                          // will be stolen by the caller; the array itself is
+                          // borrowed
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  template 
+  Outputs run_impl_minimal_arrayref_interface(
+      const Inputs& inputs,
+      DeviceStreamType stream,
+      AOTIProxyExecutorHandle proxy_executor);
+
+  static std::unique_ptr Create(
+      std::shared_ptr constants_map,
+      std::shared_ptr> constants_array,
+      const std::string& device_str,
+      std::optional cubin_dir) {
+    return std::make_unique(
+        std::move(constants_map),
+        std::move(constants_array),
+        device_str,
+        std::move(cubin_dir));
+  }
+
+ private:
+  std::unique_ptr kernels_;
+};
+
+} // namespace torch::aot_inductor
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..228506dfc59480b00490149506384322f1cce4dc
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5d6ee6e15007b038f76e0c9538984df6093f4e2
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1fe830dc88af75d1d87b1743df87a2267df2942
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d80b1def27347601e415cd23b5e204238efa9b5b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee980a9c2d3820f26481fe2683ba7d2f1bb10134
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..787027f70c650f49a72f12779a794d756796b8d6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42fb042eb029519183bdb27e962da949a9c8ea3d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..490889b481c7d200faedf1439f378cc4fe9b1bc8
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3cb85c477e0a682dcf5dd784966161c7717bb06
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdbce9598ebdf20e262957acb48673965e1d4848
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a89eef6f35bc7ef3858c3d45789ea70370952e0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9967ad64c1aff050f50e7568fd1f560959822aae
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..887b9bfb63e00666005912e6df9e51a36c314a21
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88a0b5eb3a55fc5385785b238bd750b125542e4a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98db8ffe31ab44869a0209a28458932ad0c7085b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03ace74be988e62ebc47c722a6ea0d4833049e2f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b125154d6a69977dcff1171fff225c53631f9e91
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..044b1b6ebfa927cfc13755a5a4c6941facb2bf39
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be0c3ecc5ad9f634dea475db18f19395f215f616
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1501a476ceef2e32e6add8361249d78e0ea05200
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4329491bc5ece52b0309fba10da1cc72f903a791
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f79c267f4ab825d9f64614d3cdcd16015d458b3
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58797e750bd907894653b2e0ccde214475185895
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04863d9227c9b602c52466ed0a37312d99925eb5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81f11362a0c6c1b6a0d996bfa280a9e201b63755
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2349c53149a046577a7673e9fe58c531ba9844b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b51725b1c11fdc5b5061ede9562b1b412eaadc0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f8903ea0e8463f9afe5f49497164cbcbf7cdedd
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94ebb7d9ca88becb6b1c4e5b309e976379f6c14e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18ebf7b259edb6813fceb66fe208588678156b6e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c7f837ff778a017e17f077f7dc26bc1e94398be
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db25feb499634e0ba32356414b87ad7a6c779b80
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be8e4042a4622ff55660f3bfd759ccc32e4a9c51
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..445bb68bbd97241ea42e1c561e00eee94e66ca86
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d1318ffca5eb8fd181727c3af60dcf9cfe60178
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06c875902ecc074903c12231294b3f728bbfca40
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8db80830837670c6b23d411b1fceb02d7f8de7a7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7eed693834c053ce13e5eb53ff6567e0d9c19b3e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..378e9fe9cc5b4f74b48998effc3c2f5db7c35db4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77d5fdaa16edc33d7cb6d1db798669783b97dc33
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f010545c7e006ba39d592541359588d8c7d4f71e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..399e90f10fe8afb9ef45541911ccd41de5008ba9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e191d7e9fd13bf7d17014373d9e13434f341107b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed8bb415b6a78b5534016276a23eff780b6544f1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bb90ab52a3037676146008e0a588b00493c70c0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94a8774fa956cbaffb3dbeb841c697fa3a0e6cab
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35d72a9514eb73af6413373a215b03ae4400e9e1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1126ca903fd761c20b242021bda10716bc94fade
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b1ee17e52f2714061482e77d73417c7992d6b4a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae2dc837e336f6055b6cd4db4629228ee59be39f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9bcc83e6b2d2d1d2c3668cc8c7b880f41c6efb0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..413f733e901a6aed1508022557c2dc58ce34f743
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b81f4dfb2c996dd3cfe32aead715281b655aac7d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5cde0bb400fbe1fb20fd37cb691da97da6dd7cb
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34dd8196d1f722f4b228198a97093651e405e172
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1490535f636327eaa6c00a5f22d7a4f047a110b4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08aa0d7e6a8d3c3661ed0489ac6d778ccc4d8ae0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fab32fba7eee560fb3b5ec0d413db410d09a720
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..795aada96dafb1011c53a50e6b877ac0c6abae11
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95a3a75e68208cfc0d2c91648d72ded34709fba2
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d8694d9945ad1bd423ac5ddc86d00d3a28483a7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f5a7061932ef92374833dc05d31de6db041f5ea
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e996364658421fed1ca0dbe321517d2be5b6679
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f8cd7c7666247af9a363157d13618a22baf7de5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f0b2bf4031d86990a6d8d7b300be574d6e8f1d1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68e24714951845d2f0d012f522e53809963e7105
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d295d38c50ec0c9a19f30aee2db8ed51b315d9a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcd6fe4577eef157a3abba98df7560af4cd30b12
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b048a88753c60b2e4735844bb349ffc82aacf7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac3c266d9c96dd3a08bab3ade103d85323a66fe0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bed0edf152a5c2678a6895fd7813751c0b4c76ef
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..398d60e479ea22f128ff17b65bfe20e12ee2caee
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53569eaff633818a1b840f9b72915cf9a9290f34
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b3c327cc88007913e05ac3e0452891d16fbd236
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..907df00b4b8d83c59b9d40332c519d52b1a40e20
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/debug.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6118f2dc0a17d2647dd0fdc70ea10159ddff2d8d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbb3c597e807c1ae6eeb2077566314001a3fee53
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a712e5b78f6a1d6d6d148e6723e719f7601a0f22
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34abd014495679f9e506b1dbc9305110bcadeabf
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4be97057cf50ab5943751b6098e564455818d0a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5353ac3feca93dc529d1cd90f36cfb858c221611
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9ef6867b836d648a36273aece89cde885b01c81
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a2e6dc85ea80162c19576d8f8886b22cad4e386
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac112abf9e955188c04370e3f1b4f57bf61e40b5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc584100313adaeb743c3e41e48d0ac40aa8eba8
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee3296a0af13b3cd1d19259b1ab69e2fa3112ffe
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c8fd23c1e2950ac4a400e426378d1b3a8b28d41
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5dc638c48bd8a61d4a8e5e2b9e6b206681d08c6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c742431857c33af22dbc1ad73b5bdfcf6124b9c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py
@@ -0,0 +1,27 @@
+import torch.fx
+
+
+class BackwardState:
+    """
+    BackwardState is used to pass Python hooks from the forwards pass
+    into the backwards pass in Dynamo+Compiled Autograd.
+
+    It is created by TorchDynamo and has special handling there.
+    Dynamo will pass an empty BackwardState to the forwards, then populate
+    members on it (via setattr) only after the forwards graph is finished.
+    Later on, in CompileAutograd we will inline and add the needed guards
+    on the BackwardState.
+
+    BackwardState is identified and has special handling in AOTAutograd.
+    During AOTAutograd:
+        1) BackwardState is an input to the forwards graph
+        2) It must only be used in the backwards
+        3) It will be empty in the forwards
+        4) In the forwards we add a wrapper to save it
+        5) In the backwards it becomes an input
+        6) There can only be one per graph
+
+    BackwardState requires CompiledAutograd.
+    """
+
+    proxy: torch.fx.Proxy
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a537978db3834d0bbb425bbd4214a8b17163db18
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_config.py
@@ -0,0 +1,112 @@
+import os
+import sys
+from typing import Optional
+
+from torch.utils._config_module import Config, install_config_module
+
+
+# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
+no_data_dependent_graph_break = (
+    os.environ.get("TORCHDYNAMO_NO_DATA_DEPENDENT_GRAPH_BREAK", "0") == "1"
+)
+# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
+translation_validation = (
+    os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
+)
+# Timeout (in milliseconds) for z3 finding a solution.
+# [@compile_ignored: debug]
+translation_validation_timeout = int(
+    os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
+)
+# Disables bisection for translation validation.
+#
+# Translation validation bisection is enabled by default, if translation validation
+# is also enabled. This should help finding guard simplification issues. However,
+# since validation uses Z3 for bisecting, it might take a lot of time.
+#
+# Set this configuration option so as to avoid bisecting.
+# [@compile_ignored: debug]
+translation_validation_no_bisect = (
+    os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
+)
+# Checks whether replaying ShapeEnv events on a freshly constructed one yields
+# the a ShapeEnv with the same state. This should be used only in testing.
+check_shape_env_recorded_events = False
+
+# TODO: Perhaps consider allowing unions for the configs below (so you can hit
+# multiple reps at the same time)
+
+# Give extended debug information if the string representation of a guard
+# matches this.  For example, set this to "Ne(s0, 10)" and whenever we issue
+# this guard, we will generate full Python and C++ backtrace
+# [@compile_ignored: debug]
+extended_debug_guard_added = os.environ.get(
+    "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
+)
+
+# Give extended debug information when a particular symbol is allocated.  For
+# example, set this to "u2" and whenever we create this symbol, we will
+# generate full Python and C++ backtrace
+# [@compile_ignored: debug]
+extended_debug_create_symbol = os.environ.get(
+    "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
+)
+
+# Give extended debug information (C++ backtrace) for all extended debug
+# settings as well as errors.  The C++ backtrace is slow and very spammy so we
+# don't include it by default even when you're requesting extended debug.
+# [@compile_ignored: debug]
+extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
+
+# Give extended debug information (line of code) when a torch function
+# is called during export.  This is useful for showing progress and detecting
+# where export might be stuck. Currently only works for strict=False.
+# [@compile_ignored: debug]
+extended_debug_current_loc = (
+    os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
+)
+
+# [@compile_ignored: debug] Show a warning for every specialization
+print_specializations = False
+
+# wraps (un)equalities with 'Not' class after recording the correct expression
+# in the FX graph. This should incorrectly construct the divisible and replacement
+# lists, and incorrectly issue guards.
+inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
+
+# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
+validate_shape_env_version_key = False
+
+# If we produce more than this many guards on a symbol, force the symbol to
+# get specialized and bail out if this many guards mention this particular
+# symbol.  This may be slightly more aggressive than the true number of guards
+# issued (as we test if we've hit the limit on-the-fly, whereas we may
+# do further simplifications at final guard issuance time that make guards
+# irrelevant.)
+symbol_guard_limit_before_specialize: Optional[int] = None
+
+# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
+use_duck_shape = True
+
+# Controls the registration of torch.nonzero() on the meta device.
+# When True, nonzero returns a tensor with shape (self.numel(), self.dim())
+# assuming all elements are none-zero.
+# Default is False to prevent unintended registration. Set to True to enable.
+meta_nonzero_assume_all_nonzero = False
+
+# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols,
+# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like.
+# Currently an experimental option for export.
+backed_size_oblivious = False
+
+# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
+skip_dtype_check_in_meta_registrations = False
+
+# Experimental: If True, graph module will register fx metadata during recompile()
+enrich_profiler_metadata: bool = Config(  # type: ignore[var-annotated]
+    default=False,
+    env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
+)
+
+
+install_config_module(sys.modules[__name__])
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3b40bda324c8fd6ad171d14ddb17f52508cb23a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py
@@ -0,0 +1,78 @@
+from typing import *  # noqa: F403
+
+
+# Python version of c10/core/ConstantSymNodeImpl.cpp
+# This needs to exist because the Python version of nested int is not compatible
+# with the C++ version of constant symnode.
+class ConstantIntNode:
+    def __init__(self, val: int):
+        self.val = val
+
+    def is_constant(self) -> bool:
+        return True
+
+    def maybe_as_int(self) -> int:
+        return self.val
+
+    def is_int(self) -> bool:
+        return True
+
+    def is_float(self) -> bool:
+        return False
+
+    def is_bool(self) -> bool:
+        return False
+
+    def is_nested_int(self) -> bool:
+        return False
+
+    def clone(self) -> "ConstantIntNode":
+        return self
+
+    def _str(self) -> str:
+        return str(self.val)
+
+    def __str__(self) -> str:
+        return self._str()
+
+    def __repr__(self) -> str:
+        return self._str()
+
+    def _graph_repr(self) -> str:
+        return self._str()
+
+    def add(self, other: Any) -> Any:
+        return other.add(self)
+
+    def sub(self, other: Any) -> Any:
+        return other.neg().add(self.val)
+
+    def mul(self, other: Any) -> Any:
+        return other.mul(self)
+
+    def eq(self, other: Any) -> Any:
+        return other.eq(self)
+
+    def ne(self, other: Any) -> Any:
+        return other.ne(self)
+
+    def gt(self, other: Any) -> Any:
+        return other.lt(self)
+
+    def lt(self, other: Any) -> Any:
+        return other.gt(self)
+
+    def le(self, other: Any) -> Any:
+        return other.ge(self)
+
+    def ge(self, other: Any) -> Any:
+        return other.le(self)
+
+    def is_symbolic(self) -> bool:
+        return False
+
+    def constant_int(self) -> int:
+        return self.val
+
+    def guard_int(self, file: str, line: int) -> int:
+        return self.val
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f30779ecc28df106658ca80fb103ee3735e5a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py
@@ -0,0 +1,119 @@
+import re
+from collections.abc import Callable
+from typing import Any, Union
+
+import torch
+from torch.utils._pytree import tree_flatten_with_path, tree_map
+
+
+KeyPath = tuple[Any, ...]
+NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]]
+
+__all__ = [
+    "normalize_source_name",
+    "module_to_nested_dict",
+    "track_dynamism_across_examples",
+    "clone_and_convert_to_meta",
+]
+
+
+def normalize_source_name(name: str) -> str:
+    # Match attribute access like .x and replace with ['x']
+    return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name)
+
+
+def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
+    """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys."""
+    self_dict: dict[str, Any] = {}
+
+    self_dict["_parameters"] = {}
+    self_dict["_modules"] = {}
+
+    for attr_name in dir(module):
+        try:
+            if not attr_name.startswith("_") and not callable(
+                getattr(module, attr_name)
+            ):
+                attr_value = getattr(module, attr_name)
+                if (
+                    not isinstance(attr_value, torch.nn.Module)
+                    and isinstance(attr_value, (int, float, torch.Tensor))
+                    and type(attr_value) is not bool
+                ):
+                    self_dict[attr_name] = attr_value
+        except NotImplementedError:
+            # Skip attributes that raise NotImplementedError since they won't
+            # contain any dynamism anyways.
+            continue
+
+    for name, param in module.named_parameters(recurse=False):
+        self_dict["_parameters"][name] = param
+    for name, buffer in module.named_buffers(recurse=False):
+        self_dict["_parameters"][name] = buffer
+
+    for name, submodule in module.named_children():
+        self_dict["_modules"][name] = module_to_nested_dict(submodule)
+
+    return self_dict
+
+
+def track_dynamism_across_examples(
+    example_inputs: list[Any],
+) -> dict[Any, Any]:
+    """
+    This function analyzes a list of example inputs to determine the dynamism of their shapes.
+    It tracks whether the dimensions of tensors or non-tensor values change across
+    different examples. The function returns a dictionary where each key represents
+    a path to a value in the input examples, and the corresponding value is a tuple
+    indicating which dimensions are dynamic (i.e., change across examples). This
+    helps in understanding how the structure of data varies across different instances.
+    """
+    tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {}
+
+    for ex in example_inputs:
+        if "self" in ex and isinstance(ex["self"], torch.nn.Module):
+            ex["self"] = module_to_nested_dict(ex["self"])
+        leaves_with_paths, _ = tree_flatten_with_path(ex)
+        for key_path, value in leaves_with_paths:
+            if not isinstance(value, (int, float, torch.Tensor)):
+                continue
+            if isinstance(value, torch.Tensor):
+                shape: tuple[int | float, ...] = tuple(value.shape)
+                is_tensor = True
+            else:
+                shape = (value,)
+                is_tensor = False
+            if key_path not in tracking:
+                tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor)
+            else:
+                dim_sets, flag = tracking[key_path]
+                if flag != is_tensor:
+                    pass
+                while len(dim_sets) < len(shape):
+                    dim_sets.append(set())
+            for i, dim in enumerate(shape):
+                tracking[key_path][0][i].add(dim)
+
+    output: dict[Any, Any] = {}
+    for key_path, (dim_sets, _is_tensor) in tracking.items():
+        final_dyn = tuple(len(s) > 1 for s in dim_sets)
+        key_str = "L" + "".join(f"{str(k)}" for k in key_path)
+        key = key_path[0].key  # type: ignore[attr-defined]
+        if key not in output:
+            output[key] = {}
+        output[key][key_str] = final_dyn
+    return output
+
+
+def clone_and_convert_to_meta(example_input: Any) -> Any:
+    """
+    This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta.
+    For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively.
+    """
+
+    def transform_fn(value: Any) -> Any:
+        if isinstance(value, torch.Tensor):
+            return value.clone().to(device="meta")
+        return value
+
+    return tree_map(transform_fn, example_input)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cfd41b039e9ec6f9c456fd0240b18902756dc55
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py
@@ -0,0 +1,1085 @@
+# mypy: allow-untyped-defs
+import operator
+from collections import deque
+from typing import NamedTuple
+
+import torch
+from torch.fx.experimental.partitioner_utils import (
+    Device,
+    get_extra_size_of,
+    get_latency_of_partitioned_graph,
+    get_partition_to_latency_mapping,
+    NodeLatency,
+    Partition,
+    PartitionerConfig,
+    PartitionMode,
+)
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import map_arg, Node
+from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
+from torch.fx.passes.split_module import split_module
+
+
+class DAGNode:
+    """DAGNode class maintains useful information for a partition (submodule),
+    and its input submodules and output submodules.
+    """
+
+    def __init__(
+        self,
+        submodule_node: Node,
+        input_nodes: list[Node],
+        output_nodes: list[Node],
+        logical_device_ids: list[int],
+        size_bytes: int,
+    ) -> None:
+        self.submodule_node: Node = submodule_node
+        self.input_nodes: list[Node] = input_nodes
+        self.output_nodes: list[Node] = output_nodes
+        self.logical_device_ids: list[int] = logical_device_ids
+        self.size_bytes = size_bytes
+
+    def __str__(self) -> str:
+        return str(self.submodule_node)
+
+
+class DAG:
+    """DAG class contains all the DAG nodes"""
+
+    def __init__(self) -> None:
+        self.nodes: list[DAGNode] = []
+
+    def create_node(
+        self,
+        submodule_node: Node,
+        input_nodes: list[Node],
+        output_nodes: list[Node],
+        logical_devices: list[int],
+        size_bytes: int,
+    ) -> None:
+        node = DAGNode(
+            submodule_node, input_nodes, output_nodes, logical_devices, size_bytes
+        )
+        self.nodes.append(node)
+
+
+class PartitionResult(NamedTuple):
+    """NameTuple used for returning DAG and a new fx module"""
+
+    dag: DAG
+    module_with_submodules: GraphModule
+
+
+"""Followings are some helper functions for partition manipulation"""
+
+
+def reset_partition_device(partitions):
+    for partition in partitions:
+        partition.logical_device_ids = []
+
+
+def combine_two_partitions(
+    partition_0: Partition, partition_1: Partition, partitions: list[Partition]
+) -> None:
+    """Given a list of partitions and its two partitions,
+    combine these two partitions into a new one appending to the partitions
+    and remove the previous two partitions from the list of partitions
+    """
+    partition = Partition(len(partitions))
+    partition.nodes = partition_0.nodes.union(partition_1.nodes)
+    partition.recalculate_mem_size()
+    partitions.append(partition)
+    partitions.remove(partition_0)
+    partitions.remove(partition_1)
+    reorganize_partitions(partitions)
+    return
+
+
+def set_parents_and_children(partitions: list[Partition]) -> None:
+    """Given a list of partitions, mark parents and children for each partition"""
+    # Go through all nodes in a partition.
+    # If a node's user is in other partition,
+    # then the other partition is this partition's children.
+    # This partition is the other partition's parent
+    for partition in partitions:
+        partition.children = set()
+        partition.parents = set()
+    for partition in partitions:
+        for node in partition.nodes:
+            # For each node in the current partition, find its users
+            users = node.users
+            for n in users:
+                # Find which the partition the user node belongs to.
+                # Note that if the node itself is also belongs to that partition,
+                # that partition is not the child of the current partition
+                for p in partitions:
+                    if p != partition and n in p.nodes and node not in p.nodes:
+                        partition.children.add(p)
+                        p.parents.add(partition)
+    return
+
+
+def reorganize_partitions(partitions: list[Partition]) -> None:
+    """Given a list of partitions, reorganize partition id,
+    its parents and its children for each partition
+    """
+    # Rearrange partition ids
+    for i, partition in enumerate(partitions):
+        partition.partition_id = i
+    set_parents_and_children(partitions)
+    return
+
+
+def get_bfs_level_partition(partitions: list[Partition]) -> None:
+    """Given a list of partitions,
+    mark the bfs level for each partition
+    """
+    current_level: set[Partition] = set()
+    visited: set[Partition] = set()
+    for partition in partitions:
+        # If a partition has no parent, it should be in root level
+        if len(partition.parents) == 0:
+            current_level.add(partition)
+    next_level: set[Partition] = set()
+    level = 0
+    # bfs
+    while current_level:
+        partition = current_level.pop()
+        partition.bfs_level = level
+        visited.add(partition)
+        children = partition.children
+        for child in children:
+            if child not in next_level:
+                next_level.add(child)
+        if not current_level:
+            current_level = next_level.copy()
+            next_level = set()
+            level += 1
+    return
+
+
+def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]:
+    """Given a list of partitions,return node to partition mapping"""
+    node_to_partition: dict[Node, int] = {}
+    for partition in partitions:
+        for node in partition.nodes:
+            node_to_partition[node] = partition.partition_id
+    return node_to_partition
+
+
+def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]:
+    """Get a mapping from device logical ID to Device object."""
+    logical_id_to_device: dict[int, Device] = {}
+    for d in devices:
+        logical_id_to_device[d.logical_id] = d
+    return logical_id_to_device
+
+
+def get_device_partition_stats(
+    partitions: list[Partition], devices: list[Device]
+) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]:
+    """Given a list of partitions and a list of devices, returns:
+    1. A mapping from device to partitions on it;
+    2. A mapping from device to its remaining memory size;
+    3. A list of partitions that do not have a device.
+    """
+    # logical id to device
+    logical_id_to_device = get_logical_id_to_device(devices)
+    # Track partitions on device
+    device_to_partitions: dict[Device, list[Partition]] = {}
+    # Track device's left mem size
+    device_to_left_mem_bytes: dict[Device, int] = {}
+    for d in devices:
+        device_to_partitions[d] = []
+        device_to_left_mem_bytes[d] = d.available_mem_bytes
+
+    # Deal with the partitions that already have a device
+    # and also collect all partitions without a device (no_device_partitions)
+    no_device_partitions = []
+    for partition in partitions:
+        if partition.logical_device_ids != []:
+            for logical_id in partition.logical_device_ids:
+                device = logical_id_to_device[logical_id]
+                device_to_partitions[device].append(partition)
+                device_to_left_mem_bytes[device] -= partition.used_mem_bytes
+        else:
+            no_device_partitions.append(partition)
+
+    return (
+        device_to_partitions,
+        device_to_left_mem_bytes,
+        no_device_partitions,
+    )
+
+
+def get_device_to_partitions_mapping(
+    partitions: list[Partition], devices: list[Device]
+):
+    """Given a list of partitions and a list of devices,
+    map each partition into a device.
+    """
+
+    def calculate_extra_mem_bytes_needed_for(
+        partition: Partition, partitions: list[Partition]
+    ):
+        all_nodes: set[Node] = set()
+        for p in partitions:
+            all_nodes = all_nodes.union(p.nodes)
+        if len(all_nodes) == 0:
+            return partition.used_mem_bytes
+        all_nodes = all_nodes.union(partition.nodes)
+        extra_size_needed = 0
+        for node in partition.nodes:
+            extra_size_needed += get_extra_size_of(node, all_nodes)
+        return extra_size_needed
+
+    def find_device_for(partition: Partition):
+        """Given a partition, find a logical device for the partition
+        The algorithm is to put the partition on the device
+        that has just enough mem left for that partition.
+        device_to_left_mem_bytes is a dictionary between device and its left mem size
+        sorted by its left mem size
+        """
+        for d in device_to_left_mem_bytes:
+            extra_size_needed = calculate_extra_mem_bytes_needed_for(
+                partition, device_to_partitions[d]
+            )
+            if extra_size_needed < device_to_left_mem_bytes[d]:
+                device_to_partitions[d].append(partition)
+                partition.logical_device_ids.append(d.logical_id)
+                device_to_left_mem_bytes[d] -= extra_size_needed
+                return True
+        return False
+
+    (
+        device_to_partitions,
+        device_to_left_mem_bytes,
+        no_device_partitions,
+    ) = get_device_partition_stats(partitions, devices)
+
+    # Find devices for all the partitions without a device
+    found_device = True
+    for partition in no_device_partitions:
+        device_to_left_mem_bytes = dict(
+            sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))
+        )
+        found_device = find_device_for(partition)
+        if not found_device:
+            break
+    return found_device
+
+
+def check_dependency(partition):
+    """Given a partition,check if there is a circular dependency on
+    this partition using bfs
+    """
+    visited: set[Partition] = {partition}
+    queue: deque[Partition] = deque([partition])
+    while queue:
+        p = queue.popleft()
+        for child in p.children:
+            if child == partition:
+                return True
+            else:
+                if child not in visited:
+                    visited.add(child)
+                    queue.append(child)
+    return False
+
+
+class Partitioner:
+    """A fx module may not fit into one device.
+    Partitioner class helps partition one fx module into submodules (partitions),
+    so that the submodules can be executed crossing different accelerators.
+    The main function of this class is self.partition_graph.
+    It partitions the fx module based on the scheme specified in partition_config
+    A DAG structure is returned
+    along with a new fx module with submodule nodes.
+    """
+
+    def __init__(self) -> None:
+        self.partitions: list[Partition] = []
+        self.node_to_partition: dict[Node, int] = {}
+        self.devices: list[Device] = []
+
+    def partition_graph(
+        self,
+        fx_module: GraphModule,
+        torch_module: torch.nn.Module,
+        partitioner_config: PartitionerConfig,
+    ) -> PartitionResult:
+        """Given the fx module, torch module and partitioner_config,
+        find the partitions, do the partitions,
+        and then return a DAG and a new fx module with submodule nodes (partitions)
+        """
+        self.graph_module = fx_module
+        self.torch_module = torch_module
+        self.devices = partitioner_config.devices
+        if len(self.devices) == 0:
+            raise RuntimeError("No devices")
+        # Tag the size in bytes to all nodes in the graph_module.
+        get_size_of_all_nodes(self.graph_module)
+        # Check if there are op nodes in the fx module
+        nodes = self.graph_module.graph.nodes
+        if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes):
+            raise RuntimeError("No Partition since no operations in the module")
+        # Calculate total size of the fx module
+        total_size_of_graph = 0
+        for node in nodes:
+            if node.op == "output":
+                break
+            total_size_of_graph += node.size_bytes.total_size
+        # Find the device with the max mem size
+        device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes)
+        # AOT based partition
+        if partitioner_config.mode == PartitionMode.aot_based:
+            self.aot_based_partition(
+                partitioner_config.node_to_partition_mapping,
+                partitioner_config.partition_to_logical_device_mapping,
+            )
+        # Single partition if the whole module can be fit into one device
+        elif total_size_of_graph <= device_with_max_mem.available_mem_bytes:
+            self.find_single_partition(
+                total_size_of_graph, logical_device_id=device_with_max_mem.logical_id
+            )
+        elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices):
+            raise RuntimeError("Devices have no enough memory for the module")
+        else:
+            # Sparse nn based partition
+            if partitioner_config.mode == PartitionMode.sparse_nn:
+                available_mem_bytes = self.devices[0].available_mem_bytes
+                if not all(
+                    device.available_mem_bytes == available_mem_bytes
+                    for device in self.devices
+                ):
+                    raise RuntimeError("All devices must have same memory size!")
+                # sparse_nn_partition only support same memory size
+                # TODO: add different size support for sparse_nn_partition
+                self.sparse_nn_partition(available_mem_bytes)
+            # Cost aware partition
+            elif partitioner_config.mode == PartitionMode.cost_aware:
+                self.cost_aware_partition(
+                    partitioner_config.transfer_rate_bytes_per_sec,
+                    partitioner_config.node_to_latency_mapping,
+                )
+            # KL based partition
+            elif partitioner_config.mode == PartitionMode.kl_based:
+                self.kl_based_partition(
+                    partitioner_config.transfer_rate_bytes_per_sec,
+                    partitioner_config.node_to_latency_mapping,
+                )
+            else:
+                self.size_based_partition()
+
+        # Saturate host if possible.
+        if partitioner_config.saturate_host:
+            self.saturate_host()
+
+        # Partition the graph module based on the partition assignment.
+        module_with_submodules = self.do_partition()
+
+        # The DAG contains DAGNodes with info of each partition's input nodes, output nodes
+        # and how partitions are connected.
+        dag = self.dump_dag(module_with_submodules)
+        ret = PartitionResult(dag, module_with_submodules)
+        return ret
+
+    def find_single_partition(
+        self, total_size_of_graph, logical_device_id: int = 0
+    ) -> None:
+        """Fit the whole fx module into one device"""
+        partition_0 = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op == "output":
+                # Skip the output node, but there can
+                # be nodes after the output in certain cases.
+                continue
+            partition_0.nodes.add(node)
+        partition_0.used_mem_bytes = total_size_of_graph
+        partition_0.logical_device_ids = [logical_device_id]
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def size_based_partition(self) -> None:
+        """This method is to partition the fx module based on memory size.
+        It uses greedy approach. The result may not be the best.
+        The basic idea is:
+        Step 1:
+        Find a device which has enough memory to fit the current node, create a empty partition
+        with the size of that device.
+        Then keep adding the following nodes into the partition until the partition is full.
+        Step 2:
+        Repeat Step 1 until no device left
+        Step 3:
+        If some nodes are left, create a partition for each left node (single node partition).
+        and then try to map those partitions into logical devices with enough mem left.
+        """
+
+        def find_device_based_on_size(node) -> Device:
+            """Given a node, this function is to find a logical device
+            that could fit the node.
+            """
+            mem_size_needed = get_extra_size_of(node, set())
+            device = Device("", -1, -1)
+            for d in self.devices:
+                if (
+                    d not in occupied_devices
+                    and d.available_mem_bytes >= mem_size_needed
+                ):
+                    device = d
+                    break
+            if device.available_mem_bytes < 0:
+                raise RuntimeError(str(node) + "is too large to fit any device")
+            occupied_devices.append(device)
+            return device
+
+        # Track partition and its left mem size
+        partition_to_left_mem_bytes: dict[Partition, int] = {}
+        # Track all the devices that have been used
+        occupied_devices: list[Device] = []
+        partition = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op in {"call_module", "call_method", "call_function"}:
+                # Check if there are devices left
+                if len(self.partitions) <= len(self.devices):
+                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                    # Check if the current partition is the very first partition
+                    if partition.used_mem_bytes == 0:
+                        # Find a device to fit the first node, return available mem size
+                        device = find_device_based_on_size(node)
+                        occupied_devices.append(device)
+                        # Update partition and its left mem size
+                        partition_to_left_mem_bytes[partition] = (
+                            device.available_mem_bytes
+                        )
+                        # Update available mem for the current partition
+                        partition.logical_device_ids.append(device.logical_id)
+                    else:
+                        # The current partition is not the first partition
+                        # Check if the current node can fit into current partition
+                        if (
+                            partition_to_left_mem_bytes[partition]
+                            < total_size_of_input_nodes
+                        ):
+                            # Check if no device is left
+                            if len(self.partitions) == len(self.devices):
+                                # No device is left
+                                # Create the first single node partition for the current node
+                                self.create_single_node_partition(node)
+                                continue
+                            # Some devices are still left
+                            # Create a new partition with a mem size that is enough for the current node
+                            device = find_device_based_on_size(node)
+                            partition = self.create_partition()
+                            total_size_of_input_nodes = get_extra_size_of(
+                                node, partition.nodes
+                            )
+                            partition_to_left_mem_bytes[partition] = (
+                                device.available_mem_bytes
+                            )
+                            partition.logical_device_ids.append(device.logical_id)
+                    partition.add_node(node)
+                    partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
+                # Create single node partitions if no device is left
+                else:
+                    self.create_single_node_partition(node)
+        reorganize_partitions(self.partitions)
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        # Mapping all partitions into device
+        found_partition_to_device_mapping = get_device_to_partitions_mapping(
+            self.partitions, self.devices
+        )
+        if not found_partition_to_device_mapping:
+            raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping")
+        return
+
+    def saturate_host(self) -> None:
+        """Saturate host by assigning replicates to unused devices with enough memory.
+        It uses a greedy approach to find a next available set of devices to place all split
+        partitions: For each used device, it searches for an idle device with minimal memory
+        size that can hold all the partition located on that device; If the search is successful
+        for all used devices, it then assigns the new devices' logical ID to the corresponding
+        partition.
+        """
+        (
+            device_to_partitions,
+            device_to_left_mem_bytes,
+            no_device_partitions,
+        ) = get_device_partition_stats(self.partitions, self.devices)
+
+        assert len(no_device_partitions) == 0, (
+            f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
+        )
+
+        # Devices that hold partitions
+        used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
+        # Track replicates of the assigned devices
+        replicated_device_to_used_device: dict[Device, Device] = {}
+
+        while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
+            self.devices
+        ):
+            # Success flag for this round
+            success = True
+            # Devices that have not been assigned
+            idle_devices = [
+                d
+                for d in self.devices
+                if d not in used_devices and d not in replicated_device_to_used_device
+            ]
+            # Temporary mapping from replicated device to original device
+            temp_replicate_mapping = {}
+
+            # Find a new device to replicate all partitions on an used device
+            for used_device in used_devices:
+                # Idle devices that have enough memory
+                available_devices = [
+                    d
+                    for d in idle_devices
+                    if d.available_mem_bytes
+                    >= used_device.available_mem_bytes
+                    - device_to_left_mem_bytes[used_device]
+                ]
+                if len(available_devices) == 0:
+                    success = False
+                    break
+                new_device = min(available_devices, key=lambda d: d.available_mem_bytes)
+                idle_devices.remove(new_device)
+                temp_replicate_mapping[new_device] = used_device
+
+            if not success:
+                break
+            replicated_device_to_used_device.update(temp_replicate_mapping)
+
+        # Update logical device IDs assigned to the partitions
+        for (
+            replicate_device,
+            original_device,
+        ) in replicated_device_to_used_device.items():
+            logical_id = replicate_device.logical_id
+            for partition in device_to_partitions[original_device]:
+                partition.logical_device_ids.append(logical_id)
+        for p in self.partitions:
+            print(p.logical_device_ids)
+
+    def do_partition(self) -> GraphModule:
+        """Return a new fx module with submodule nodes (partitions)."""
+        module_with_submodules = split_module(
+            self.graph_module,
+            self.torch_module,
+            lambda node: self.node_to_partition[node],
+        )
+        return module_with_submodules
+
+    def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
+        """Return the dag structure and the new fx module with submodules."""
+        dag = DAG()
+        for node in module_with_submodules.graph.nodes:
+            if node.op == "output":
+                break
+            if node.op in {"placeholder", "get_attr"}:
+                continue
+            if node.target is operator.__getitem__:
+                continue
+            input_nodes: dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # When a node has two or more output nodes,
+            # it outputs its result to 'getitem' nodes.
+            # Those 'getitem' nodes are the output node for this node.
+            # Otherwise, the output node is this node itself.
+            if len(node.users) > 1:
+                output_nodes = list(node.users)
+            else:
+                output_nodes = [node]
+            partition_id = int(node.name.rsplit("_", 1)[-1])
+            device_ids = self.partitions[partition_id].logical_device_ids
+            size_bytes = self.partitions[partition_id].used_mem_bytes
+            dag.create_node(
+                node, list(input_nodes), output_nodes, device_ids, size_bytes
+            )
+        return dag
+
+    def create_partition(self) -> Partition:
+        """Create a partition and append it to self.partitions."""
+        partition_id = len(self.partitions)
+        partition = Partition(partition_id)
+        self.partitions.append(partition)
+        return partition
+
+    def create_single_node_partition(self, node):
+        """Create a partition for a single node"""
+        partition = self.create_partition()
+        partition.add_node(node)
+        return
+
+    def sparse_nn_partition(self, available_mem_bytes: int) -> None:
+        """This method partition a sparse nn module.
+        It is size based partition but different from size_based_partition,
+        it only works when all the devices have same memory size (available_mem_bytes).
+        In the future, devices with different mem sizes will be supported like size_based_partition.
+        It first traverse all the nodes and do the partitions based on the same memory size.
+        If the current partition has no enough memory left for a new op node
+        (call_module, call_method, call_function), a new partition is created.
+        When crossing the boundary between non-embedding nodes and embedding nodes,
+        a new partition is created regardlessly.
+        For example, if the current node is a non-embedding node but the next node is an
+        embedding node, a new partition is created for the next node.
+        After the partition, the partitions are combined as much as possible.
+        The rule is that a non-embedding partition only
+        combines with another non-embedding one.
+        So as the embedding partitions.
+        """
+
+        def combine_partitions_based_on_size(
+            partitions: list[Partition], available_mem_bytes: int
+        ) -> None:
+            """Combining small partitions together to keep as less partitions as possible.
+            Here is an example of the algorithm to do this:
+            Assume some partitions, we first sort them based on partition used memory size.
+            [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)]
+            The available memory is 10.
+            step 1: self.find_partition_to_combine_based_on_size()
+            First, mark bfs level for each partition
+            Second, look the smallest partition, partition_4: 10 - 1 = 9
+            It means any partition has a used memory equal or less than 9 could combine this partition
+            We go from the largest and selection partition_0.
+            Check the bfs level for two partitions, if the level difference is less than 2,
+            it can be combined.
+            step 2: repeat step 1 until no partitions can be combined
+            """
+            find_combination = True
+            while find_combination:
+                # Sort partitions based on memory size
+                sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes)
+                # Mark bfs level
+                get_bfs_level_partition(self.partitions)
+                find_combination, partitions = find_partition_to_combine_based_on_size(
+                    sorted_partitions,
+                    available_mem_bytes,
+                    # pyrefly: ignore [bad-argument-type]
+                    partitions,
+                )
+            return
+
+        def calculate_mem_bytes_needed(p1, p2):
+            """Given two partitions, calculate how many mem bytes
+            are needed if two partitions are combined
+            """
+            nodes = p1.nodes.union(p2.nodes)
+            mem_bytes_needed = 0
+            for node in nodes:
+                mem_bytes_needed += get_extra_size_of(node, nodes)
+            return mem_bytes_needed
+
+        def find_partition_to_combine_based_on_size(
+            sorted_partitions: list[Partition],
+            available_mem_bytes: int,
+            partitions: list[Partition],
+        ) -> tuple[bool, list[Partition]]:
+            """step 1 in combine_partition_based_on_size()"""
+            find_combination = False
+            smallest_partition = sorted_partitions.pop(0)
+            for p in sorted_partitions[::-1]:
+                if abs(smallest_partition.bfs_level - p.bfs_level) <= 1:
+                    # Calculate how many bytes needed if combined
+                    mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition)
+                    if mem_bytes_needed <= available_mem_bytes:
+                        combine_two_partitions(p, smallest_partition, self.partitions)
+                        partitions.remove(smallest_partition)
+                        partitions.remove(p)
+                        partitions.append(self.partitions[-1])
+                        find_combination = True
+                        break
+            return find_combination, partitions
+
+        def reset_partition_in_sparse_nn(partition, new_partition=True):
+            """If crossing the boundary between non-embedding nodes and
+            embedding nodes, create a new partition
+            """
+            if in_embedding_region:
+                embedding_partitions.append(partition)
+            else:
+                non_embedding_partitions.append(partition)
+            if new_partition:
+                partition = self.create_partition()
+                # pyrefly: ignore [missing-attribute]
+                partition.left_mem_bytes = available_mem_bytes
+                return partition
+            return None
+
+        def is_embedding_node(node: Node) -> bool:
+            """Check if a node is an embedding node"""
+            if node.op == "call_module":
+                submodule = self.graph_module
+                for atom in str(node.target).split("."):
+                    if not hasattr(submodule, atom):
+                        raise RuntimeError(
+                            f"Module {submodule} has no attribute {atom}"
+                        )
+                    submodule = getattr(submodule, atom)
+                    if "Embedding" in str(submodule):
+                        return True
+            return False
+
+        # Track embedding partitions and non-embedding partitions separately
+        embedding_partitions: list[Partition] = []
+        non_embedding_partitions: list[Partition] = []
+        # A Flag to check the boundary
+        in_embedding_region: bool = False
+        partition = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op in {"call_module", "call_method", "call_function"}:
+                # Check if crossing the boundary between embedding nodes and non embedding nodes
+                if is_embedding_node(node) != in_embedding_region:
+                    # Crossing the boundary
+                    # Check if the current partition is an empty partition
+                    if partition.used_mem_bytes != 0:
+                        # The current partition isn't an empty partition. Create a new one.
+                        partition = reset_partition_in_sparse_nn(partition)
+                    in_embedding_region = not in_embedding_region
+                total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                if (
+                    total_size_of_input_nodes + partition.used_mem_bytes
+                    > available_mem_bytes
+                ):
+                    partition = reset_partition_in_sparse_nn(partition)
+                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                    if total_size_of_input_nodes > available_mem_bytes:
+                        raise RuntimeError(
+                            node.target + "is too large to fit into a device"
+                        )
+                partition.add_node(node)
+        reset_partition_in_sparse_nn(partition, new_partition=False)
+        # Set parents and children for partitions
+        set_parents_and_children(self.partitions)
+        # Combining non-embedding partitions
+        combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes)
+        # Combining embedding partitions
+        combine_partitions_based_on_size(embedding_partitions, available_mem_bytes)
+        total_size_of_non_embedding_partitions = 0
+        for partition in non_embedding_partitions:
+            total_size_of_non_embedding_partitions += partition.used_mem_bytes
+        # Check if devices are enough for all partitions
+        if len(embedding_partitions) > len(self.devices):
+            msg = (
+                "Need "
+                + str(len(embedding_partitions))
+                + " devices, but only "
+                + str(len(self.devices))
+                + " provided"
+            )
+            raise RuntimeError(msg)
+        occupied_devices = []
+        for i, partition in enumerate(embedding_partitions):
+            # Check if all non-embedding partitions can fit into embedding partition devices
+            if (
+                total_size_of_non_embedding_partitions + partition.used_mem_bytes
+                > available_mem_bytes
+            ):
+                raise RuntimeError(
+                    "partition_"
+                    + str(partition.partition_id)
+                    + "(embedding partition) and non embedding partitions can not fit into one device"
+                )
+            else:
+                # Add logical device to the partition
+                partition.logical_device_ids = [self.devices[i].logical_id]
+                occupied_devices.append(self.devices[i].logical_id)
+        # Add logical devices to the non_embedding_partitions
+        for partition in non_embedding_partitions:
+            partition.logical_device_ids = occupied_devices
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def cost_aware_partition(
+        self,
+        transfer_rate_bytes_per_sec: float,
+        node_to_latency_mapping: dict[Node, NodeLatency],
+    ) -> None:
+        """This method is to partition the fx module based on the cost.
+        The cost is the total latency of running the whole fx module.
+        In partitioner_utils.py, the cost model is built.
+        The cost aware partition algorithm is:
+        #1. At every beginning, each node is a partition.
+            Then we map all the partitions to the devices
+            and calculate the cost
+        #2. Then try to pre-combine any two of the partitions if the two
+            partitions can be combined.
+            (the bfs level is less than 2 or two partitions are connected and
+            can find partition to device mapping)
+            See if any partition pair could reduce the current cost.
+            Choose the pair that shows the minimum cost and then combine them
+        #3. Repeat #2 until the cost cannot be reduced.
+        """
+
+        def try_combining_partitions(p0_index, p1_index, partitions) -> float:
+            """Given two partitions and a list of partitions, combine these two partitions
+            and see what is the cost of the modified partition list
+            """
+            p0 = partitions[p0_index]
+            p1 = partitions[p1_index]
+            """If two partitions' bfs level are less than 2 or two partitions are connected to each other,
+               then they can be combined
+            """
+            if (
+                (abs(p0.bfs_level - p1.bfs_level) <= 1)
+                or (p0 in p1.parents)
+                or p0 in (p1.children)
+            ):
+                combine_two_partitions(p0, p1, partitions)
+                # Check if a circular dependency exists after combining
+                if check_dependency(partitions[-1]):
+                    return float("inf")
+                # Check if the modified partition list can be mapped to devices after combination
+                reset_partition_device(partitions)
+                found_deivce = get_device_to_partitions_mapping(
+                    partitions, self.devices
+                )
+                if not found_deivce:
+                    return float("inf")
+                # Calculate the new cost
+                partition_to_latency_mapping = get_partition_to_latency_mapping(
+                    partitions, node_to_latency_mapping
+                )
+                cost = get_latency_of_partitioned_graph(
+                    partitions,
+                    partition_to_latency_mapping,
+                    transfer_rate_bytes_per_sec,
+                )
+                return cost
+            # If two partition can not be combined, the cost is inf
+            return float("inf")
+
+        def search_combination(
+            transfer_rate_bytes_per_sec, node_to_latency_mapping
+        ) -> bool:
+            """Given transfer rate between partitions and each node's latency,
+            find two partitions to combine so the cost of the partitions can
+            be reduced.
+            The algorithm is :
+            1. Go through all the partition pairs and see
+            if any pair of partitions can be combined.
+            2. Calculate the cost after the combination.
+            3. Select the minimum cost and combine its corresponding partition pair.
+            """
+            partition_to_latency_mapping = get_partition_to_latency_mapping(
+                self.partitions, node_to_latency_mapping
+            )
+            cost = get_latency_of_partitioned_graph(
+                self.partitions,
+                partition_to_latency_mapping,
+                transfer_rate_bytes_per_sec,
+            )
+            if len(self.partitions) == 1:
+                return False
+            partition_pair: list[int] = []
+            for i in range(len(self.partitions) - 1):
+                for j in range(i + 1, len(self.partitions)):
+                    # Try to combine the partition pair
+                    # and see the new cost after combination
+                    new_cost = try_combining_partitions(i, j, self.partitions[:])
+                    if new_cost <= cost:
+                        partition_pair = [i, j]
+                        cost = new_cost
+                    reorganize_partitions(self.partitions)
+            # If a partition pair is found, combine them
+            if len(partition_pair) != 0:
+                p0 = self.partitions[partition_pair[0]]
+                p1 = self.partitions[partition_pair[1]]
+                combine_two_partitions(p0, p1, self.partitions)
+            get_bfs_level_partition(self.partitions)
+            reset_partition_device(self.partitions)
+            get_device_to_partitions_mapping(self.partitions, self.devices)
+            return len(partition_pair) != 0
+
+        for node in self.graph_module.graph.nodes:
+            if node.op not in {"placeholder", "get_attr", "output"}:
+                self.create_single_node_partition(node)
+        # Set up parent partitions and children partitions for each partition
+        set_parents_and_children(self.partitions)
+        # Get bfs level for each partition
+        get_bfs_level_partition(self.partitions)
+        find_combination = True
+        while find_combination:
+            # Search for a pair partition to generate the minimum new cost,
+            # then combine them
+            find_combination = search_combination(
+                transfer_rate_bytes_per_sec, node_to_latency_mapping
+            )
+        # Make sure all partitions are set up correctly
+        reorganize_partitions(self.partitions)
+        # Set up node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def kl_based_partition(
+        self,
+        transfer_rate_bytes_per_sec: float,
+        node_to_latency_mapping: dict[Node, NodeLatency],
+    ) -> None:
+        """This function is a cost aware partition based
+        on Kernighan-Lin algorithm.
+        First, the graph is partitioned using size_based_partition.
+        Then, each node is swapped with any other node in a different
+        partition, and at the same time, the cost is estimated after
+        the swapping.
+        For example, we have nodes n0, n1, n2, n3 and n4.
+        Using size_based_partition, n0 and n1 are in Partition p0.
+        n2, n3 and n4 in Partition p1. The current cost is estimated.
+        We first tried using n0 to swap with n2 from the other partition.
+        Then we see that swapping n0 and n2 shows a lower cost
+        than the current cost and it is the minimum among other pairs like
+        (n0, None)(This means moving n0 to Partition without swapping other nodes),
+        (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost
+        as the current cost.
+        Then We repeat this process for all the other nodes until all swapping pairs
+        are tried.
+        """
+
+        def swap_nodes(n0, n1, p0, p1):
+            # Either n0 or n1 could be None
+            # That means we simply move the node
+            # to another partition
+            if n0 is not None:
+                p0.remove_node(n0)
+                p1.add_node(n0)
+            if n1 is not None:
+                p0.add_node(n1)
+                p1.remove_node(n1)
+
+        def try_swap_nodes(
+            n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+        ):
+            cost = float("inf")
+            swap_nodes(n0, n1, p0, p1)
+            # Reorganize partitions after swapping
+            reorganize_partitions(self.partitions)
+            # Check if there is a circular dependency after swapping
+            if (not check_dependency(p0)) and (not check_dependency(p1)):
+                reset_partition_device(self.partitions)
+                partition_to_latency_mapping = get_partition_to_latency_mapping(
+                    self.partitions, node_to_latency_mapping
+                )
+                # Check if all partitions can be mapped to logical devices after swapping
+                found_device = get_device_to_partitions_mapping(
+                    self.partitions, self.devices
+                )
+                if not found_device:
+                    cost = float("inf")
+                else:
+                    cost = get_latency_of_partitioned_graph(
+                        self.partitions,
+                        partition_to_latency_mapping,
+                        transfer_rate_bytes_per_sec,
+                    )
+            # Swap back and reset all partitions back to original
+            swap_nodes(n1, n0, p0, p1)
+            reorganize_partitions(self.partitions)
+            reset_partition_device(self.partitions)
+            get_device_to_partitions_mapping(self.partitions, self.devices)
+            return cost
+
+        def swap_node_to_partition(
+            node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+        ):
+            """This function helps to swap one node from partition p0
+            with all the nodes in another partition p1
+            """
+            p1_nodes = list(p1.nodes) + [None]
+            min_cost = float("inf")
+            node_pair: list[Node] = []
+            for n1 in p1_nodes:
+                # Ignore the node if it is not a op node
+                if n1 is not None and n1.op in {"placeholder", "get_attr"}:
+                    continue
+                # Try swapping node in p0 with n1 in p1
+                cost = try_swap_nodes(
+                    node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+                )
+                if cost < min_cost:
+                    # pyrefly: ignore [bad-assignment]
+                    node_pair = [node, n1]
+                    min_cost = cost
+            return cost, node_pair  # type: ignore[possibly-undefined]
+
+        # First use size_base_partition
+        self.size_based_partition()
+        partition_to_latency_mapping = get_partition_to_latency_mapping(
+            self.partitions, node_to_latency_mapping
+        )
+        # Calculate the cost of the partitions
+        cost = get_latency_of_partitioned_graph(
+            self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
+        )
+        # Keep tracking the node pair that shows the better cost
+        node_pair: list[Node] = []
+        # Keep tracking the partition pair of node pair
+        partition_pair: list[Partition] = []
+        # Collect all the op nodes from the graph
+        op_nodes = [
+            n
+            for n in self.graph_module.graph.nodes
+            if n.op not in {"placeholder", "get_attr", "output"}
+        ]
+        for node in op_nodes:
+            # Find which partition the current node belongs
+            p0_index = self.node_to_partition[node]
+            p0 = self.partitions[p0_index]
+            # Go through all the other partitions to swap
+            # with other nodes from those partitions
+            for p1_index, _ in enumerate(self.partitions):
+                if p0_index != p1_index:
+                    p1 = self.partitions[p1_index]
+                    new_cost, new_node_pair = swap_node_to_partition(
+                        node,
+                        p0,
+                        p1,
+                        node_to_latency_mapping,
+                        transfer_rate_bytes_per_sec,
+                    )
+                    # Update the cost
+                    # Track the swapped node pair and their partitions
+                    if new_cost < cost:
+                        cost = new_cost
+                        node_pair = new_node_pair
+                        partition_pair = [p0, p1]
+            # Do the swapping after trying all the nodes from a partition
+            if len(node_pair) != 0:
+                swap_nodes(
+                    node_pair[0], node_pair[1], partition_pair[0], partition_pair[1]
+                )
+                reorganize_partitions(self.partitions)
+                get_device_to_partitions_mapping(self.partitions, self.devices)
+        reorganize_partitions(self.partitions)
+        # Mapping the device to the partition
+        get_device_to_partitions_mapping(self.partitions, self.devices)
+        return
+
+    def aot_based_partition(
+        self, node_to_partition_mapping, partition_to_logical_device_mapping
+    ):
+        """This function helps to rebuild the partitions given the nodes and its
+        corresponding partition id
+        """
+        partition_id_to_partition_mapping: dict[int, Partition] = {}
+        self.node_to_partition = node_to_partition_mapping
+        for node in self.node_to_partition:
+            partition_id = self.node_to_partition[node]
+            # If the requested partition has not been created, create the partition
+            if partition_id not in partition_id_to_partition_mapping:
+                partition = Partition(partition_id)
+                self.partitions.append(partition)
+                partition_id_to_partition_mapping[partition_id] = partition
+                partition.logical_device_ids = partition_to_logical_device_mapping[
+                    partition_id
+                ]
+            else:
+                partition = partition_id_to_partition_mapping[
+                    self.node_to_partition[node]
+                ]
+            # Add the current node into the partition
+            partition.add_node(node)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py
new file mode 100644
index 0000000000000000000000000000000000000000..f494f11593410467623b680a7587e50a614be5a7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py
@@ -0,0 +1,354 @@
+# mypy: allow-untyped-defs
+import re
+from collections.abc import Callable
+from typing import Optional, Union
+
+import torch.fx
+from torch.fx.node import map_arg
+from torch.fx.passes.split_module import split_module
+
+
+__all__ = [
+    "FoldedGraphModule",
+    "get_unique_attr_name_in_module",
+    "split_const_subgraphs",
+]
+
+
+class FoldedGraphModule(torch.fx.GraphModule):
+    """
+    FoldedGraphModule is a GraphModule which also contains another
+    `const_subgraph_module` representing a subgraph which has all const attr
+    inputs and which can be run once before running the main standard
+    `graph`. The `const_output_names` are the ordered list names of attrs which
+    represent what each respective output from the const_subgraph should be set
+    on which attrs.
+    """
+
+    def __init__(
+        self,
+        root: torch.nn.Module,
+        graph: torch.fx.Graph,
+        const_subgraph: Optional[torch.fx.Graph] = None,
+        fx_const_folded_attrs_name: Optional[str] = None,
+        device_for_folded_attrs: str = "cuda",
+    ):
+        super().__init__(root, graph)
+        self.const_subgraph_module = (
+            None
+            if const_subgraph is None
+            else torch.fx.GraphModule(root, const_subgraph)
+        )
+        self.has_folding_been_run = False
+        self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
+        self.device_for_folded_attrs = device_for_folded_attrs
+
+    def __call__(self, *args, **kwargs):
+        if not self.has_folding_been_run:
+            self.run_folding()
+        return super().__call__(*args)
+
+    def run_folding(self):
+        # If there's no const subgraph module or attr output names to use, return
+        # early as there is no const folding to perform.
+        if (
+            self.const_subgraph_module is None
+            or self.fx_const_folded_attrs_name is None
+        ):
+            return
+
+        assert not self.has_folding_been_run
+        self.has_folding_been_run = True
+
+        # Actually run const folding subgraph. Note that single attr const fold
+        # subgraphs output a single Tensor while multiple outputs are returned as
+        # Tuple[Tensor,].
+        folded_attrs = self.const_subgraph_module()
+
+        def _create_param(i):
+            return torch.nn.Parameter(
+                i.detach().clone()
+                if not isinstance(i, int)
+                else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
+                requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
+            )
+
+        params = (
+            torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
+            if isinstance(folded_attrs, tuple)
+            else _create_param(folded_attrs)
+        )
+        setattr(self, self.fx_const_folded_attrs_name, params)
+
+
+def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
+    """
+    Given `gm` and some graph module which is called with target name `inline_mod_name`,
+    this helper will inline all of the nodes from that called graph module into `gm`.
+    """
+    # Fetch the inner graph module that we want to inline inside `gm`.
+    inline_mod = dict(gm.named_modules())[inline_mod_name]
+    assert isinstance(inline_mod, torch.fx.GraphModule)
+    call_mod_node_to_replace = None
+    for node in gm.graph.nodes:
+        if node.op == "call_module" and node.target == inline_mod_name:
+            call_mod_node_to_replace = node
+            break
+    assert call_mod_node_to_replace is not None
+
+    # Now actually do the swap. Note that we have to keep track of new nodes that are
+    # copied into `gm` -- we do this via replacement_mapping.
+    call_mod_args = call_mod_node_to_replace.args
+    call_mod_kwargs = call_mod_node_to_replace.kwargs
+
+    replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
+    ph_count = 0
+
+    def replacement_fn(node):
+        new_node = replacement_mapping[node]
+        new_node.meta = node.meta.copy()
+        return new_node
+
+    for inline_node in inline_mod.graph.nodes:
+        if inline_node.op == "placeholder":
+            replacement_mapping[inline_node] = (
+                call_mod_kwargs[inline_node.name]
+                if inline_node.name in call_mod_kwargs
+                else call_mod_args[ph_count]
+            )
+
+            ph_count += 1
+            continue
+
+        if inline_node.op == "output":
+            outputs = inline_node.args[0]
+            output_replacements = map_arg(outputs, replacement_fn)
+            call_mod_node_to_replace.replace_all_uses_with(output_replacements)
+            continue
+
+        with gm.graph.inserting_before(call_mod_node_to_replace):
+            new_node = gm.graph.node_copy(inline_node, replacement_fn)
+        replacement_mapping[inline_node] = new_node
+
+    # Explicitly remove the module that was just inlined,
+    # this module may contain impure ops so cannot be dead code eliminated,
+    # this module is unneeded as it's just inlined back to main graph.
+    gm.graph.erase_node(call_mod_node_to_replace)
+    gm.graph.eliminate_dead_code()
+
+
+def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
+    """
+    Make sure the name is unique (in a module) and can represents an attr.
+    """
+    # Delete all characters that are illegal in a Python identifier.
+    name = re.sub("[^0-9a-zA-Z_]+", "_", name)
+    if name[0].isdigit():
+        name = f"_{name}"
+    # Now make sure it is in fact unique to the module by incrementing suffix value.
+    while hasattr(mod_traced, name):
+        match = re.match(r"(.*)_(\d+)$", name)
+        if match is None:
+            name = name + "_1"
+        else:
+            base, num = match.group(1, 2)
+            name = f"{base}_{int(num) + 1}"
+
+    return name
+
+
+def split_const_subgraphs(
+    module: Union[torch.nn.Module, torch.fx.GraphModule],
+    skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
+    device_for_folded_attrs: str = "cpu",
+) -> FoldedGraphModule:
+    """
+    Looks through `module` for any nodes that have all constant attribute inputs
+    and separates them out into their own constant subgraph, and returns a
+    FoldedGraphModule which runs that constant subgraph on the first run to set
+    attributes on the module prior to running the non-constant portion of the
+    graph.
+    """
+
+    import sympy
+
+    if not isinstance(module, torch.fx.GraphModule):
+        mod_traced = torch.fx.symbolic_trace(module)
+    else:
+        mod_traced = module
+
+    def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
+        """
+        Return True if a GraphModule type subgraph contains any impure op, else False.
+        """
+        assert isinstance(module, torch.fx.GraphModule), (
+            "caller should only pass GraphModule to subgraph_has_impure_ops check"
+        )
+        for node in module.graph.nodes:
+            if node.op == "call_function" and node.is_impure():
+                return True
+            if (
+                # pyrefly: ignore [invalid-argument]
+                node.op == "call_module"
+                # pyrefly: ignore [not-callable]
+                and (submodule := module.get_submodule(node.target))
+                and isinstance(submodule, torch.fx.GraphModule)
+            ):
+                return _subgraph_has_impure_ops(submodule)
+        return False
+
+    # Build up a list of const_nodes, defined as nodes that are themselves
+    # get_attrs, or have all get_attr or other constant node inputs.
+    const_nodes: set[torch.fx.Node] = set()
+    found_const_folding = False
+    for node in mod_traced.graph.nodes:
+        # Skip over placeholders/outputs because they can't be const folded and
+        # we don't want to add tags to them.
+        if node.op in {"placeholder", "output"}:
+            continue
+
+        # If the node itself is constant, or all of its inputs are constant,
+        # then tag it as constant.
+        if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
+            const_nodes
+        ):
+            continue
+
+        # If provided skip folding function says to skip, then skip.
+        if skip_folding_node_fn and skip_folding_node_fn(node):
+            continue
+
+        # Skip folding side-effectful functions
+        if node.is_impure():
+            continue
+
+        # Skip folding nodes that have symbolic fill_value
+        if isinstance(node.kwargs.get("fill_value", None), sympy.Expr):
+            continue
+
+        # Skip folding submodules that have impure ops
+        if (
+            # pyrefly: ignore [invalid-argument]
+            node.op == "call_module"
+            # pyrefly: ignore [not-callable]
+            and (target_mod := mod_traced.get_submodule(node.target))
+            and isinstance(target_mod, torch.fx.GraphModule)
+            and _subgraph_has_impure_ops(target_mod)
+        ):
+            continue
+
+        # Must be a constant foldable node at this point.
+        const_nodes.add(node)
+        if node.op != "get_attr":
+            found_const_folding = True
+
+    # If we did not find any const folding then return early without a const fold subgraph.
+    if not found_const_folding:
+        return FoldedGraphModule(mod_traced, mod_traced.graph)
+
+    # Partition the module into two: submod_0 for constant folding subgraph, and
+    # submod_1 for the rest.
+    def mod_partition(node: torch.fx.Node):
+        return 0 if node in const_nodes else 1
+
+    split = split_module(mod_traced, module, mod_partition)
+
+    const_mod_name, non_const_mod_name = "submod_0", "submod_1"
+    # Safely get submod_1 in case there are no non-const nodes
+    const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
+
+    # The module that a call_module node refers to gets copied to submodules during split.
+    # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
+    # attach inlined modules to `split` as it's the owning module now.
+    for node in non_const_gm.graph.nodes if non_const_gm else []:
+        if node.op == "call_module":
+            setattr(split, node.target, getattr(non_const_gm, node.target))
+    for node in const_gm.graph.nodes:
+        if node.op == "call_module":
+            setattr(split, node.target, getattr(const_gm, node.target))
+
+    # split_module currently does not use get_attrs for attrs. Instead it passes
+    # them in as args from the parent module, which used get_attrs. Here we set
+    # them as get_attrs inside const_gm, allowing for running folding without
+    # somehow a priori knowing the attrs that should be passed as args. We can
+    # unconditionally do this for all placeholders because we know all
+    # placeholders to const_gm must be constants accessible via get_attr.
+    call_const_gm_args = None
+    for node in split.graph.nodes:
+        if node.op == "call_module":
+            if node.target == const_mod_name:
+                call_const_gm_args = node.args
+                break
+    assert call_const_gm_args is not None
+
+    # Here we do the actual replacement of placeholders to get_attrs. Note that here we
+    # set the const_gm.graph into a new root_const_gm with split as the root module,
+    # because we are fetching attributes directly from the root module, instead of
+    # fetching them from const_gm. Example: The const_gm must have some format like:
+    # graph():
+    #    %inp : [num_users=1] = placeholder[target=const_inp]
+    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
+    #    return add
+    # We replace that with the following, which does not have any placeholders:
+    # graph():
+    #    %inp_1 : [num_users=1] = get_attr[target=const_inp]
+    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
+    #    return add
+    root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
+
+    # The order of placeholders in the const_gm graph should match the order of
+    # args in the outer module, so we can simply use an index for the
+    # placeholder mapping
+    ph_idx = 0
+    for node in root_const_gm.graph.nodes:
+        if node.op == "output":
+            multiple_outputs = isinstance(node.args[0], tuple)
+            continue
+        if node.op != "placeholder":
+            continue
+        assert ph_idx < len(call_const_gm_args)
+        in_node = call_const_gm_args[ph_idx]
+        ph_idx += 1
+        assert in_node.op == "get_attr"
+        with root_const_gm.graph.inserting_before(node):
+            new_node = root_const_gm.graph.get_attr(in_node.target)
+        new_node.meta = node.meta.copy()
+        node.replace_all_uses_with(new_node)
+        root_const_gm.graph.erase_node(node)
+    assert "multiple_outputs" in locals()
+
+    # Now find the call to const_gm inside split, and replace it with a getattr to the
+    # folded tensor(s) that result from constant folding. Note that we don't need to
+    # worry about whether this is one or more tensors because the original graph
+    # correctly uses getitem to extract individual tensors if there are multiple folded.
+    fx_const_folded_attrs_name = get_unique_attr_name_in_module(
+        mod_traced, "_FX_CONST_FOLDED_ATTRS"
+    )
+    setattr(
+        split,
+        fx_const_folded_attrs_name,
+        torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),  # type: ignore[possibly-undefined]
+    )
+    for node in split.graph.nodes:
+        if node.op == "call_module" and node.target == const_mod_name:
+            with node.graph.inserting_before(node):
+                folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
+            folded_attrs.meta = node.meta.copy()
+            node.replace_all_uses_with(folded_attrs)
+            break
+
+    # Finally, inline the non-constant submod (if it exists) into the split submod.
+    # This is so that the original caller who may have passed in a graph module will
+    # get back out a graph module whose graph is traced to the same granularity.
+    if hasattr(split, non_const_mod_name):
+        _inline_module(split, non_const_mod_name)
+
+    split.graph.eliminate_dead_code()
+
+    return FoldedGraphModule(
+        split,
+        split.graph,
+        root_const_gm.graph,
+        fx_const_folded_attrs_name,
+        device_for_folded_attrs,
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/debug.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..b87dee9db9c73f0b4ea1a0a27682a167e125a71d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/debug.py
@@ -0,0 +1,33 @@
+from collections.abc import Sequence
+
+import torch.fx as fx
+
+
+__all__ = ["set_trace"]
+
+
+def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
+    """
+    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
+    `gm` gets run.
+
+    Args:
+        gm: graph module to insert breakpoint. It is then recompiled for it to
+            take effect.
+
+    Returns:
+        the `gm` with breakpoint inserted.
+    """
+
+    def insert_pdb(body: Sequence[str]) -> list[str]:
+        return ["import pdb; pdb.set_trace()\n", *body]
+
+    with gm.graph.on_generate_code(
+        make_transformer=lambda cur_transform: (
+            # new code transformer to register
+            lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
+        )
+    ):
+        gm.recompile()
+
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py
new file mode 100644
index 0000000000000000000000000000000000000000..58a62aee314607320bb5f7eb922192888fa172a5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py
@@ -0,0 +1,1011 @@
+# mypy: allow-untyped-defs
+import itertools
+import operator
+from collections.abc import Callable
+from functools import reduce
+from typing import TypeVar
+from typing_extensions import ParamSpec
+
+import sympy
+
+import torch
+from torch.fx.experimental.refinement_types import Equality
+from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
+from torch.fx.node import Node, Target
+from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn.modules.conv import Conv2d
+
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+_INFERENCE_RULES: dict[Target, Callable] = {}
+_REFINEMENT_RULES: dict[Target, Callable] = {}
+_RULES: dict[Target, Callable] = {}
+
+__all__ = [
+    "GraphTypeChecker",
+    "Refine",
+    "adaptiveavgpool2d_check",
+    "adaptiveavgpool2d_inference_rule",
+    "add_inference_rule",
+    "all_eq",
+    "bn2d_inference_rule",
+    "broadcast_types",
+    "calculate_out_dimension",
+    "conv2d_inference_rule",
+    "conv_refinement_rule",
+    "conv_rule",
+    "element_wise_eq",
+    "expand_to_tensor_dim",
+    "first_two_eq",
+    "flatten_check",
+    "flatten_inference_rule",
+    "flatten_refinement_rule",
+    "get_attr_inference_rule",
+    "get_greatest_upper_bound",
+    "get_parameter",
+    "linear_check",
+    "linear_inference_rule",
+    "linear_refinement_rule",
+    "maxpool2d_check",
+    "maxpool2d_inference_rule",
+    "register_algebraic_expressions_inference_rule",
+    "register_inference_rule",
+    "register_refinement_rule",
+    "relu_inference_rule",
+    "reshape_inference_rule",
+    "transpose_inference_rule",
+]
+
+
+def expand_to_tensor_dim(t, n):
+    """
+    Expand a type to the desired tensor dimension if possible
+    Raise an error otherwise.
+    - t is the given type
+    - n is a number of dimensions to expand to
+    """
+    if t == Dyn:
+        dims = [Dyn] * n
+        return TensorType(tuple(dims))
+    elif isinstance(t, TensorType):
+        if len(t.__args__) != n:
+            raise TypeError(
+                f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}"
+            )
+        return t
+    else:
+        raise TypeError(f"Cannot match the type {t}")
+
+
+def broadcast_types(t1, t2):
+    """
+    Applies broadcasting to both given types such that they
+    become consistent with each other and returns two new
+    resulting types
+    """
+
+    # if either type is Dyn, do nothing since the types are already consistent
+    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
+        return t1, t2
+
+    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
+        s1 = len(t1.__args__)
+        s2 = len(t2.__args__)
+
+        new_t1 = list(t1.__args__)
+        new_t2 = list(t2.__args__)
+
+        # We make the types the same length which is the first requirement
+        # for consistency
+        if s1 > s2:
+            for _ in range(s1 - s2):
+                new_t2.insert(0, 1)
+
+        elif s2 > s1:
+            for _ in range(s2 - s1):
+                new_t1.insert(0, 1)
+
+        # we replace occurrences of "1" with each tensor with
+        # the corresponding type from the other tensor
+        for i, (x, y) in enumerate(zip(new_t1, new_t2)):
+            if x == 1:
+                new_t1[i] = y
+            elif y == 1:
+                new_t2[i] = x
+
+        # at this point our tensors should be consistent
+        # and we can apply the element-wise operation and find the right dimension
+        # for the output of the operation
+        (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
+        return (t1, t2)
+    else:
+        raise TypeError(f"Cannot broadcast types {t1} and {t2}")
+
+
+def register_inference_rule(
+    call_target: Target,
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+    def register(fn: Callable[_P, _T]) -> Callable[_P, _T]:
+        if call_target in _INFERENCE_RULES:
+            raise RuntimeError(f"Inference rule already registered for {call_target}!")
+        _INFERENCE_RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+def register_refinement_rule(
+    call_target: Target,
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+    def register(fn: Callable[_P, _T]) -> Callable[_P, _T]:
+        if call_target in _REFINEMENT_RULES:
+            raise RuntimeError(f"Refinement rule already registered for {call_target}!")
+        _REFINEMENT_RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+def register_algebraic_expressions_inference_rule(
+    call_target: Target,
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+    def register(fn: Callable[_P, _T]) -> Callable[_P, _T]:
+        if call_target in _RULES:
+            raise RuntimeError(f"Rule already registered for {call_target}!")
+        _RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+@register_inference_rule(torch.add)
+@register_inference_rule(operator.add)
+def add_inference_rule(n: Node):
+    """
+    Apply the addition inference rule. This includes:
+    - scalar addition
+    - broadcasting semantics
+
+    Note that we always return the least precise type between
+    the operands (after applying broadcasting) to be the final type of the operation
+
+    Note that we do not modify the operand types themselves after applying broadcasting
+    to them. We only use them to calculate the final type
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+    t1 = n.args[0].type
+    t2 = n.args[1].type
+
+    # handle scalar addition
+    if t1 is int and isinstance(t2, TensorType):
+        n.type = t2
+        return n.type
+
+    # handle scalar addition
+    elif t2 is int and isinstance(t1, TensorType):
+        n.type = t1
+        return n.type
+
+    # we bring the new types to the point where
+    # we can check for consistency
+    # any inconsistency would not have been caused
+    # by broadcasting at this point
+    (new_t1, new_t2) = broadcast_types(t1, t2)
+
+    if new_t1 != t1 or new_t2 != t2:
+        n.meta["broadcast"] = True
+        n.meta[str(n.args[0])] = new_t1
+        n.meta[str(n.args[1])] = new_t2
+
+    else:
+        n.meta["broadcast"] = False
+
+    new_t1 = t1 if not n.meta["broadcast"] else new_t1
+    new_t2 = t2 if not n.meta["broadcast"] else new_t2
+
+    # we check for consistency between the new types
+    if is_consistent(new_t1, new_t2):
+        # we return the less precise type because
+        # broadcasting may have happened
+        # for operands with shape [1,2,Dyn] and [1,2,1]
+        # we have to assign the node [1,2,Dyn]
+        if is_more_precise(new_t1, new_t2):
+            n.type = new_t2
+        else:
+            n.type = new_t1
+        return n.type
+    else:
+        raise TypeError(
+            f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}."
+            f" Types should match "
+        )
+
+
+@register_inference_rule(getattr)
+def get_attr_inference_rule(n: Node, traced):
+    """
+    The current getattr rule only handles the shape attribute
+    Can be extended to other attributes
+    The most representitive type we have is "Dyn" but the system
+    can be extended with more types, such as a type to represent shapes
+    """
+    attr_name = n.args[1]
+
+    if attr_name == "shape":
+        n.type = Dyn
+    else:
+        raise TypeError("Not yet implemented")
+
+    # TODO. We leave it like this till we add a type to represent tensor sizes
+    return n.type
+
+
+@register_inference_rule(torch.transpose)
+def transpose_inference_rule(n: Node):
+    """
+    We check that dimensions for the transpose operations
+    are within range of the tensor type of the node
+    """
+    if n.target is torch.transpose:
+        assert isinstance(n.args[0], Node)
+        t = n.args[0].type
+
+        assert isinstance(n.args[1], int)
+        assert isinstance(n.args[2], int)
+        dim1, dim2 = n.args[1], n.args[2]
+
+        if t == Dyn:
+            n.type = Dyn
+            return n.type
+
+        elif isinstance(t, TensorType):
+            if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
+                new_type = list(t.__args__)
+                new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
+                final = TensorType(new_type)
+                n.type = get_greatest_upper_bound(n.type, final)
+                return n.type
+            else:
+                raise TypeError(
+                    f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
+                )
+        else:
+            raise TypeError(
+                f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
+            )
+
+
+@register_inference_rule(torch.reshape)
+def reshape_inference_rule(n: Node):
+    """
+    Without dynamism, the rule checks that the
+    product of the elements of the argument tensor
+    type is equal to the product of the elements
+    of the required shape. We gradualize this rule
+    by adding a case to handle fully dynamic input
+    as well as input where some of the tensor dimensions
+    are unknown. In this case we check for divisibility
+    """
+    assert isinstance(n.args[0], Node)
+    t1 = n.args[0].type
+
+    assert isinstance(n.args[1], list)
+    t2 = n.args[1]
+    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])
+
+    # if we do not know the original tensor dimension,
+    # we return the required dimension
+    if t1 == Dyn:
+        n.type = t2_type
+        return t2_type
+
+    # if any of the dimensions are unknown,
+    # we check for divisibility
+    elif isinstance(t1, TensorType):
+        assert isinstance(t1, TensorType)
+        a = [e if e != Dyn else 1 for e in t1.__args__]
+        p1 = reduce(operator.mul, a)
+        p2 = reduce(operator.mul, t2)
+        if p1 % p2 == 0 or p2 % p1 == 0:
+            n.type = t2_type
+            return t2_type
+        else:
+            raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
+    else:
+        raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
+
+
+@register_inference_rule(BatchNorm2d)
+def bn2d_inference_rule(n: Node, module_instance):
+    """
+    Given a BatchNorm2D instance and a node check the following conditions:
+    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, x_3, x_4)
+    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
+    - t is consistent with t'
+    - x_2 is consistent with the module's num_features
+    - x_2' is consistent with the module's num_features
+    output type: the more precise type of t and t'
+    """
+    assert isinstance(n.args[0], Node)
+    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
+    arg_type = n.args[0].type
+    n.type = expand_to_tensor_dim(n.type, 4)
+
+    # we check the conditions on the incoming argument
+    # and any existing annotation
+    # we also check for consistency between both annotations
+    if (
+        is_consistent(arg_type.__args__[1], module_instance.num_features)
+        and is_consistent(n.type.__args__[1], module_instance.num_features)
+        and is_consistent(arg_type, n.type)
+    ):
+        # we choose the more precise type
+        # to be the node type
+        # so if an incoming argument has more type information
+        # we set this node's type to be the argument type
+        n.type = get_greatest_upper_bound(arg_type, n.type)
+        return n.type
+    else:
+        raise TypeError(
+            f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
+        )
+
+
+def calculate_out_dimension(d_in, module_instance, index):
+    """
+    For calculating h_in and w_out according to the conv2D documentation
+    """
+    padding = (
+        (module_instance.padding, module_instance.padding)
+        if isinstance(module_instance.padding, int)
+        else module_instance.padding
+    )
+    kernel_size = (
+        (module_instance.kernel_size, module_instance.kernel_size)
+        if isinstance(module_instance.kernel_size, int)
+        else module_instance.kernel_size
+    )
+    stride = (
+        (module_instance.stride, module_instance.stride)
+        if isinstance(module_instance.stride, int)
+        else module_instance.stride
+    )
+    dilation = (
+        (module_instance.dilation, module_instance.dilation)
+        if isinstance(module_instance.dilation, int)
+        else module_instance.dilation
+    )
+
+    DIMENSION_TYPES = (int, sympy.Symbol)
+
+    if d_in == Dyn:
+        return Dyn
+
+    elif isinstance(d_in, DIMENSION_TYPES):
+        n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1
+
+        return (n // stride[0]) + 1
+
+    else:
+        raise TypeError(
+            f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}"
+        )
+
+
+def get_greatest_upper_bound(type1, type2):
+    """
+    Get the most precise type that's consistent with the given types
+    """
+    if type1 == Dyn:
+        return type2
+    elif type2 == Dyn:
+        return type1
+    elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
+        if not is_consistent(type1, type2):
+            raise TypeError(f"Inconsistent types {type1}, {type2}")
+        gub = [
+            t1 if is_more_precise(t1, t2) else t2
+            for (t1, t2) in zip(type1.__args__, type2.__args__)
+        ]
+        return TensorType(tuple(gub))
+
+
+@register_inference_rule(Conv2d)
+def conv2d_inference_rule(n: Node, module_instance):
+    """
+    Given a Conv2D instance and a node check the following conditions:
+    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, H, W)
+    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
+    - x_2 is consistent with the module's in_channels
+    - let o = (x_1, out_channels, H_out, W_out)
+    then the output is the greatest upper bound of o and the existing node type t'.
+    """
+    assert isinstance(n.args[0], Node)
+    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
+    arg_type = n.args[0].type
+    curr_node_type = expand_to_tensor_dim(n.type, 4)
+
+    if is_consistent(arg_type.__args__[1], module_instance.in_channels):
+        w_in = arg_type.__args__[3]
+        h_in = arg_type.__args__[2]
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+        new_type = TensorType(
+            (arg_type.__args__[0], module_instance.out_channels, h_out, w_out)
+        )
+        gub = get_greatest_upper_bound(new_type, curr_node_type)
+        n.type = gub
+        return n.type
+    else:
+        raise TypeError(
+            f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
+        )
+
+
+@register_inference_rule(torch.nn.ReLU)
+def relu_inference_rule(n: Node, module_instance):
+    """
+    Input and output shapes should be equal.
+    """
+    assert isinstance(n.args[0], Node)
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+
+    if isinstance(n.args[0].type, TensorType):
+        n.type = get_greatest_upper_bound(n.args[0].type, n.type)
+    return n.type
+
+
+def maxpool2d_check(typ, module_instance):
+    """
+    Applies the maxpool2d shape information to the input
+    this affects the last two dimensions
+    """
+    new_type_list = list(typ.__args__)
+    if len(new_type_list) == 4 or len(new_type_list) == 3:
+        w_in = new_type_list[-1]
+        h_in = new_type_list[-2]
+
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+
+        new_type_list[-1] = w_out
+        new_type_list[-2] = h_out
+        return TensorType(tuple(new_type_list))
+
+    else:
+        raise TypeError(f"Wrong size {typ} for {module_instance}")
+
+
+@register_inference_rule(torch.nn.MaxPool2d)
+def maxpool2d_inference_rule(n: Node, module_instance):
+    """
+    Given a MaxPool2D instance and a node check the following conditions:
+    - Input size matches size 3 or 4
+    - Current node type is consistent with the output type we will calculate
+    - Input size matches output size and the last two dimensions of the output
+      are w_out and h_out. The remaining dimensions are the same as the input
+    - Our final result is the greatest upper bound of the output we calculate
+      and the current node type.
+    """
+    assert isinstance(n.args[0], Node)
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output = maxpool2d_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(output, n.type)
+    return n.type
+
+
+def linear_check(tensor_type, module_instance):
+    """
+    Checks that an input tensor type satisfies the conditions for linear operation
+    and returns the output type based on in and out features given by module_instance
+    """
+    if len(tensor_type.__args__) >= 2:
+        if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
+            new_type_args = list(tensor_type.__args__)
+            new_type_args[-1] = module_instance.out_features
+            return TensorType(tuple(new_type_args))
+        else:
+            raise TypeError(
+                f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}"
+            )
+    else:
+        raise TypeError(f"Type {tensor_type} must have rank 2 or more.")
+
+
+@register_inference_rule(torch.nn.Linear)
+def linear_inference_rule(n: Node, module_instance):
+    """
+    Applies the shape information to the input then gets the greatest upper bound
+    of the resulting type and the existing type
+    """
+    assert isinstance(n.args[0], Node)
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output_type = linear_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(output_type, n.type)
+    return n.type
+
+
+def adaptiveavgpool2d_check(tensor_type, module_instance):
+    output_size = module_instance.output_size
+    if isinstance(output_size, int):
+        output_size = [output_size, output_size]
+    elif isinstance(output_size, tuple):
+        output_size = list(output_size)
+        if output_size[0] is None:
+            output_size[0] = output_size[1]
+        if output_size[1] is None:
+            output_size[1] = output_size[0]
+
+    new_type_list = list(tensor_type.__args__)
+
+    if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3:
+        new_type_list[-1] = output_size[1]
+        new_type_list[-2] = output_size[0]
+
+        return TensorType(tuple(new_type_list))
+
+    else:
+        raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}")
+
+
+@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
+def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
+    """
+    The input and output sizes should be the same except for the last
+    two dimensions taken from the input, which represent width and height
+    """
+    assert isinstance(n.args[0], Node)
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(n.type, output_type)
+    return n.type
+
+
+def flatten_check(tensor_type, start_dim, end_dim):
+    l = len(tensor_type.__args__)
+
+    start_dim = l if start_dim == -1 else abs(start_dim)
+    end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
+
+    if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
+        my_args = list(tensor_type.__args__)
+        lhs = my_args[0:start_dim]
+        rhs = my_args[end_dim:]
+        mid = my_args[start_dim:end_dim]
+        if Dyn in mid:
+            mid = [Dyn]
+        else:
+            mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
+        new_type_list = lhs + mid + rhs
+        return TensorType(tuple(new_type_list))
+    else:
+        raise TypeError(
+            f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}"
+        )
+
+
+@register_inference_rule(torch.flatten)
+def flatten_inference_rule(n: Node):
+    """
+    Applies the flatten shape information to the input then gets the
+    greatest upper bound of the resulting type and the existing type
+    """
+    assert isinstance(n.args[0], Node)
+
+    # set the default start and end dims
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+
+    if isinstance(n.args[0].type, TensorType):
+        output_type = flatten_check(n.args[0].type, start_dim, end_dim)
+        n.type = get_greatest_upper_bound(output_type, n.type)
+
+    return n.type
+
+
+class GraphTypeChecker:
+    def __init__(self, env, traced):
+        self.env = env
+        self.traced = traced
+
+    def type_check(self):
+        """
+        A gradual type checker for graphs
+        Effect: every node's field type will be
+        populated with a type after type-checking is done
+        """
+        graph = self.traced.graph
+
+        # type check every node with gradual type rules
+        # if any node does not type check return false
+        for n in graph.nodes:
+            self.type_check_node(n)
+        return True
+
+    def type_check_node(self, n: Node):
+        """
+        Type check a given fx node.
+        Current operations:
+        - Reshape
+        - Transpose
+        - Add
+        - Relu
+        - conv2d
+        - batchnorm2d
+        - flatten
+        - maxpool2d
+        - adaptiveavgpool2d
+        - linear
+        """
+        if n.type is None:
+            n.type = Dyn
+
+        if n.op == "placeholder":
+            return n.type
+
+        elif n.op == "get_attr":
+            t = get_parameter(self.traced, n.target)  # type: ignore[arg-type]
+            if isinstance(t.data, torch.Tensor):
+                n.type = TensorType(t.data.shape)
+            return n.type
+
+        elif n.op == "call_function":
+            if n.target is getattr:
+                assert getattr in _INFERENCE_RULES
+                return _INFERENCE_RULES[n.target](n, self.traced)
+
+            elif n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](n)
+            else:
+                raise RuntimeError(
+                    f"No inference rule registered for target {n.target}!"
+                )
+
+        elif n.op == "call_module":
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _INFERENCE_RULES:
+                return _INFERENCE_RULES[type(module_instance)](n, module_instance)
+            else:
+                raise RuntimeError(
+                    f"No inference rule registered for class {type(module_instance)}!"
+                )
+
+        elif n.op == "output":
+
+            def get_node_type(a):
+                return a.type
+
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+        else:
+            raise NotImplementedError(f"Method {n.op} not yet implemented")
+
+
+@register_refinement_rule(Conv2d)
+def conv_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
+        return res
+
+
+@register_refinement_rule(torch.nn.Linear)
+def linear_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
+    return res
+
+
+@register_refinement_rule(BatchNorm2d)
+@register_refinement_rule(torch.nn.ReLU)
+def all_eq(n: Node):
+    """
+    For operations where the input shape is equal to the output shape
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        args1 = arg_type.__args__
+        args2 = n.type.__args__
+        res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
+    return res
+
+
+@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
+@register_refinement_rule(torch.nn.MaxPool2d)
+def first_two_eq(n: Node):
+    """
+    For operations where the first two dimensions of the input and output shape
+    are equal
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        args1 = arg_type.__args__
+        args2 = n.type.__args__
+        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
+    return res
+
+
+@register_refinement_rule(torch.add)
+@register_refinement_rule(operator.add)
+def element_wise_eq(n: Node):
+    """
+    For element-wise operations and handles broadcasting.
+    Note that after applying broadcasting to the arguments
+    we are able to determine if certain dimensions have not been broadcast
+    if they are symbolicallu equal.
+
+    in this case, we can establish equality between those dimensions and the
+    corresponding output dimensions.
+
+    Note that it takes two iterations for this result. One iteration to establish
+    equality between certain dimensions of the operands (requiring the whole solver
+    including unification) and another iteration to establish equality between the operands
+    and the resulting type, requiring another round of constraint generation and unificaiton.
+    """
+    res = []
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        arg_type1 = n.args[0].type
+        arg_type2 = n.args[1].type
+        if (
+            isinstance(arg_type1, TensorType)
+            and isinstance(arg_type2, TensorType)
+            and isinstance(n.type, TensorType)
+        ):
+            args1, args2 = broadcast_types(arg_type1, arg_type2)
+            # by this point, we know that args1 and args2 are the same size.
+            a1 = args1.__args__
+            a2 = args2.__args__
+            a3 = n.type.__args__
+
+            # we would be here in the second iteration where we establish equality
+            # between operand type dimensions and the resulting type dimensions
+            r = []
+            for x, y, z in zip(a1, a2, a3):
+                if x == y:
+                    r.append(Equality(x, z))
+            res = r
+    return res
+
+
+@register_refinement_rule(torch.flatten)
+def flatten_refinement_rule(n: Node):
+    """
+    Generates equality constraints between the dimensions of the input and output
+    that will not be involved in the flatten operation
+    """
+    assert isinstance(n.args[0], Node)
+
+    eq_const = []
+
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType):
+        l = len(n.type.__args__)
+        arg_type = n.args[0].type
+        start_dim = l if start_dim == -1 else start_dim
+        end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
+
+        for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]):
+            eq_const.append(Equality(t1, t2))
+
+        for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]):
+            eq_const.append(Equality(t1, t2))
+    return eq_const
+
+
+@register_algebraic_expressions_inference_rule(Conv2d)
+def conv_rule(n: Node, module_instance):
+    """
+    Represents the output in terms of an algrbraic expression w.r.t
+    the input when possible
+    """
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        w_in = arg_type.__args__[3]
+        h_in = arg_type.__args__[2]
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+        new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
+        n.type = new_type
+        return new_type
+
+
+class Refine:
+    """
+    Symbolic shape inference.
+    Generates constraints over type variables.
+    Currently all constraints are equality constraints.
+    """
+
+    def __init__(self, traced):
+        self.constraints = []
+        self.traced = traced
+        self.symbol_iter = itertools.count(start=0, step=1)
+
+    def refine(self):
+        """
+        Generates constraints for
+        every node in the graph based on
+        the operation.
+        """
+        graph = self.traced.graph
+        for n in graph.nodes:
+            self.refine_node(n)
+        return True
+
+    def symbolic_relations(self):
+        """
+        Infers algebraic relations
+        """
+        graph = self.traced.graph
+        for n in graph.nodes:
+            self.infer_symbolic_relations(n)
+        return True
+
+    def replace_dyn_with_fresh_var(self, typ):
+        """
+        Replace all unknown types with fresh type variables.
+        """
+        if typ == Dyn:
+            new_symbol = Var(next(self.symbol_iter))
+            return new_symbol
+        elif isinstance(typ, TensorType):
+            new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
+            return TensorType(tuple(new_args))
+        elif isinstance(typ, list):
+            return [self.replace_dyn_with_fresh_var(t) for t in typ]
+        elif isinstance(typ, tuple):
+            return (self.replace_dyn_with_fresh_var(t) for t in typ)
+        else:
+            return typ
+
+    def convert_to_sympy_symbols(self, typ):
+        """
+        Replace all unknown types with fresh type variables.
+        """
+        if isinstance(typ, Var):
+            return sympy.symbols(str(typ))
+        elif isinstance(typ, TensorType):
+            new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
+            return TensorType(tuple(new_args))
+        elif isinstance(typ, list):
+            return [self.convert_to_sympy_symbols(t) for t in typ]
+        elif isinstance(typ, tuple):
+            return (self.convert_to_sympy_symbols(t) for t in typ)
+        else:
+            return typ
+
+    def refine_node(self, n: Node):
+        """
+        Returns a list of equality constraints for
+        call_module and call_function nodes.
+        Models the relation between input and output dimensions
+        using constraints in case they are both tensors.
+        All operations used in resnet50 are defined.
+        """
+        if n.type is None:
+            n.type = Dyn
+
+        n.type = self.replace_dyn_with_fresh_var(n.type)
+
+        if n.op == "call_function":
+            if n.target in _REFINEMENT_RULES:
+                self.constraints += _REFINEMENT_RULES[n.target](n)
+
+        if n.op == "call_module":
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _REFINEMENT_RULES:
+                self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
+
+        if n.op == "output":
+
+            def get_node_type(a):
+                return a.type
+
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+    def infer_symbolic_relations(self, n: Node):
+        n.type = self.convert_to_sympy_symbols(n.type)
+        if n.op == "call_function":
+            if n.target in _RULES:
+                return _RULES[n.target](n)
+
+        if n.op == "call_module":
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _RULES:
+                return _RULES[type(module_instance)](n, module_instance)
+
+        if n.op == "output":
+
+            def get_node_type(a):
+                return a.type
+
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+
+def get_parameter(traced, target: str):
+    """
+    Returns the parameter given by ``target`` if it exists,
+    otherwise throws an error.
+
+    See the docstring for ``get_submodule`` for a more detailed
+    explanation of this method's functionality as well as how to
+    correctly specify ``target``.
+
+    Args:
+        target: The fully-qualified string name of the Parameter
+            to look for. (See ``get_submodule`` for how to specify a
+            fully-qualified string.)
+
+    Returns:
+        torch.nn.Parameter: The Parameter referenced by ``target``
+
+    Raises:
+        AttributeError: If the target string references an invalid
+            path or resolves to something that is not an
+            ``nn.Parameter``
+    """
+    module_path, _, param_name = target.rpartition(".")
+
+    mod: torch.nn.Module = traced.get_submodule(module_path)
+
+    if not hasattr(mod, param_name):
+        raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
+
+    param: torch.nn.Parameter = getattr(mod, param_name)
+
+    return param
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd239d78842dd8ba3cbfbf2d03e259a19427489b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py
@@ -0,0 +1,178 @@
+# mypy: allow-untyped-defs
+import itertools
+import operator
+
+import torch
+from torch.fx._symbolic_trace import symbolic_trace
+from torch.fx.node import Node
+from torch.fx.passes.tools_common import legalize_graph
+
+
+def split_result_tensors(
+    result: torch.Tensor, inputs: list[torch.Tensor]
+) -> tuple[torch.Tensor, ...]:
+    """
+    A free function for use in the merge_matmul graph transformation below that
+    splits the output from a merged matmul into the individual results for each
+    input tensor.
+
+    Arguments:
+        result: The merged matmul result tensor.
+        inputs: The list of inputs that were merged into one for the matmul.
+
+    Returns:
+        List of matmul results for each input tensor.
+    """
+    # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
+    # need an int even when tracing
+    if isinstance(result, torch.fx.Proxy):
+        splits = [0] * len(inputs)
+    else:
+        splits = [x.shape[0] for x in inputs]
+
+    # pyrefly: ignore [bad-argument-type]
+    return torch.split(result, splits)
+
+
+def may_depend_on(a: Node, b: Node, search_depth: int = 6):
+    """
+    Determine if one node depends on another in a torch.fx.Graph.
+
+    Arguments:
+        a: The node that may have a dependency on b.
+        b: The node that a may have a dependency on.
+        search_depth: In the case of an indirect dependency, this function
+                        searches upto this many nodes away in search of a
+                        data dependency. If none is found, the function
+                        makes the conservative assumption that there is a
+                        dependency.
+
+    Returns:
+        True if a may depend on b, False if it definitely does not.
+    """
+    # Equivalence is defined as dependence.
+    if a == b:
+        return True
+
+    # If a has no inputs, it cannot depend on b.
+    if len(a.all_input_nodes) == 0:
+        return False
+
+    # If the search depth has been exhausted and no conclusion has been
+    # reached, assume that there is a data dependency.
+    if search_depth == 0:
+        return True
+
+    # Recursively check all inputs of a.
+    for inp in a.all_input_nodes:
+        if may_depend_on(inp, b, search_depth - 1):
+            return True
+
+    return False
+
+
+def are_nodes_independent(nodes: list[Node]):
+    """
+    Check if all of the given nodes are pairwise-data independent.
+
+    Arguments:
+        nodes: The nodes to check for data dependencies.
+
+    Returns:
+        True if any pair in nodes has a data dependency.
+    """
+    # For each pair in nodes:
+    for i, j in itertools.combinations(nodes, 2):
+        if may_depend_on(i, j) or may_depend_on(j, i):
+            return False
+
+    return True
+
+
+def merge_matmul(in_mod: torch.nn.Module):
+    """
+    A graph transformation that merges matrix multiplication operations that share the same right-hand
+    side operand into one large matrix multiplication.
+               ____      _________        _________
+      ----    |    |    |         |     M|  A * C  |
+    M| A  |  T| B  | * K|    C    | =    |---------|
+      ---- ,  |    |    |         |     T|  B * C  |
+       K       ----      ---------        ---------
+                K            R                R
+    """
+    gm = symbolic_trace(in_mod)
+
+    rhs_users: dict[Node, list[Node]] = {}
+    lhs_users: dict[Node, list[Node]] = {}
+
+    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
+    # the matmul of which they are the LHS/RHS.
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target is not torch.matmul:
+            continue
+
+        lhs, rhs = node.args
+
+        # TODO: Properly handle aliasing caused by get_attr. For now,
+        # use the attribute name as the operand if the node is a
+        # get_attr.
+        lhs = lhs.target if lhs.op == "get_attr" else lhs
+        rhs = rhs.target if rhs.op == "get_attr" else rhs
+
+        lhs_users.setdefault(lhs, []).append(node)
+        rhs_users.setdefault(rhs, []).append(node)
+
+    for rhs, mms in rhs_users.items():
+        # There must be at least matmuls for a merge to make sense.
+        if len(mms) < 2:
+            continue
+
+        # All matmuls must not depend on each other directly or indirectly
+        # in order for the merge to be possible.
+        if not are_nodes_independent(mms):
+            continue
+
+        lhs_vals = [mm.args[0] for mm in mms]
+
+        # Merge the matmul.
+        # Collect a list of LHS operands and the single RHS operand.
+        lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
+        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
+
+        # Concatenate all the LHS operands.
+        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
+
+        # Multiply the concatenated LHS operands with the one RHS. This will produce
+        # the same results as all the individual matmuls involving rhs in the original graph,
+        # but they will all be concatenated together.
+        merge_mm = gm.graph.call_function(
+            torch.matmul,
+            (
+                merge_mm_cat,
+                rhs,
+            ),
+            {},
+        )
+
+        # Split the result of the merged matmul using the shapes of the LHS operands
+        # to ascertain how large each chunk should be.
+        merge_mm_split = gm.graph.call_function(
+            split_result_tensors, (merge_mm, lhs), {}
+        )
+        merge_mm_res = [
+            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
+            for out in range(len(lhs))
+        ]
+
+        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
+        for old, new in zip(mms, merge_mm_res):
+            old.replace_all_uses_with(new)
+            gm.graph.erase_node(old)
+
+        # All of the new nodes created above were inserted at the end, so we need to sort
+        # the nodes topologically to make sure all definitions precede uses.
+        legalize_graph(gm)
+
+    gm.recompile()
+    gm.graph.lint()
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3adfba8d412a12012cb3148732e0fab42a7b66
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py
@@ -0,0 +1,320 @@
+# mypy: allow-untyped-defs
+import builtins
+import functools
+import warnings
+from collections.abc import Callable
+from typing import Any, Optional, Union
+
+import torch
+import torch.fx
+
+
+def embedding_override(self, input):
+    return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
+
+
+def nn_layernorm_override(self, input):
+    return input
+
+
+def torch_relu_override(x):
+    return x
+
+
+def torch_nn_relu_override(self, x):
+    return x
+
+
+def functional_relu_override(x, inplace=False):
+    assert not inplace, "dont support inplace functional.relu for metatensor analysis"
+    return x
+
+
+def torch_where_override(condition, x, y):
+    # torch.where returns the broadcasted tensor of condition, x, and y,
+    # so hack it by using addition
+    return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
+
+
+def torch_abs_override(input, *, out=None):
+    assert out is None, "Dont support in-place abs for MetaTensor analysis"
+    return input
+
+
+manual_meta_overrides: dict[Callable, Callable] = {
+    torch.nn.Embedding: embedding_override,
+    torch.nn.LayerNorm: nn_layernorm_override,
+    torch.relu: torch_relu_override,
+    torch.nn.functional.relu: functional_relu_override,
+    torch.nn.ReLU: torch_nn_relu_override,
+    torch.where: torch_where_override,
+    torch.abs: torch_abs_override,
+}
+
+
+def gen_constructor_wrapper(target):
+    @functools.wraps(target)
+    def wrapper(*args, **kwargs):
+        proxy = None
+
+        def check_has_proxy(v):
+            if isinstance(v, torch.fx.Proxy):
+                nonlocal proxy
+                proxy = v
+
+        torch.fx.node.map_aggregate(args, check_has_proxy)
+        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
+
+        if proxy is not None:
+            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
+        else:
+            return target(*args, **kwargs)
+
+    return wrapper, target
+
+
+class MetaProxy(torch.fx.Proxy):
+    def install_tensor_meta(self, tensor_meta):
+        self._tensor_meta = tensor_meta
+
+    def size(self, dim=None):
+        if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
+            return self._tensor_meta.size(*[dim] if dim else [])
+        return self.tracer.create_proxy(
+            "call_method", "size", (self, dim) if dim else (self,), {}
+        )
+
+    def dim(self):
+        if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
+            return self._tensor_meta.dim()
+        return self.tracer.create_proxy("call_method", "dim", (self,), {})
+
+    @property
+    def shape(self):
+        if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
+            return self._tensor_meta.shape
+        return self.tracer.create_proxy(
+            "call_function", builtins.getattr, (self, "shape"), {}
+        )
+
+    @property
+    def dtype(self):
+        if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
+            return self._tensor_meta.dtype
+        return self.tracer.create_proxy(
+            "call_function", builtins.getattr, (self, "dtype"), {}
+        )
+
+    @property
+    def device(self):
+        # Hack so we can track when devices are used. During meta-tensor propagation,
+        # replace these values with a constant 'meta'
+        return MetaDeviceAttribute(self, "device")
+
+    def __getattr__(self, k):
+        if k == "_tensor_meta":
+            return self.__getattribute__(k)
+        # note: not added to the graph yet, if this is a method call
+        # we peephole optimize to the method invocation
+        return MetaAttribute(self, k)
+
+
+class MetaAttribute(MetaProxy):
+    def __init__(self, root, attr: str):
+        self.root = root
+        self.attr = attr
+        self.tracer = root.tracer
+        self._node = None
+
+    @property
+    def node(self):  # type: ignore[override]
+        # the node for attributes is added lazily, since most will just be method calls
+        # which do not rely on the getitem call
+        if self._node is None:
+            self._node = self.tracer.create_proxy(
+                "call_function", getattr, (self.root, self.attr), {}
+            ).node
+        return self._node
+
+    def __call__(self, *args, **kwargs):
+        return self.tracer.create_proxy(
+            "call_method", self.attr, (self.root,) + args, kwargs
+        )
+
+
+class MetaDeviceAttribute(MetaAttribute):
+    pass
+
+
+def proxys_to_metas(v):
+    if isinstance(v, MetaDeviceAttribute):
+        return "meta"
+    if isinstance(v, torch.fx.Proxy):
+        assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}"
+        assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta"
+        return v._tensor_meta
+    return v
+
+
+class MetaTracer(torch.fx.Tracer):
+    allow_insert_stateless_mods: bool = True
+
+    _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
+
+    def create_proxy(
+        self,
+        kind,
+        target,
+        args,
+        kwargs,
+        name=None,
+        type_expr=None,
+        proxy_factory_fn=None,
+    ):
+        rv = super().create_proxy(
+            kind,
+            target,
+            args,
+            kwargs,
+            name,
+            type_expr,
+            # pyrefly: ignore [bad-argument-type]
+            proxy_factory_fn,
+        )
+
+        if kind == "placeholder" and target in self.meta_args:
+            rv.install_tensor_meta(self.meta_args[target])
+            return rv
+
+        if target in self.orig_fns:
+            # NOTE: tensor constructors in PyTorch define the `device` argument as
+            # *kwargs-only*. That is why this works. If you add methods to
+            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
+            # this will break and you will likely see issues where we cannot infer
+            # the size of the output.
+            if "device" in kwargs:
+                kwargs["device"] = "meta"
+
+        try:
+            args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
+            kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
+
+            if kind == "call_function":
+                meta_target = manual_meta_overrides.get(target, target)
+                # pyrefly: ignore [not-callable]
+                meta_out = meta_target(*args_metas, **kwargs_metas)
+            elif kind == "call_method":
+                meta_target = getattr(args_metas[0], target)  # type: ignore[index]
+                meta_out = meta_target(*args_metas[1:], **kwargs_metas)  # type: ignore[index]
+            elif kind == "call_module":
+                assert hasattr(self, "orig_forward")
+                self._disable_module_getattr = True
+                try:
+                    mod = self.root.get_submodule(target)
+                    mod_type = type(mod)
+                    if mod_type in manual_meta_overrides:
+                        meta_out = manual_meta_overrides[mod_type](
+                            mod, *args_metas, **kwargs_metas
+                        )  # type: ignore[misc, arg-type]
+                    else:
+                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
+                finally:
+                    self._disable_module_getattr = False
+            elif kind == "get_attr":
+                self._disable_module_getattr = True
+                try:
+                    attr_itr = self.root
+                    atoms = target.split(".")
+                    for atom in atoms:
+                        attr_itr = getattr(attr_itr, atom)
+                    assert isinstance(attr_itr, torch.Tensor)
+                    meta_out = attr_itr.to(device="meta")
+                finally:
+                    self._disable_module_getattr = False
+            else:
+                return rv
+
+            # TODO
+            assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet"
+            rv.install_tensor_meta(meta_out)
+        except Exception as e:
+            warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
+
+        return rv
+
+    def getattr(self, attr, attr_val, parameter_proxy_cache):
+        if getattr(self, "_disable_module_getattr", False):
+            return attr_val
+        else:
+            return super().getattr(attr, attr_val, parameter_proxy_cache)
+
+    def call_module(self, m, forward, args, kwargs):
+        self.orig_forward = forward
+        return super().call_module(m, forward, args, kwargs)
+
+    def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
+        """
+        Helper method which tries to insert a module that was not declared as submodule.
+        """
+        idx = 0
+        mod_name = mod.__class__.__name__.lower()
+        path = f"{mod_name}_{idx}"
+        while hasattr(self.root, path):
+            path = f"{mod_name}_{idx}"
+            idx += 1
+
+        self.root.add_module(path, mod)
+        return path
+
+    def path_of_module(self, mod: torch.nn.Module) -> str:
+        try:
+            return super().path_of_module(mod)
+        except NameError:
+            if (
+                self.allow_insert_stateless_mods
+                and len(list(mod.parameters())) == 0
+                and len(list(mod.buffers())) == 0
+            ):
+                path = self._insert_module_as_submodule(mod)
+                self.prev_module = path
+                return path
+            raise
+
+    def proxy(self, node):
+        return MetaProxy(node, self)
+
+    def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None):  # type: ignore[override]
+        assert isinstance(meta_args, dict)
+        self.meta_args = meta_args
+
+        self.patched_torch_methods = {
+            target: gen_constructor_wrapper(getattr(torch, target))
+            for target in self._TORCH_METHODS_TO_PATCH
+        }
+        self.orig_fns = set()
+
+        for name, (wrapper, orig) in self.patched_torch_methods.items():
+            setattr(torch, name, wrapper)
+            self.orig_fns.add(orig)
+
+        try:
+            graph = super().trace(root, concrete_args)
+            graph._tracer_extras = {"meta_args": meta_args}
+            return graph
+        finally:
+            for name, (_, orig) in self.patched_torch_methods.items():
+                setattr(torch, name, orig)
+
+
+def symbolic_trace(
+    root: Union[torch.nn.Module, Callable[..., Any]],
+    meta_args: Optional[dict[str, torch.Tensor]] = None,
+    concrete_args: Optional[dict[str, Any]] = None,
+) -> torch.fx.GraphModule:
+    tracer = MetaTracer()
+    graph = tracer.trace(root, meta_args, concrete_args)  # type: ignore[arg-type]
+    name = (
+        root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+    )
+    gm = torch.fx.GraphModule(tracer.root, graph, name)
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32781d26e671cb0351e8f8bc6f0b18ddeeeed032
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c5d1e7b2d0b1bd10c87d3b9ca57e632308e1417
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53394d736878ba4ce5e711a21dd9d94dc653fac4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f72c3fa91cff02f6c0bbe7fb5623f634dd268f35
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8429d2764b44da1d30ec9177085c445be5b46d7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..238fa4c281a65669417bcc2fb9c20f7adf7b8d4f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b547db5c92d42b37c40a4bfacae52c86dd440f8
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..277cff75a705c30311871e3b47f51cccaeec2c29
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e46b3a607044a47774db97ec14c0ed40bea3d23d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
@@ -0,0 +1,637 @@
+# mypy: allow-untyped-defs
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_div,
+    op_eq,
+    op_gt,
+    op_lt,
+    op_mod,
+    op_mul,
+    op_neq,
+    op_sub,
+)
+from torch.fx.tensor_type import Dyn, TensorType
+
+
+class Constraint:
+    pass
+
+
+class Conj(Constraint):
+    def __init__(self, conjuncts):
+        """
+        :param conjuncts: Conjunction of constraints
+        """
+        self.conjucts = conjuncts
+
+    def __eq__(self, other):
+        if isinstance(other, Conj):
+            return self.conjucts == other.conjucts and self.conjucts == other.conjucts
+        else:
+            return False
+
+    def __repr__(self):
+        return f"And({self.conjucts})"
+
+
+class Disj(Constraint):
+    def __init__(self, disjuncts):
+        """
+        :param disjuncts: Disjunction of constraints
+        """
+        self.disjuncts = disjuncts
+
+    def __eq__(self, other):
+        if isinstance(other, Disj):
+            return (
+                self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
+            )
+        else:
+            return False
+
+    def __repr__(self):
+        return f"Or({self.disjuncts})"
+
+
+class Prod(Constraint):
+    def __init__(self, products):
+        """
+        :param products: lists of dimensions to multiply
+        """
+        self.products = products
+
+    def __eq__(self, other):
+        if isinstance(other, Prod):
+            return self.products == other.products and self.products == other.products
+        else:
+            return False
+
+    def __repr__(self):
+        return f"Product({self.products})"
+
+
+class T(Constraint):
+    """
+    True
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __eq__(self, other):
+        return isinstance(other, T)
+
+    def __repr__(self):
+        return "True"
+
+
+class F(Constraint):
+    """
+    False
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __eq__(self, other):
+        return isinstance(other, F)
+
+    def __repr__(self):
+        return "False"
+
+
+class BinaryConstraint(Constraint):
+    """
+    Represents all binary operations
+    """
+
+    def __init__(self, lhs, rhs, op):
+        """
+        :param lhs: lhs of the constraint
+        :param rhs: rhs of the constraint
+        :param op: string representing the operation
+        """
+        self.lhs = lhs
+        self.rhs = rhs
+        self.op = op
+
+    def __eq__(self, other):
+        if isinstance(other, BinaryConstraint):
+            return (
+                self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
+            )
+        else:
+            return False
+
+    def __repr__(self):
+        return f"({self.lhs} {self.op} {self.rhs})"
+
+
+class BinConstraintT(BinaryConstraint):
+    """
+    Binary constraints about tensors
+    """
+
+    def __init__(self, lhs, rhs, op):
+        assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and (
+            isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn
+        )
+        super().__init__(lhs, rhs, op)
+
+
+class BinConstraintD(BinaryConstraint):
+    """
+    Binary constraints about dimensions
+    """
+
+    def __init__(self, lhs, rhs, op):
+        assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
+        assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
+
+        super().__init__(lhs, rhs, op)
+
+
+class TGreatestUpperBound(Constraint):
+    """
+    Greatest Upper bound for tensors with dynamic type
+    """
+
+    def __init__(self, res, rhs1, rhs2):
+        """
+        :param res: tensor variable that stores the result of the output
+        :param rhs1: tensor or tensor variable
+        :param rhs2: tensor or tensor variabke
+        """
+        self.res = res
+        self.rhs1 = rhs1
+        self.rhs2 = rhs2
+
+    def __repr__(self):
+        return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}"
+
+    def __eq__(self, other):
+        if isinstance(other, TGreatestUpperBound):
+            return (
+                self.res == other.res
+                and self.rhs1 == other.rhs1
+                and self.rhs2 == other.rhs2
+            )
+        else:
+            return False
+
+
+class DGreatestUpperBound(Constraint):
+    """
+    Greatest Upper bound for dimensions
+    """
+
+    def __init__(self, res, rhs1, rhs2):
+        """
+        :param res: Dimension variable to store the result
+        :param rhs1: dimension variable 1
+        :param rhs2: dimension variable 2
+        """
+        assert is_dim(res)
+        assert is_dim(rhs1)
+        assert is_dim(rhs2)
+
+        self.res = res
+        self.rhs1 = rhs1
+        self.rhs2 = rhs2
+
+    def __repr__(self):
+        return f"{self.res} = {self.rhs1}\u2294{self.rhs2}"
+
+    def __eq__(self, other):
+        if isinstance(other, DGreatestUpperBound):
+            return (
+                self.res == other.res
+                and self.rhs1 == other.rhs1
+                and self.rhs2 == other.rhs2
+            )
+        else:
+            return False
+
+
+class CanReshape(Constraint):
+    """
+    can_reshape constraint
+    """
+
+    def __init__(self, src, target):
+        """
+        :param src: tensor variable
+        :param target: tensor
+        """
+        self.src = src
+        self.target = target
+
+    def __repr__(self):
+        return f"can-reshape({self.src}, {self.target})"
+
+    def __eq__(self, other):
+        if isinstance(other, CanReshape):
+            return self.src == other.src and self.target == other.target
+        else:
+            return False
+
+
+class IndexSelect(Constraint):
+    def __init__(self, tensor_size, input_var, dim_replace, index, output):
+        """
+        Args:
+            input_var: input to index_select
+            tensor_size: tensor size we are considering
+            dim_replace: the dimension of the output at "index"
+            index: location of the dimensions to replace in the input
+            output: variable to store the result
+        """
+        assert isinstance(input_var, TVar)
+        assert isinstance(output, TVar)
+        assert isinstance(dim_replace, DVar) or dim_replace == Dyn
+        assert isinstance(index, int)
+
+        self.input_var = input_var
+        self.tensor_size = tensor_size
+        self.dim_replace = dim_replace
+        self.index = index
+        self.output = output
+
+    def __repr__(self):
+        return (
+            f" {self.output} = "
+            f"IndexSelect({self.input_var}, "
+            f"tensor_size: {self.tensor_size}, "
+            f"{self.dim_replace}, "
+            f"{self.index})"
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, IndexSelect):
+            return (
+                self.tensor_size == other.tensor_size
+                and self.dim_replace == other.dim_replace
+                and self.index == other.index
+                and self.output == other.output
+                and self.input_var == other.input_var
+            )
+        else:
+            return False
+
+
+class Transpose(Constraint):
+    def __init__(self, tensor_size, input_var, index1, index2, output):
+        """
+        Args:
+            tensor_size: current tensor size
+            input_var: variable to hold input
+            index1: dimension 1
+            index2: dimension 2
+            output: output that stores result
+        """
+        assert isinstance(input_var, TVar)
+        assert isinstance(output, TVar)
+        assert isinstance(index1, int)
+        assert isinstance(index2, int)
+
+        self.input_var = input_var
+        self.tensor_size = tensor_size
+        self.index1 = index1
+        self.index2 = index2
+        self.output = output
+
+    def __repr__(self):
+        return (
+            f" {self.output} = "
+            f"Transpose({self.input_var}, "
+            f"tensor_size: {self.tensor_size}, "
+            f"{self.index1}, "
+            f"{self.index2})"
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, Transpose):
+            return (
+                self.tensor_size == other.tensor_size
+                and self.index1 == other.index1
+                and self.index2 == other.index2
+                and self.output == other.output
+                and self.input_var == other.input_var
+            )
+        else:
+            return False
+
+
+class GetItem(Constraint):
+    def __init__(self, tensor_size, index, res, input_var):
+        """
+        Constraint for getting item given a tensor size
+        :param tensor_size: actual number
+        :param index: actual number representing the index
+        :param res: dimension variable to carry the item we get
+        :param input_var: a tensor variable from which we will get item
+        """
+        assert isinstance(res, DVar)
+
+        self.res = res
+        self.tensor_size = tensor_size
+        self.index = index
+        self.input_var = input_var
+
+    def __repr__(self):
+        return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})"
+
+    def __eq__(self, other):
+        if isinstance(other, GetItem):
+            return (
+                self.res == other.res
+                and self.tensor_size == other.tensor_size
+                and self.index == other.index
+                and self.input_var == other.input_var
+            )
+        else:
+            return False
+
+
+class GetItemTensor(Constraint):
+    def __init__(self, tensor_size, index_tuple, res, input_var):
+        """
+        Constraint for getting item given a tensor size
+        However, when the argument is a tuple, we will
+        expect a tensor
+        :param tensor_size: actual number representing the rank
+        :param index_tuple: tuple for indexing
+        :param res: tensor variable to carry the item we get
+        :param input_var: a tensor variable from which we will get item
+        """
+        assert isinstance(res, TVar)
+
+        self.res = res
+        self.tensor_size = tensor_size
+        self.index_tuple = index_tuple
+        self.input_var = input_var
+
+    def __repr__(self):
+        return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})"
+
+    def __eq__(self, other):
+        if isinstance(other, GetItemTensor):
+            return (
+                self.res == other.res
+                and self.tensor_size == other.tensor_size
+                and self.index_tuple == other.index_tuple
+                and self.input_var == other.input_var
+            )
+        else:
+            return False
+
+
+class CalcConv(Constraint):
+    def __init__(
+        self,
+        conv_result,
+        input_var,
+        c_out,
+        kernel,
+        padding,
+        stride,
+        dilation,
+        matching_constraint_vars,
+    ):
+        """
+        :param conv_result: the convolution result
+        :param input_var: input to convolution
+        :param c_out: output channel type
+        :param kernel: kernel tuple
+        """
+        self.conv_result = conv_result
+        self.input_var = input_var
+        self.c_out = c_out
+        self.kernel = kernel
+        self.padding = padding
+        self.stride = stride
+        self.dilation = dilation
+        self.matching_constraint = matching_constraint_vars
+
+    def __repr__(self):
+        return (
+            f"{self.conv_result} ="
+            f" calc-conv({self.input_var},"
+            f" {self.c_out}, {self.kernel}, "
+            f"{self.padding}, {self.stride},"
+            f" {self.dilation})"
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, CalcConv):
+            return (
+                self.conv_result == other.conv_result
+                and self.input_var == other.input_var
+                and self.c_out == other.c_out
+                and self.kernel == other.kernel
+                and self.padding == other.padding
+                and self.stride == other.stride
+                and self.dilation == other.dilation
+                and self.matching_constraint == other.matching_constraint
+            )
+        else:
+            return False
+
+
+class CalcMaxPool(Constraint):
+    def __init__(
+        self,
+        maxpool_result,
+        input_var,
+        kernel,
+        padding,
+        stride,
+        dilation,
+        matching_constraint_vars,
+    ):
+        """
+        :param maxpool_result: the result of maxpool
+        :param input_var: input to convolution
+        :param kernel: kernel tuple
+        """
+        self.maxpool_result = maxpool_result
+        self.input_var = input_var
+        self.kernel = kernel
+        self.padding = padding
+        self.stride = stride
+        self.dilation = dilation
+        self.matching_constraint = matching_constraint_vars
+
+    def __repr__(self):
+        return (
+            f"{self.maxpool_result} ="
+            f" calc-maxpool({self.input_var},"
+            f"  {self.kernel}, "
+            f"{self.padding}, {self.stride},"
+            f" {self.dilation})"
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, CalcMaxPool):
+            return (
+                self.maxpool_result == other.maxpool_result
+                and self.input_var == other.input_var
+                and self.kernel == other.kernel
+                and self.padding == other.padding
+                and self.stride == other.stride
+                and self.dilation == other.dilation
+                and self.matching_constraint == other.matching_constraint
+            )
+        else:
+            return False
+
+
+class ApplyBroadcasting(Constraint):
+    def __init__(self, res1, res2, input1, input2):
+        """
+        :param res1: resulting tensor 1
+        :param res2: resulting tensor 2
+        :param input1: tensor variable 1
+        :param input2: tensor variable 2
+        """
+        self.res1 = res1
+        self.res2 = res2
+        self.input1 = input1
+        self.input2 = input2
+
+    def __eq__(self, other):
+        if isinstance(other, ApplyBroadcasting):
+            return (
+                self.res1 == other.res1
+                and self.res2 == other.res2
+                and self.input1 == other.input1
+                and self.input2 == other.input2
+            )
+        else:
+            return False
+
+    def __repr__(self):
+        return (
+            f"{self.res1}, {self.res2} ="
+            f" apply-broadcasting({self.input1},"
+            f" {self.input2})"
+        )
+
+
+class CalcProduct(Constraint):
+    """
+    Given correct dimensions, calculate the product for flatten accounting for Dyn
+    """
+
+    def __init__(self, start, end, flattened, dims_to_flatten):
+        """
+        :param start: start index
+        :param end: end index
+        :param flattened: variable to store the product
+        :param dims_to_flatten: the type which we will flatten
+        """
+        assert isinstance(dims_to_flatten, list)
+        assert isinstance(flattened, TVar)
+        assert isinstance(start, int)
+        assert isinstance(end, int)
+
+        self.start = start
+        self.end = end
+        self.dims_to_flatten = dims_to_flatten
+        self.flattened = flattened
+
+    def __eq__(self, other):
+        if isinstance(other, CalcProduct):
+            return (
+                self.start == other.start
+                and self.end == other.end
+                and self.dims_to_flatten == other.dims_to_flatten
+                and self.flattened == other.flattened
+            )
+
+        else:
+            return False
+
+    def __repr__(self):
+        return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})"
+
+
+class TVar:
+    """
+    Tensor variable with no tensor constructor
+    """
+
+    def __init__(self, tvar):
+        """
+        :param tvar: tensor variable
+        """
+        self.tvar = tvar
+
+    def __repr__(self):
+        return f"TV({self.tvar})"
+
+    def __eq__(self, other):
+        if isinstance(other, TVar):
+            return self.tvar == other.tvar
+        else:
+            return False
+
+
+class DVar:
+    """
+    Dimension variable
+    """
+
+    def __init__(self, c):
+        """
+        :param c: character or number
+        """
+        self.c = c
+
+    def __repr__(self):
+        return f"DV({self.c})"
+
+    def __eq__(self, other):
+        if isinstance(other, DVar):
+            return self.c == other.c
+        else:
+            return False
+
+
+class BVar:
+    """
+    Boolean variable
+    """
+
+    def __init__(self, c):
+        """
+        :param c: character or number
+        """
+        self.c = c
+
+    def __repr__(self):
+        return f"BV({self.c})"
+
+    def __eq__(self, other):
+        if isinstance(other, BVar):
+            return self.c == other.c
+        else:
+            return False
+
+
+def is_algebraic_expression(constraint):
+    if isinstance(constraint, BinConstraintD):
+        return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
+    else:
+        return isinstance(constraint, Prod)
+
+
+def is_bool_expr(constraint):
+    if isinstance(constraint, BinConstraintD):
+        return constraint.op in [op_gt, op_lt, op_neq, op_eq]
+    else:
+        return isinstance(constraint, (BVar, Conj, Disj))
+
+
+def is_dim(d):
+    return isinstance(d, (DVar, int)) or d == Dyn
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e5c7c215e64f0ee61a840f37482c56f988c445
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -0,0 +1,1565 @@
+# mypy: allow-untyped-defs
+import operator
+import warnings
+from collections.abc import Callable, Iterable
+from typing import TypeVar
+from typing_extensions import ParamSpec
+
+import torch
+from torch.fx._symbolic_trace import _assert_is_none
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    ApplyBroadcasting,
+    BinConstraintD,
+    BinConstraintT,
+    CalcConv,
+    CalcMaxPool,
+    CalcProduct,
+    CanReshape,
+    Conj,
+    DGreatestUpperBound,
+    Disj,
+    DVar,
+    F,
+    GetItem,
+    GetItemTensor,
+    IndexSelect,
+    T,
+    TGreatestUpperBound,
+    Transpose,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_consistency,
+    op_div,
+    op_eq,
+    op_gt,
+    op_leq,
+    op_lt,
+    op_matching,
+    op_mul,
+    op_neq,
+    op_precision,
+    op_sub,
+)
+from torch.fx.experimental.migrate_gradual_types.util import (
+    gen_bvar,
+    gen_dvar,
+    gen_nat_constraints,
+    gen_tensor_dims,
+    gen_tvar,
+)
+from torch.fx.node import Node, Target
+from torch.fx.tensor_type import Dyn, TensorType
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn.modules.conv import Conv2d
+
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+_INFERENCE_RULES: dict[Target, Callable] = {}
+
+MAX_TENSOR_RANK = 4
+
+__all__ = [
+    "ConstraintGenerator",
+    "adaptive_inference_rule",
+    "add_layer_norm_constraints",
+    "add_linear_constraints",
+    "arange_inference_rule",
+    "assert_inference_rule",
+    "batchnorm_inference_rule",
+    "bmm_inference_rule",
+    "broadcasting_inference_rule",
+    "conv2d_inference_rule",
+    "cumsum_inference_rule",
+    "embedding_inference_rule",
+    "embedding_inference_rule_functional",
+    "eq_inference_rule",
+    "equality_inference_rule",
+    "expand_inference_rule",
+    "flatten_inference_rule",
+    "full_inference_rule",
+    "gen_broadcasting_constraints",
+    "gen_embedding_rules",
+    "gen_layer_norm_constraints",
+    "generate_flatten_constraints",
+    "get_attr_inference_rule",
+    "getitem_inference_rule",
+    "gt_inference_rule",
+    "index_select_inference_rule",
+    "layer_norm_functional",
+    "layer_norm_inference_rule",
+    "linear_constraints",
+    "linear_inference_rule",
+    "lt_inference_rule",
+    "masked_fill_inference_rule",
+    "maxpool_inference_rule",
+    "neq_inference_rule",
+    "range_check",
+    "register_inference_rule",
+    "relu_inference_rule",
+    "reshape_inference_rule",
+    "size_inference_rule",
+    "tensor_inference_rule",
+    "torch_dim_inference_rule",
+    "torch_linear_inference_rule",
+    "transpose_inference_rule",
+    "type_inference_rule",
+    "view_inference_rule",
+]
+
+
+def register_inference_rule(
+    call_target: Target,
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+    def register(fn: Callable[_P, _T]) -> Callable[_P, _T]:
+        if call_target in _INFERENCE_RULES:
+            raise RuntimeError(f"Inference rule already registered for {call_target}!")
+        _INFERENCE_RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
+    d, counter = gen_tensor_dims(n, counter)
+    c1 = BinConstraintT(input, TensorType(d), op_eq)
+    start_dim = n if start_dim == -1 else abs(start_dim)
+    end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
+    c2 = CalcProduct(start_dim, end_dim, flattened, d)
+    nat_constraints = gen_nat_constraints(d)
+    return Conj([c1, c2, *nat_constraints]), counter
+
+
+@register_inference_rule(getattr)
+def get_attr_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    If the attribute is "device" then the tensor shape is preserved
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], str)
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    input = symbols[n.args[0]]
+    attr = n.args[1]
+
+    if attr == "device":
+        return [BinConstraintT(input, output, op_eq)], counter
+    else:
+        raise NotImplementedError("Not yet implemented")
+
+
+@register_inference_rule(torch.bmm)
+def bmm_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Constraints that match the input to a size 3 tensor
+    and switch the dimensions according to the rules
+    of batch multiplication
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    bmm_output, counter = gen_tvar(counter)
+    symbols[n] = bmm_output
+
+    bmm_input1 = symbols[n.args[0]]
+    bmm_input2 = symbols[n.args[1]]
+
+    dims_input1, counter = gen_tensor_dims(3, counter)
+    dims_input2, counter = gen_tensor_dims(3, counter)
+
+    inputs_dyn = Conj(
+        [
+            BinConstraintT(bmm_input1, Dyn, op_eq),
+            BinConstraintT(bmm_input2, Dyn, op_eq),
+            BinConstraintT(bmm_output, Dyn, op_eq),
+        ]
+    )
+
+    input1_dyn = Conj(
+        [
+            BinConstraintT(bmm_input1, Dyn, op_eq),
+            BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
+            BinConstraintT(
+                bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq
+            ),
+        ]
+    )
+
+    input2_dyn = Conj(
+        [
+            BinConstraintT(bmm_input2, Dyn, op_eq),
+            BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
+            BinConstraintT(
+                bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq
+            ),
+        ]
+    )
+
+    consistency_constraints = [
+        BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)
+    ]
+
+    batch_size, counter = gen_dvar(counter)
+
+    inputs_are_tensors = Conj(
+        [
+            BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
+            BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
+            BinConstraintT(
+                bmm_output,
+                TensorType([batch_size, dims_input1[1], dims_input2[2]]),
+                op_eq,
+            ),
+            *consistency_constraints,
+            DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]),
+        ]
+    )
+
+    return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
+
+
+@register_inference_rule("index_select")
+def index_select_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We constrain the second argument to a vector or Dyn.
+    The output replaces the input with the shape of the vector
+    at the position given by the index (first argument)
+    """
+    # print(n.args)
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], int)
+    assert isinstance(n.args[2], Node)
+
+    index_select, counter = gen_tvar(counter)
+    symbols[n] = index_select
+
+    dims, counter = gen_tensor_dims(1, counter)
+
+    # equality constraint
+    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
+    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
+
+    c2 = Conj(
+        [
+            is_size_1,
+            Disj(
+                [
+                    IndexSelect(
+                        i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select
+                    )
+                    for i in range(MAX_TENSOR_RANK)
+                ]
+            ),
+        ]
+    )
+    c3 = Conj(
+        [
+            is_dyn,
+            Disj(
+                [
+                    IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
+                    for i in range(MAX_TENSOR_RANK)
+                ]
+            ),
+        ]
+    )
+
+    return [Disj([c2, c3])], counter
+
+
+@register_inference_rule("expand")
+def expand_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the exact constraints as we do for tensor additions but we constraint
+    the rank of this expression to be equal to len(n.args[1:]) so that only
+    those cases get considered for the output
+    """
+    assert isinstance(n.args[0], Node)
+
+    # define the output for expand
+    expand, counter = gen_tvar(counter)
+    symbols[n] = expand
+
+    # since we do not have two nodes here, we will construct an argument variable
+    e1 = symbols[n.args[0]]
+    e2, counter = gen_tvar(counter)
+
+    e2_nat_constraints = []
+    for arg in n.args[1:]:
+        assert isinstance(arg, (Node, int))
+        if isinstance(arg, Node):
+            assert isinstance(symbols[arg], DVar)
+            e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
+
+    e2_constraint = BinConstraintT(
+        e2,
+        TensorType(
+            [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]
+        ),
+        op_eq,
+    )
+
+    constraints, counter = gen_broadcasting_constraints(
+        e1, e2, symbols, counter, expand
+    )
+
+    # constraint the output size
+    dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
+    nat_constraints = gen_nat_constraints(dims)
+    c = [
+        BinConstraintT(expand, TensorType(dims), op_eq),
+        *nat_constraints,
+        e2_constraint,
+        *e2_nat_constraints,
+    ]
+    constraints += c
+
+    return constraints, counter
+
+
+@register_inference_rule(torch.nn.functional.gelu)
+@register_inference_rule(torch.nn.functional.dropout)
+@register_inference_rule(torch.nn.functional.softmax)
+@register_inference_rule("detach")
+@register_inference_rule("to")
+@register_inference_rule("int")
+@register_inference_rule("long")
+@register_inference_rule("contiguous")
+@register_inference_rule(torch.ones)
+@register_inference_rule(torch.zeros)
+def equality_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    if isinstance(n.args[0], Node):
+        input = symbols[n.args[0]]
+        if isinstance(input, TVar):
+            return [BinConstraintT(input, output, op_eq)], counter
+
+        # then we have dimension variables
+        else:
+            for arg in n.args:
+                assert isinstance(symbols[arg], DVar)
+        my_size = [symbols[arg] for arg in n.args]
+        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
+
+    elif isinstance(n.args[0], tuple):
+        # then the tuple is the size
+        assert len(n.args[0]) <= 4
+        my_size = [symbols[arg] for arg in n.args[0]]
+        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
+    else:
+        raise NotImplementedError("Method not yet implemented")
+
+
+@register_inference_rule("transpose")
+def transpose_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Can be considered as a sequence of two index selects, so we generate constraints accordingly
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], int)
+    assert isinstance(n.args[2], int)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    from_arg = symbols[n.args[0]]
+    assert isinstance(from_arg, TVar)
+
+    # input and output are dyn
+    is_dyn = Conj(
+        [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]
+    )
+
+    # or input is a tensor and we actually do the replacement
+    c3 = Disj(
+        [
+            Transpose(i + 1, from_arg, n.args[1], n.args[2], output)
+            for i in range(MAX_TENSOR_RANK)
+        ]
+    )
+
+    return [Disj([is_dyn, c3])], counter
+
+
+@register_inference_rule("type_as")
+def type_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    from_arg = symbols[n.args[0]]
+    to_arg = symbols[n.args[1]]
+
+    assert isinstance(from_arg, TVar)
+    assert isinstance(to_arg, TVar)
+
+    return [
+        BinConstraintT(from_arg, to_arg, op_consistency),
+        BinConstraintT(output, to_arg, op_eq),
+    ], counter
+
+
+@register_inference_rule("masked_fill_")
+def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Similar to addition. For now we implement the constraints when
+    the argument is a boolean tensor. There is also a case for when
+    it is a condition. We will leave this out for now.
+    """
+
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    # We will retrieve the type variables from the symbol table
+    # and confirm they are tensor variables
+
+    e1 = symbols[n.args[0]]
+    e2 = symbols[n.args[1]]
+
+    if isinstance(e1, TVar) and isinstance(e2, TVar):
+        masked_fill_tensor, counter = gen_tvar(counter)
+        symbols[n] = masked_fill_tensor
+        return gen_broadcasting_constraints(
+            e1, e2, symbols, counter, masked_fill_tensor
+        )
+    else:
+        raise NotImplementedError("Not yet implemented")
+
+
+@register_inference_rule(torch.nn.functional.embedding)
+def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    embedding_dim_weights = symbols[n.args[1]]
+
+    # will treat this as a static shape. So we will not use matching.
+    weight_dims, counter = gen_tensor_dims(2, counter)
+    equality_constraint = BinConstraintT(
+        embedding_dim_weights, TensorType(weight_dims), op_eq
+    )
+    embedding_dim = weight_dims[1]
+    constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
+    return [equality_constraint] + constraints, counter
+
+
+@register_inference_rule(torch.nn.modules.sparse.Embedding)
+def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    The output shape differs from the input shape in the last dimension
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
+
+
+def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
+    embedding_output, counter = gen_tvar(counter)
+    symbols[n] = embedding_output
+    embedding_input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
+    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+    c2 = []
+
+    for i in range(1, MAX_TENSOR_RANK):
+        new_dims, counter = gen_tensor_dims(i, counter)
+        nat_constraints = gen_nat_constraints(new_dims)
+
+        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
+                BinConstraintT(
+                    embedding_output, TensorType(new_dims + [embedding_dim]), op_eq
+                ),
+            ]
+            + nat_constraints
+        )
+        c2.append(c_tensor_i)
+
+    return [Disj([c1, Disj(c2)])], counter
+
+
+@register_inference_rule(torch.tensor)
+def tensor_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    If the tensor is a scalar, we will skip it since we
+    do not support scalars yet. We will add support in the future
+    if it's needed. For our examples so far, scalars are not needed.
+    """
+    return [], counter
+
+
+@register_inference_rule("reshape")
+@register_inference_rule("view")
+def view_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Similar to reshape but with an extra condition on the strides
+    """
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    my_view, counter = gen_tvar(counter)
+    symbols[n] = my_view
+
+    src_var = symbols[n.args[0]]
+    t2 = [
+        symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]
+    ]  # target shape
+    t2_type = []
+    num_constraints = []
+
+    for t in t2:
+        if t == -1:
+            var, counter = gen_dvar(counter)
+            t2_type.append(var)
+            # pyrefly: ignore [bad-argument-type]
+            num_constraints.append(BinConstraintD(var, Dyn, op_neq))
+
+        else:
+            # pyrefly: ignore [bad-argument-type]
+            num_constraints.append(BinConstraintD(t, Dyn, op_neq))
+            t2_type.append(t)  # type: ignore[arg-type]
+
+    t2_type = TensorType(t2_type)  # type: ignore[assignment]
+
+    c1 = BinConstraintT(my_view, t2_type, op_eq)
+    c2 = CanReshape(src_var, t2_type)
+
+    # TODO: add the extra check mentioned here:
+    # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
+
+    return [c1, c2] + num_constraints, counter  # type: ignore[operator]
+
+
+@register_inference_rule("size")
+def size_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    The constraint is just lhs = rhs.
+    Ex: size = input_ids.size()
+    """
+
+    if len(n.args) == 1:
+        # generate the new variable
+        size, counter = gen_tvar(counter)
+        symbols[n] = size
+        input = symbols[n.args[0]]
+        c = BinConstraintT(input, size, op_eq)
+        return [c], counter
+
+    elif len(n.args) == 2:
+        # TODO: review this rule; should input = dyn; output = dyn be included here?
+        if isinstance(n.args[1], int):
+            # generate the new variable
+            size_index, counter = gen_dvar(counter)
+            symbols[n] = size_index
+            input = symbols[n.args[0]]
+            c2 = [
+                GetItem(i + 1, n.args[1], size_index, input)
+                for i in range(MAX_TENSOR_RANK)
+            ]
+            c3 = BinConstraintD(0, size_index, op_leq)
+
+            input_dyn = BinConstraintT(input, Dyn, op_eq)
+            output_dyn = BinConstraintD(size_index, Dyn, op_eq)
+            c1 = Conj([input_dyn, output_dyn])
+
+            return [Disj([c1, Conj([Disj(c2), c3])])], counter
+
+        else:
+            raise NotImplementedError
+
+    else:
+        raise NotImplementedError
+
+
+def range_check(i, n):
+    """
+    Checks if an index i is within range of a size n list
+    Args:
+        i: index
+        n: list size
+
+    Returns: Boolean
+    """
+    if i >= 0:
+        return T() if i < n else F()
+    else:
+        return T() if i >= n else F()
+
+
+@register_inference_rule(torch.cumsum)
+def cumsum_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal
+    We should verify that the index is valid
+    """
+    assert isinstance(n.args[0], Node)
+    arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
+    assert isinstance(arg_1, int)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintT(output, Dyn, op_eq)
+    c1 = Conj([input_dyn, output_dyn])
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(input, TensorType(new_dims), op_eq),
+                BinConstraintT(output, TensorType(new_dims), op_eq),
+            ]
+            + [range_check(arg_1, i)]
+            + nat_constraints
+        )
+
+        c2.append(c_tensor_i)
+    dyn_or_tensor = Disj([c1, Disj(c2)])
+    return [dyn_or_tensor], counter
+
+
+@register_inference_rule(_assert_is_none)
+def assert_inference_rule(n: Node, symbols, constraints, counter):
+    assert len(n.users) == 0
+    return [], counter
+
+
+@register_inference_rule(operator.getitem)
+def getitem_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # dimension output case
+    if isinstance(n.args[1], int):
+        # create and store the new dimension variable
+        get_item_output, counter = gen_dvar(counter)
+        symbols[n] = get_item_output
+
+        # retrieve arg variables
+        get_item_arg = symbols[n.args[0]]
+        assert isinstance(get_item_arg, TVar)
+
+        # if the input is dynamic, we accept any index and return
+        # a dynamic dimension as output
+        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
+        output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
+        c1 = Conj([input_dyn, output_dyn])
+
+        # if the input is a tensor,
+        # generate a getItem constraint which will be expanded based on the
+        # tensor dimension.
+
+        c2 = [
+            GetItem(i + 1, n.args[1], get_item_output, get_item_arg)
+            for i in range(MAX_TENSOR_RANK)
+        ]
+
+        # since the output is a dimension, we make sure it's a natural number
+        # added as a conjunction to the disjunction of c2
+        c3 = BinConstraintD(0, get_item_output, op_leq)
+        return [Disj([c1, Conj([Disj(c2), c3])])], counter
+
+    # tensor output case
+    elif isinstance(n.args[1], tuple):
+        # create and store the new tensor variable
+        get_item_output, counter = gen_tvar(counter)  # type: ignore[arg-type,assignment]
+        symbols[n] = get_item_output
+
+        # retrieve arg variables
+        if n.args[0] in symbols:
+            get_item_arg = symbols[n.args[0]]
+            assert isinstance(get_item_arg, TVar)
+
+            input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
+            output_dyn = BinConstraintT(get_item_output, Dyn, op_eq)  # type: ignore[assignment]
+            c1 = Conj([input_dyn, output_dyn])
+
+            c2 = [
+                GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg)  # type: ignore[misc]
+                for i in range(MAX_TENSOR_RANK)
+            ]
+        else:
+            # TODO: we should figure out why there is a key-error here.
+            return [], counter
+
+        return [Disj([c1, *c2])], counter
+
+    else:
+        raise RuntimeError("Method not yet implemented")
+
+
+@register_inference_rule(operator.gt)
+def gt_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    # We make sure this node will not be used again. We do not
+    # generate a constraint about that node. Only about the operands.
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            gt_tensor, counter = gen_tvar(counter)
+            symbols[n] = gt_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            gt_constraint = BinConstraintD(e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError("Sort Mismatch")
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            gt_constraint = BinConstraintD(e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        elif isinstance(e1, TVar) and isinstance(e2, int):
+            # then we made the wrong assumption about the argument being a tensor
+            # so we should fix the assumption
+            warnings.warn(
+                f"Made the wrong assumption for node {n}. Correctness not guaranteed."
+            )
+
+            new_e1, counter = gen_dvar(counter)
+            symbols[n.args[0]] = new_e1
+            symbols[n.args[0]]
+
+            gt_constraint = BinConstraintD(new_e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    else:
+        raise NotImplementedError("Method not yet implemented")
+
+
+@register_inference_rule(operator.eq)
+def eq_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            eq_tensor, counter = gen_tvar(counter)
+            symbols[n] = eq_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            eq_constraint = BinConstraintD(e1, e2, op_eq)
+
+            my_eq, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError("Sort Mismatch")
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            eq_constraint = BinConstraintD(e1, e2, op_eq)
+
+            my_eq, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
+            return [equality_constraint], counter
+        else:
+            raise NotImplementedError("Method not yet implemented")
+    else:
+        raise NotImplementedError("Method not yet implemented")
+
+
+@register_inference_rule(operator.ne)
+def neq_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Translates to inconsistent in gradual types.
+    To prove inequality, we should prove that
+    tensors are either different sizes or
+    disagree on at least one dimension
+
+    This is a WIP (works when the condition
+    is false. We are working on making this operation work
+    when the condition is true as well)
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], tuple)
+
+    # implementing for size 3 and 4
+    if len(n.args[1]) == 3:
+        assert isinstance(n.args[1][0], (Node, int))
+        assert isinstance(n.args[1][1], (Node, int))
+        assert isinstance(n.args[1][2], (Node, int))
+
+        lhs = symbols[n.args[0]]
+
+        b, counter = gen_tensor_dims(4, counter)
+        input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
+
+        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
+        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
+        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
+
+        # dimensions not equal
+        my_ne, counter = gen_bvar(counter)
+        neq_1 = BinConstraintD(d1, b[0], op_neq)
+        neq_2 = BinConstraintD(d2, b[1], op_neq)
+        neq_3 = BinConstraintD(d3, b[2], op_neq)
+
+        # dimensions inconsistent
+        dims_inconsistent1 = Conj(
+            [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]
+        )
+        dims_inconsistent2 = Conj(
+            [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]
+        )
+        dims_inconsistent3 = Conj(
+            [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]
+        )
+
+        dims_inconsistent = Disj(
+            [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]
+        )
+
+        # we are covering size 3 and 4 only for now
+        ne_constraint = Conj([input_is_size3, dims_inconsistent])
+
+        my_ne, counter = gen_bvar(counter)
+        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
+
+    elif len(n.args[1]) == 4:
+        assert isinstance(n.args[1][0], (Node, int))
+        assert isinstance(n.args[1][1], (Node, int))
+        assert isinstance(n.args[1][2], (Node, int))
+        assert isinstance(n.args[1][3], (Node, int))
+
+        lhs = symbols[n.args[0]]
+
+        b1, counter = gen_dvar(counter)
+        b2, counter = gen_dvar(counter)
+        b3, counter = gen_dvar(counter)
+        b4, counter = gen_dvar(counter)
+
+        input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
+
+        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
+        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
+        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
+        d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
+
+        # dimensions not equal
+        my_ne, counter = gen_bvar(counter)
+        neq_1 = BinConstraintD(d1, b1, op_neq)
+        neq_2 = BinConstraintD(d2, b2, op_neq)
+        neq_3 = BinConstraintD(d3, b3, op_neq)
+        neq_4 = BinConstraintD(d4, b4, op_neq)
+
+        # dimensions to inconsistent
+        dims_inconsistent1 = Conj(
+            [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]
+        )
+        dims_inconsistent2 = Conj(
+            [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]
+        )
+        dims_inconsistent3 = Conj(
+            [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]
+        )
+        dims_inconsistent4 = Conj(
+            [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]
+        )
+
+        dims_inconsistent = Disj(
+            [
+                dims_inconsistent1,
+                dims_inconsistent2,
+                dims_inconsistent3,
+                dims_inconsistent4,
+            ]
+        )
+
+        ne_constraint = Conj([input_is_size4, dims_inconsistent])
+
+        my_ne, counter = gen_bvar(counter)
+
+        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
+
+    else:
+        raise NotImplementedError("Method not yet implemented")
+
+    return [equality_constraint], counter
+
+
+@register_inference_rule(operator.lt)
+def lt_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    # We make sure this node will not be used again. We do not
+    # generate a constraint about that node. Only about the operands.
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            lt_tensor, counter = gen_tvar(counter)
+            symbols[n] = lt_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            lt_constraint = BinConstraintD(e1, e2, op_lt)
+
+            my_lt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError("Sort Mismatch")
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            lt_constraint = BinConstraintD(e1, e2, op_lt)
+
+            my_lt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
+            return [equality_constraint], counter
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    else:
+        raise NotImplementedError("Method not yet implemented")
+
+
+@register_inference_rule(torch.full)
+def full_inference_rule(n: Node, symbols, constraints, counter):
+    full, counter = gen_tvar(counter)
+    symbols[n] = full
+    res = []
+
+    assert isinstance(n.args[0], Iterable)
+    for arg in n.args[0]:
+        dim = arg if isinstance(arg, int) else symbols[arg]
+        res.append(dim)
+    c = BinConstraintT(full, TensorType(list(res)), op_eq)  # type: ignore[arg-type]
+    return [c], counter
+
+
+# TODO normalize index
+@register_inference_rule(torch.arange)
+def arange_inference_rule(n: Node, symbols, constraints, counter):
+    start = 0
+    step = 1
+
+    if len(n.args) == 1:
+        end = symbols[n.args[0]]
+    else:
+        raise NotImplementedError("Not yet implemented")
+
+    # int((end - start) / step)
+    d1, counter = gen_dvar(counter)
+    size_constraint = BinConstraintD(
+        d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq
+    )
+    arange, counter = gen_tvar(counter)
+    symbols[n] = arange
+
+    # either the a parameter is a number or it is Dyn
+    c1 = Disj(
+        [
+            BinConstraintD(end, Dyn, op_eq),
+            BinConstraintD(start, Dyn, op_eq),
+            BinConstraintD(step, Dyn, op_eq),
+        ]
+    )
+    c2 = BinConstraintD(d1, Dyn, op_eq)
+    both_dyn = Conj([c1, c2])
+
+    c11 = Conj(
+        [
+            BinConstraintD(end, Dyn, op_neq),
+            BinConstraintD(start, Dyn, op_neq),
+            BinConstraintD(step, Dyn, op_neq),
+        ]
+    )
+    c22 = BinConstraintD(d1, Dyn, op_neq)
+    both_numbers = Conj([c11, c22, size_constraint])
+
+    return [
+        BinConstraintT(arange, TensorType([d1]), op_eq),
+        Disj([both_dyn, both_numbers]),
+    ], counter
+
+
+def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
+    # additional vars that don't correspond to expressions
+    e11, counter = gen_tvar(counter)
+    e22, counter = gen_tvar(counter)
+
+    # generate constraints
+    c1 = TGreatestUpperBound(output_var, e11, e22)
+    c2 = ApplyBroadcasting(e11, e22, e1, e2)
+    c3 = BinConstraintT(e11, e22, op_consistency)
+    return [c1, c2, c3], counter
+
+
+@register_inference_rule(operator.mul)
+@register_inference_rule(torch.ne)
+@register_inference_rule("ne")
+@register_inference_rule(torch.add)
+@register_inference_rule(operator.add)
+def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
+    op_code = None
+    if n.target is operator.add or n.target is torch.add:
+        op_code = op_add
+    elif n.target is operator.mul:
+        op_code = op_mul
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(symbols[n.args[0]], TVar) and isinstance(
+            symbols[n.args[1]], TVar
+        ):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+            e2 = symbols[n.args[1]]
+
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
+        if isinstance(symbols[n.args[0]], TVar):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+            return [BinConstraintT(my_output, e1, op_eq)], counter
+        elif isinstance(symbols[n.args[0]], DVar):
+            my_output, counter = gen_dvar(counter)  # type: ignore[arg-type,assignment]
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+
+            # we will propagate the runtime value here since this is regular addition
+            c = Conj(
+                [
+                    BinConstraintD(
+                        my_output, BinConstraintD(e1, n.args[1], op_code), op_eq
+                    ),
+                    BinConstraintD(0, my_output, op_leq),
+                ]
+            )
+            return [c], counter
+
+    elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
+        if isinstance(symbols[n.args[1]], TVar):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e2 = symbols[n.args[1]]
+            return [BinConstraintT(my_output, e2, op_eq)], counter
+        elif isinstance(symbols[n.args[1]], DVar):
+            my_output, counter = gen_dvar(counter)  # type: ignore[arg-type,assignment]
+            symbols[n] = my_output
+            e2 = symbols[n.args[1]]
+
+            # we will propagate the runtime value here since this is regular addition
+            c = Conj(
+                [
+                    BinConstraintD(
+                        my_output, BinConstraintD(e2, n.args[0], op_code), op_eq
+                    ),
+                    BinConstraintD(0, my_output, op_leq),
+                ]
+            )
+            return [c], counter
+
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    else:
+        # TODO generate add constraints for scalar addition
+        raise NotImplementedError("Addition not yet implemented")
+
+
+@register_inference_rule(torch.flatten)
+def flatten_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    flattened, counter = gen_tvar(counter)
+    symbols[n] = flattened
+
+    input = symbols[n.args[0]]
+
+    # set the default start and end dims
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    c1 = BinConstraintT(input, Dyn, op_eq)
+    c2 = BinConstraintT(flattened, Dyn, op_eq)
+    both_dyn = Conj([c1, c2])
+
+    const = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        c, counter = generate_flatten_constraints(
+            start_dim, end_dim, input, flattened, i, counter
+        )
+        const.append(c)
+
+    return [Disj([both_dyn, *const])], counter
+
+
+@register_inference_rule(torch.nn.functional.layer_norm)
+def layer_norm_functional(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
+
+
+@register_inference_rule(torch.nn.LayerNorm)
+def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal.
+    Input should be consistent with the normalized_shape
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(
+        n, module_instance.normalized_shape, symbols, counter
+    )
+
+
+def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintT(output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs, counter = gen_tensor_dims(i, counter)
+        nat_constraints = gen_nat_constraints(new_dims_rhs)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
+                BinConstraintT(output, TensorType(new_dims_rhs), op_eq),
+            ]
+            + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape))
+            + nat_constraints
+        )
+        c2.append(c_tensor_i)
+    return [Disj([c1, Disj(c2)])], counter
+
+
+@register_inference_rule(torch.nn.Dropout)
+@register_inference_rule(torch.nn.ReLU)
+def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal.
+    """
+    assert isinstance(n.args[0], Node)
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+    assert isinstance(input, TVar)
+    return [BinConstraintT(input, output, op_eq)], counter
+
+
+@register_inference_rule(torch.nn.Linear)
+def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output sizes should be the same except for the last dimension
+    If the input is Dyn, then so should the output
+    """
+    assert isinstance(n.args[0], Node)
+    return linear_constraints(
+        n, module_instance.in_features, module_instance.out_features, symbols, counter
+    )
+
+
+@register_inference_rule("dim")
+def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    my_dim, counter = gen_dvar(counter)
+    symbols[n] = my_dim
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
+
+    c1 = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
+                BinConstraintD(my_dim, i, op_eq),
+            ]
+        )
+        c1.append(c_tensor_i)
+
+    return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
+
+
+@register_inference_rule(torch._C._nn.linear)
+def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    weight_dims, counter = gen_tensor_dims(2, counter)
+    equality_constraint = BinConstraintT(
+        symbols[n.args[1]], TensorType(weight_dims), op_eq
+    )
+    constraints, counter = linear_constraints(
+        n, weight_dims[1], weight_dims[0], symbols, counter
+    )
+    return [equality_constraint] + constraints, counter
+
+
+def linear_constraints(n: Node, in_features, out_features, symbols, counter):
+    linear_output, counter = gen_tvar(counter)
+    symbols[n] = linear_output
+    linear_input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
+    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
+                BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq),
+            ]
+            + add_linear_constraints(
+                new_dims_rhs_1, new_dims_rhs_2, in_features, out_features
+            )
+            + nat_constraints
+        )
+        c2.append(c_tensor_i)
+    return [Disj([c1, Disj(c2)])], counter
+
+
+def add_layer_norm_constraints(input_dim, normalized_dim):
+    """
+    The constraints say that the type has te form: [*, 1024, 1024]
+     while the normalized_dim have the form [1024, 1024]
+    Args:
+        input_dim: Input shape of layer norm
+        normalized_dim: normalized_dim parameter of the module instance
+
+    """
+
+    # in this case we return false since there's a pattern mismatch
+    if len(normalized_dim) > len(input_dim):
+        return [F()]
+
+    else:
+        constraints = []
+        for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
+            constraints.append(BinConstraintD(i, n, op_consistency))
+        return constraints
+
+
+def add_linear_constraints(dims1, dims2, in_features, out_features):
+    assert len(dims1) == len(dims2)
+    constraints = []
+    for i in range(len(dims1)):
+        if i == len(dims1) - 1:
+            constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
+            constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
+        else:
+            constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
+
+    return constraints
+
+
+@register_inference_rule(torch.reshape)
+def reshape_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    my_reshape, counter = gen_tvar(counter)
+    symbols[n] = my_reshape
+
+    src_var = symbols[n.args[0]]
+    t2 = n.args[1]
+    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])  # type: ignore[union-attr]
+    c1 = BinConstraintT(my_reshape, t2_type, op_eq)  # type: ignore[union-attr]
+    c2 = CanReshape(src_var, t2_type)
+
+    return [c1, c2], counter
+
+
+@register_inference_rule(BatchNorm2d)
+def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    batchnorm_output, counter = gen_tvar(counter)
+    symbols[n] = batchnorm_output
+    batchnorm_input = symbols[n.args[0]]
+
+    # dim vars
+    d1, counter = gen_dvar(counter)
+    d2, counter = gen_dvar(counter)
+    d3, counter = gen_dvar(counter)
+    d4, counter = gen_dvar(counter)
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
+    c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
+    return [c1, c2, *nat_constraints], counter
+
+
+@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
+def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    avg_pool, counter = gen_tvar(counter)
+
+    symbols[n] = avg_pool
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    d1, counter = gen_dvar(counter)
+    d2, counter = gen_dvar(counter)
+    d3, counter = gen_dvar(counter)
+    d4, counter = gen_dvar(counter)
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+    c2 = BinConstraintT(
+        avg_pool,
+        TensorType(
+            [d1, d2, module_instance.output_size[0], module_instance.output_size[1]]
+        ),
+        op_eq,
+    )
+
+    return [c1, c2, *nat_constraints], counter
+
+
+@register_inference_rule(Conv2d)
+def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    my_conv, counter = gen_tvar(counter)
+    symbols[n] = my_conv
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
+
+    # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+
+    # c2 = DConsistency(module_instance.in_channels, d2)
+    c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
+
+    c3 = CalcConv(
+        my_conv,
+        input_var,
+        module_instance.out_channels,
+        module_instance.kernel_size,
+        module_instance.padding,
+        module_instance.stride,
+        module_instance.dilation,
+        [d1, d2, d3, d4],
+    )
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    return [c1, c2, c3, *nat_constraints], counter
+
+
+@register_inference_rule(torch.nn.MaxPool2d)
+def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    maxpool, counter = gen_tvar(counter)
+    symbols[n] = maxpool
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
+
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+
+    c2 = CalcMaxPool(
+        maxpool,
+        input_var,
+        module_instance.kernel_size,
+        module_instance.padding,
+        module_instance.stride,
+        module_instance.dilation,
+        [d1, d2, d3, d4],
+    )
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    return [c1, c2, *nat_constraints], counter
+
+
+class ConstraintGenerator:
+    def __init__(self, traced, graph=None):
+        self.traced = traced  # traced or tracer.root
+        self.traced_params = dict(self.traced.named_parameters())
+        self.constraints = []
+        self.symbol_dict = {}
+        self.graph = traced.graph if hasattr(traced, "graph") else graph
+
+    def generate_constraints(self, counter=0):
+        """
+        Iterate through every node and generate constraints
+        Effect: self.constraints will be populated with the final constraints
+        """
+        graph = self.graph
+
+        all_constraints = []
+
+        # pyrefly: ignore [missing-attribute]
+        for n in graph.nodes:
+            (constraints, counter) = self.generate_constraints_node(n, counter)
+            all_constraints += constraints
+
+        return Conj(all_constraints), counter
+
+    def generate_constraints_node(self, n: Node, counter):
+        """
+        Generate constraints the given node:
+        Currently supported operations:
+        - Reshape
+        - Add
+        - conv2d
+        """
+
+        if n.op == "placeholder":
+            x, counter = gen_tvar(counter)
+            self.symbol_dict[n] = x
+
+            my_type = n.type
+
+            if n.type != Dyn and (not isinstance(n.type, TensorType)):
+                if n.type == torch.nn.parameter.Parameter:
+                    # since we have a parameter, the shape must be static
+                    assert "example_value" in n.meta
+                    my_type = TensorType(n.meta["example_value"].size())
+                else:
+                    my_type = Dyn
+
+            c1 = BinConstraintT(my_type, x, op_precision)
+            c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
+            return [c1, c2], counter
+
+        elif n.op == "call_function":
+            if n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](
+                    n, self.symbol_dict, self.constraints, counter
+                )
+            else:
+                raise RuntimeError(
+                    f"No inference rule registered for target {n.target}!"
+                )
+
+        elif n.op == "call_module":
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _INFERENCE_RULES:
+                return _INFERENCE_RULES[type(module_instance)](
+                    n, module_instance, self.symbol_dict, self.constraints, counter
+                )
+            else:
+                raise RuntimeError(
+                    f"No inference rule registered for class {type(module_instance)}!"
+                )
+
+        elif n.op == "call_method":
+            if n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](
+                    n, self.symbol_dict, self.constraints, counter
+                )
+            else:
+                raise RuntimeError(
+                    f"No inference rule registered for target {n.target}!"
+                )
+
+        elif n.op == "get_attr":
+            t = self.traced_params.get(n.target, None)
+
+            if isinstance(t, torch.Tensor):
+                if len(t.shape) > 0:
+                    res = list(t.shape)
+                    attr_type = TensorType(res)
+                    output, counter = gen_tvar(counter)
+                    self.symbol_dict[n] = output
+                    return [BinConstraintT(output, attr_type, op_eq)], counter
+                else:
+                    # scalar?
+                    return [], counter
+            else:
+                return [], counter
+
+        elif n.op == "output":
+            return [], counter
+
+        else:
+            raise NotImplementedError(f"Method {n.op} not yet implemented")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0782ba5affc9cbbe6b55fbba131066a35f331f5a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -0,0 +1,1322 @@
+# mypy: ignore-errors
+import copy
+import itertools
+from collections.abc import Callable
+
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    ApplyBroadcasting,
+    BinConstraintD,
+    CalcConv,
+    CalcMaxPool,
+    CalcProduct,
+    CanReshape,
+    Conj,
+    Constraint,
+    DGreatestUpperBound,
+    Disj,
+    DVar,
+    F,
+    GetItem,
+    GetItemTensor,
+    IndexSelect,
+    Prod,
+    T,
+    TGreatestUpperBound,
+    Transpose,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
+    BinConstraintT,
+    MAX_TENSOR_RANK,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_consistency,
+    op_div,
+    op_eq,
+    op_leq,
+    op_matching,
+    op_mod,
+    op_mul,
+    op_neq,
+    op_precision,
+    op_sub,
+)
+from torch.fx.experimental.migrate_gradual_types.util import (
+    gen_dvar,
+    gen_nat_constraints,
+    gen_tensor_dims,
+)
+from torch.fx.tensor_type import Dyn, TensorType
+
+
+_TRANSFORMATION_RULES: dict[Constraint, Callable] = {}
+
+
+def register_transformation_rule(call_target):
+    def register(fn):
+        if call_target in _TRANSFORMATION_RULES:
+            raise RuntimeError(
+                f"Transformation rule already registered for {call_target}!"
+            )
+        _TRANSFORMATION_RULES[call_target] = fn
+        return fn
+
+    return register
+
+
+def valid_index(index, dims):
+    """
+    Given a list of dimensions, checks if an index is valid in the list
+    """
+    try:
+        dims[index]
+        return T()
+    except IndexError:
+        return F()
+
+
+@register_transformation_rule(Transpose)
+def transform_transpose(constraint, counter):
+    """
+    Similar to a sequence of two index-selects
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index1 = valid_index(constraint.index1, dims)
+    is_valid_index2 = valid_index(constraint.index2, dims)
+    new_dims = copy.deepcopy(dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    if is_valid_index1 == T() and is_valid_index2 == T():
+        new_dims[constraint.index1] = dims[constraint.index2]
+        new_dims[constraint.index2] = dims[constraint.index1]
+
+    transformed_constraint = Conj(
+        [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            *nat_constraints,
+            is_valid_index1,
+            is_valid_index2,
+            BinConstraintT(constraint.output, TensorType(new_dims), op_eq),
+        ]
+    )
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(IndexSelect)
+def transform_index_select(constraint, counter):
+    """
+    The constraints consider the given tensor size, checks if the index is valid
+    and if so, generates a constraint for replacing the input dimension
+    with the required dimension
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index = valid_index(constraint.index, dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # if the index is valid then replace the input dimension with the new dimension
+    # otherwise the dimension will not be replaced and the clause will contain False
+    if is_valid_index == T():
+        new_dims = copy.deepcopy(dims)
+        new_dims[constraint.index] = constraint.dim_replace
+
+    transformed_constraint = Conj(
+        [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            *nat_constraints,
+            is_valid_index,
+            BinConstraintT(constraint.output, TensorType(new_dims), op_eq),
+        ]
+    )
+
+    # print(constraints)
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(GetItem)
+def transform_get_item(constraint, counter):
+    """
+    generate an equality of the form:
+    t = [a1, ..., an]
+    then generate constraints that check if the given index is valid
+    given this particular tensor size.
+    If the index is valid, generate a constraint to get the item
+    Note that we already handled the Dyn input case in the previous
+    step.
+    Args:
+        constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
+        counter: variable tracking
+    Returns: simplified constraints for GetItem
+
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+    is_valid_index = valid_index(constraint.index, dims)
+
+    all_constraints = [
+        BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+        *nat_constraints,
+        is_valid_index,
+    ]
+
+    # if the index is valid, we generate a constraint for getting an item
+    # otherwise this clause will have been UNSAT due to the wrong index
+    if is_valid_index == T():
+        all_constraints.append(
+            BinConstraintD(constraint.res, dims[constraint.index], op_eq)
+        )
+
+    return Conj(all_constraints), counter
+
+
+def valid_index_tensor(index, dims):
+    """
+    if the slice instances exceed the length of the dimensions
+    then this is a type error so we return False
+    """
+    slice_count = 0
+    for s in index:
+        if isinstance(s, slice):
+            slice_count += 1
+    if slice_count > len(dims):
+        return F()
+    else:
+        return T()
+
+
+@register_transformation_rule(GetItemTensor)
+def transform_get_item_tensor(constraint, counter):
+    """
+    When the index is a tuple, then the output will be a tensor
+    TODO: we have to check if this is the case for all HF models
+
+    The cases we are covering here are a tuple with one of:
+     - slice with default argument
+     - None
+
+     None appends 1 to the input tensor dimensions
+     so each occurrence of 'None' increases the rank by 1
+
+     slice with default arguments does not change the rank
+    """
+    assert isinstance(constraint.index_tuple, tuple)
+
+    # generate a result tensor of the expected size
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # generate a place-holder list of the right rank
+    # where "slice" does not contribute to the rank and "None" does
+    none_c = constraint.index_tuple.count(None)
+    resulting_tensor_dims = (none_c + len(dims)) * [None]
+
+    dim_index = 0
+    for i in range(len(constraint.index_tuple)):
+        # append 1 to the right location of the resulting tensor
+        if constraint.index_tuple[i] is None:
+            resulting_tensor_dims[i] = 1
+
+        elif constraint.index_tuple[i] == slice(None, None, None):
+            pass
+
+        else:
+            raise NotImplementedError("Method not yet implemented")
+
+    # append the remaining dimensions to the right location
+    dim_index = 0
+    for i in range(len(resulting_tensor_dims)):
+        if resulting_tensor_dims[i] is None:
+            resulting_tensor_dims[i] = dims[dim_index]
+            dim_index += 1
+
+    # check if the index is valid
+    is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
+
+    # check if the resulting tensor is within bounds
+    if len(resulting_tensor_dims) > 4:
+        return F(), counter
+
+    else:
+        constraints = [
+            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+            BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
+            *nat_constraints,
+            is_valid_index,
+        ]
+        return Conj(constraints), counter
+
+
+@register_transformation_rule(BinConstraintT)
+def generate_binconstraint_t(constraint, counter):
+    """
+    Transform binary constraints for tensors
+    """
+
+    # precision constraints
+    if constraint.op == op_precision:
+        if constraint.lhs == Dyn:
+            return T(), counter
+        elif isinstance(constraint.lhs, TensorType):
+            is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
+            if is_fully_static:
+                return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
+            else:
+                new_dims = []
+
+                for _ in range(len(constraint.lhs.__args__)):
+                    dim, counter = gen_dvar(counter)
+                    new_dims.append(dim)
+
+                new_dim_constraints = (
+                    [
+                        BinConstraintD(old_dim, new_dim, op_precision)
+                        for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)
+                    ]
+                    + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)]
+                    + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims]
+                )
+                return Conj(new_dim_constraints), counter
+
+    # matching
+    elif constraint.op == op_matching:
+        assert isinstance(constraint.rhs, TensorType)
+        d1 = constraint.rhs.__args__[0]
+        d2 = constraint.rhs.__args__[1]
+        d3 = constraint.rhs.__args__[2]
+        d4 = constraint.rhs.__args__[3]
+
+        conj = [
+            BinConstraintT(constraint.lhs, Dyn, op_eq),
+            BinConstraintD(d1, Dyn, op_eq),
+            BinConstraintD(d2, Dyn, op_eq),
+            BinConstraintD(d3, Dyn, op_eq),
+            BinConstraintD(d4, Dyn, op_eq),
+        ]
+        return (
+            Disj(
+                [
+                    Conj(conj),
+                    BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq),
+                ]
+            ),
+            counter,
+        )
+
+    elif constraint.op == op_consistency:
+        c_dyn = Disj(
+            [
+                BinConstraintT(constraint.lhs, Dyn, op_eq),
+                BinConstraintT(constraint.rhs, Dyn, op_eq),
+            ]
+        )
+        (
+            (
+                c_tensor_1,
+                c_tensor_2,
+                c_tensor_3,
+                c_tensor_4,
+            ),
+            counter,
+        ) = gen_consistency_constraints(constraint, counter)
+
+        return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
+
+    elif constraint.op == op_leq:
+        assert isinstance(constraint.rhs, int)
+        disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
+        for i in range(1, constraint.rhs + 1):
+            dims = []
+            for _ in range(1, i + 1):
+                dim_var, counter = gen_dvar(counter)
+                dims.append(dim_var)
+            disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
+        return Disj(disj), counter
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(BinConstraintD)
+def generate_binconstraint_d(constraint, counter):
+    """
+    Transform binary constraints for dimensions
+    """
+    if constraint.op == op_precision:
+        if isinstance(constraint.lhs, int):
+            return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
+        elif constraint.lhs == Dyn:
+            return T(), counter
+
+    elif constraint.op == op_consistency:
+        return (
+            Disj(
+                [
+                    BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
+                    BinConstraintD(constraint.rhs, Dyn, op_eq),
+                    BinConstraintD(constraint.lhs, Dyn, op_eq),
+                ]
+            ),
+            counter,
+        )
+
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(Conj)
+def generate_conj(constraint, counter):
+    """
+    Transform conjunctions
+    """
+    new = []
+    for c in constraint.conjucts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Conj(new), counter
+
+
+@register_transformation_rule(Disj)
+def generate_disj(constraint, counter):
+    """
+    Transform disjunctions
+    """
+    new = []
+    for c in constraint.disjuncts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Disj(new), counter
+
+
+@register_transformation_rule(TGreatestUpperBound)
+def generate_gub(constraint, counter):
+    """
+    Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
+    on dimensions
+    """
+    c1 = Conj(
+        [
+            Disj(
+                [
+                    BinConstraintT(constraint.rhs1, Dyn, op_eq),
+                    BinConstraintT(constraint.rhs2, Dyn, op_eq),
+                ]
+            ),
+            BinConstraintT(constraint.res, Dyn, op_eq),
+        ]
+    )
+
+    [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
+
+    return Disj([c1, c2, c3, c4, c5]), counter
+
+
+@register_transformation_rule(DGreatestUpperBound)
+def generate_d_gub(constraint, counter):
+    """
+    Transform greatest upper bound for dimensions into equality constraints
+    """
+    c1 = Conj(
+        [
+            BinConstraintD(constraint.rhs1, Dyn, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs2, op_eq),
+        ]
+    )
+    c2 = Conj(
+        [
+            BinConstraintD(constraint.rhs2, Dyn, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs1, op_eq),
+        ]
+    )
+    c3 = Conj(
+        [
+            BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq),
+            BinConstraintD(constraint.res, constraint.rhs1, op_eq),
+        ]
+    )
+    return Disj([c1, c2, c3]), counter
+
+
+@register_transformation_rule(CalcConv)
+def generate_calc_conv(constraint, counter):
+    d, counter = gen_tensor_dims(4, counter)
+    conv_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the convolution result is a tensor of size 4
+    c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
+
+    # the second dimension of the output is equal to the output channels
+    c2 = Conj(
+        [
+            BinConstraintD(d[1], constraint.c_out, op_eq),
+            BinConstraintD(d[1], Dyn, op_neq),
+        ]
+    )
+
+    # the input corresponds to the output in the first dimension of the convolution
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj(
+        [
+            BinConstraintD(0, d[0], op_leq),
+            BinConstraintD(0, d[1], op_leq),
+            BinConstraintD(0, d[2], op_leq),
+            BinConstraintD(0, d[3], op_leq),
+        ]
+    )
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcMaxPool)
+def generate_calc_maxpool(constraint, counter):
+    """
+    Transform maxpool constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+    maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the maxpool result is a tensor of size 4
+    c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
+
+    # the input corresponds to the output in the first and second dimension of maxpool
+    c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj(
+        [
+            BinConstraintD(0, d[0], op_leq),
+            BinConstraintD(0, d[1], op_leq),
+            BinConstraintD(0, d[2], op_leq),
+            BinConstraintD(0, d[3], op_leq),
+        ]
+    )
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcProduct)
+def generate_calc_product(constraint, counter):
+    """
+    Transform flatten constraints
+    """
+    start = constraint.start
+    end = constraint.end
+    dims = constraint.dims_to_flatten
+    flattened = constraint.flattened
+    n = len(constraint.dims_to_flatten)
+
+    # this will be evaluated right here
+    boundary_check = 0 <= start and start < end and end <= n
+
+    c_boundary = T() if boundary_check else F()
+
+    lhs = dims[0:start]
+    rhs = dims[end:]
+    mid = dims[start:end]
+
+    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        p = list(p)
+        # this tells us there is a dynamic variable
+        contains_dyn = not all(constraint.op == op_neq for constraint in p)
+        if contains_dyn:
+            mid_var = [Dyn]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(
+                    Conj(
+                        [
+                            BinConstraintT(
+                                flattened, TensorType(lhs + mid_var + rhs), op_eq
+                            )
+                        ]
+                        + p
+                    )
+                )
+        else:
+            new_var, counter = gen_dvar(counter)
+            mid_eq_prod = Conj(
+                [
+                    BinConstraintD(new_var, Prod(mid), op_eq),
+                    BinConstraintD(new_var, Dyn, op_neq),
+                ]
+            )
+            mid_var = [new_var]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(
+                    Conj(
+                        [
+                            BinConstraintT(
+                                flattened, TensorType(lhs + mid_var + rhs), op_eq
+                            ),
+                            mid_eq_prod,
+                        ]
+                        + p
+                    )
+                )
+
+    return Conj([Disj(all_constraints), c_boundary]), counter
+
+
+@register_transformation_rule(CanReshape)
+def generate_reshape(constraint, counter):
+    """
+    Transform reshape constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+
+    d1 = d[0]
+    d2 = d[1]
+    d3 = d[2]
+    d4 = d[3]
+
+    target = constraint.target.__args__
+
+    is_fully_static = all(d != Dyn for d in target)
+
+    # dynamic tensor
+    c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
+    c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
+    c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
+    c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
+    c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
+
+    d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
+    d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
+
+    d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
+    d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
+
+    d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    nat_d1 = BinConstraintD(0, d1, op_leq)
+    nat_d2 = BinConstraintD(0, d2, op_leq)
+    nat_d3 = BinConstraintD(0, d3, op_leq)
+    nat_d4 = BinConstraintD(0, d4, op_leq)
+
+    if is_fully_static:
+        # size 1 tensor
+        c3_tensor1 = Disj(
+            [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))]
+        )
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # size 2 tensor
+        all_tensor_2 = Conj(
+            [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]
+        )
+
+        # size 3 tensor
+        all_tensor_3 = Conj(
+            [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]
+        )
+
+        # size 4 tensor
+        all_tensor_4 = Conj(
+            [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]
+        )
+
+        return (
+            Conj(
+                [
+                    Disj(
+                        [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]
+                    ),
+                    nat_d1,
+                    nat_d2,
+                    nat_d3,
+                    nat_d4,
+                ]
+            ),
+            counter,
+        )
+
+    # then there must be exactly one occurrence of dyn
+    else:
+        new_target = [n for n in target if n != Dyn]
+
+        # tensor 1
+        c3_tensor1 = Disj(
+            [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))]
+        )
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # tensor 2
+        c21 = Disj([d1_eq_dyn, d2_eq_dyn])
+        c22 = Conj(
+            [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]
+        )
+        all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
+
+        # tensor 3
+        c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
+        c32 = Conj(
+            [
+                d1_neq_dyn,
+                d2_neq_dyn,
+                d3_neq_dyn,
+                is_dim_div_by_target(new_target, Prod([d1, d2, d3])),
+            ]
+        )
+        all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
+
+        # tensor 4
+        c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
+        c42 = Conj(
+            [
+                d1_neq_dyn,
+                d2_neq_dyn,
+                d3_neq_dyn,
+                d4_neq_dyn,
+                is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])),
+            ]
+        )
+        all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
+
+        return (
+            Conj(
+                [
+                    Disj(
+                        [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]
+                    ),
+                    nat_d1,
+                    nat_d2,
+                    nat_d3,
+                    nat_d4,
+                ]
+            ),
+            counter,
+        )
+
+
+@register_transformation_rule(ApplyBroadcasting)
+def generate_broadcasting(constraint, counter):
+    """
+    Transform broadcasting constraints
+    """
+    e11, e12 = constraint.res1, constraint.res2
+    e1, e2 = constraint.input1, constraint.input2
+
+    e1_dyn = BinConstraintT(e1, Dyn, op_eq)
+    e2_dyn = BinConstraintT(e2, Dyn, op_eq)
+
+    # Introduce dimensions
+    e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
+    e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
+
+    # dyn possibility
+    e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
+    e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
+
+    # tensor possibility
+    # generate dimensions to create tensors of size 1
+    final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints(
+        e1, e2, e11, e12, 1, counter
+    )
+
+    # generate dimensions to create tensors of size 2
+    (
+        final_tensor_2_constraint_no_padding,
+        final_tensor_2_constraint_padding_arg1,
+        final_tensor_2_constraint_padding_arg2,
+        nat_dims_2,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
+
+    # generate dimensions to create tensors of size 3
+    (
+        final_tensor_3_constraint_no_padding,
+        final_tensor_3_constraint_padding_arg1,
+        final_tensor_3_constraint_padding_arg2,
+        nat_dims_3,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
+
+    # generate dimensions to create tensors of size 4
+    (
+        final_tensor_4_constraint_no_padding,
+        final_tensor_4_constraint_padding_arg1,
+        final_tensor_4_constraint_padding_arg2,
+        nat_dims_4,
+        counter,
+    ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
+
+    final_result = Disj(
+        [
+            e1_dyn_constraint,
+            e2_dyn_constraint,
+            final_tensor_1_constraint,
+            final_tensor_2_constraint_no_padding,
+            final_tensor_2_constraint_padding_arg1,
+            final_tensor_2_constraint_padding_arg2,
+            final_tensor_3_constraint_no_padding,
+            final_tensor_3_constraint_padding_arg1,
+            final_tensor_3_constraint_padding_arg2,
+            final_tensor_4_constraint_no_padding,
+            final_tensor_4_constraint_padding_arg1,
+            final_tensor_4_constraint_padding_arg2,
+        ]
+    )
+
+    return (
+        Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]),
+        counter,
+    )
+
+
+def transform_constraint(constraint: Constraint, counter: int):
+    """
+    Transforms a constraint into a simpler constraint.
+    Ex: precision and consistency are transformed to equality
+    Args:
+        constraint: constraint to be transformed
+        counter: for variable tracking
+
+    Returns: Constraint
+
+    """
+    if type(constraint) in _TRANSFORMATION_RULES:
+        return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
+
+    else:
+        return constraint, counter
+
+
+def calc_last_two_dims(constraint, d: list[DVar]):
+    """
+    Generates constraints for the last two dimensions of a convolution or a maxpool output
+    Args:
+        constraint: CalcConv or CalcMaxPool
+        d: The list of output dimensions
+
+    Returns: Constraints for calculating the last two dimensions of the output
+
+    """
+
+    assert isinstance(constraint, (CalcConv, CalcMaxPool))
+
+    b3 = constraint.matching_constraint[2]
+    b4 = constraint.matching_constraint[3]
+
+    b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
+    b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
+
+    d3_not_dyn = Conj(
+        [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]
+    )
+    d4_not_dyn = Conj(
+        [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]
+    )
+
+    # transform parameters into tuples in case they are not already
+    padding = (
+        (constraint.padding, constraint.padding)
+        if isinstance(constraint.padding, int)
+        else constraint.padding
+    )
+    kernel = (
+        (constraint.kernel, constraint.kernel)
+        if isinstance(constraint.kernel, int)
+        else constraint.kernel
+    )
+    stride = (
+        (constraint.stride, constraint.stride)
+        if isinstance(constraint.stride, int)
+        else constraint.stride
+    )
+    dilation = (
+        (constraint.dilation, constraint.dilation)
+        if isinstance(constraint.dilation, int)
+        else constraint.dilation
+    )
+
+    f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
+    f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
+    f3 = BinConstraintD(
+        BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div
+    )
+    f4 = BinConstraintD(f3, 1, op_add)
+
+    c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
+
+    f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
+    f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
+    f33 = BinConstraintD(
+        BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div
+    )
+    f44 = BinConstraintD(f33, 1, op_add)
+
+    c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
+
+    return c4, c5
+
+
+def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]):
+    """
+    Generate all possibilities of being equal or not equal to dyn for my_list
+    Args:
+        my_list: List of tensor dimensions
+
+    Returns: A list of a list of constraints. Each list of constraints corresponds to
+    one possibility about the values of the dimension variables
+    """
+    # generate all possibilities of being equal or not equal to dyn for my_list
+    eq_possibilities = [
+        BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))
+    ]
+    neq_possibilities = [
+        BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
+    ]
+
+    d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)]
+    all_possibilities = list(itertools.product(*d_possibilities))
+    return all_possibilities
+
+
+def is_target_div_by_dim(target: list[int], dim: list[DVar]):
+    """
+    Generate constraints to check if the target dimensions are divisible by the input dimensions
+    Args:
+        target: Target dimensions
+        dim: Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
+
+
+def is_dim_div_by_target(target: list[int], dim: list[DVar]):
+    """
+    Generate constraints to check if the input dimensions is divisible by the target dimensions
+    Args:
+        target: Target dimensions
+        dim:  Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
+
+
+def gen_all_reshape_possibilities(list_of_dims, target):
+    """
+    Consider all possibilities what the input dimensions could be (number or dynamic)
+    Then generate the appropriate constraints using multiplication or mod depending on the possibility
+    The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
+    for the input. Target is fixed because at most one dimension could be dyn.
+    We have different cases for this.
+
+    Args:
+        list_of_dims: The input list of dimensions
+        target: The tensor we want to reshape to
+
+    Returns: A disjunction of transformed reshape constraints
+
+    """
+    all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        to_multiply = []
+
+        p = list(p)
+
+        for constraint in p:
+            assert isinstance(constraint, BinConstraintD)
+            if constraint.op == op_neq:
+                to_multiply.append(constraint.lhs)
+
+        if not to_multiply:
+            all_constraints.append(Conj(p))
+
+        elif len(to_multiply) < len(list_of_dims):
+            all_constraints.append(
+                Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])
+            )
+        else:
+            all_constraints.append(
+                Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)])
+            )
+
+    return Disj(all_constraints)
+
+
+def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
+    """
+    Apply broadcasting to the 'index' dimension of tensor_input1.
+    Args:
+        tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
+        tensor_input2: represents the second input
+        res1: broadcasted result 1
+        res2: broadcasted result 2
+        index: the index to broadcast
+        padding: If padding was used, then tensor_input1[index] does not exist
+
+    Returns:
+
+    """
+    if tensor_input1[index] is None:
+        assert padding
+
+    if not padding:
+        # then the inputs are the same length so they all have dimensions at "index"
+        return Conj(
+            [
+                BinConstraintD(tensor_input1[index], 1, op_eq),
+                BinConstraintD(res1[index], res2[index], op_eq),
+                BinConstraintD(res2[index], tensor_input2[index], op_eq),
+            ]
+        )
+
+    else:
+        # we don't set the input dimension to 1, since it doesn't exist.
+        return Conj(
+            [
+                BinConstraintD(res1[index], res2[index], op_eq),
+                BinConstraintD(res2[index], tensor_input2[index], op_eq),
+            ]
+        )
+
+
+def apply_padding(
+    e1_var: TVar,
+    e11: BinConstraintT,
+    e2: BinConstraintT,
+    e12: BinConstraintT,
+    d2: list[DVar],
+    d11: list[DVar],
+    d12: list[DVar],
+    counter: int,
+):
+    """
+    We are considering the possibility where one input has less dimensions than
+    another input, so we apply padding to the broadcasted results
+
+    Args:
+        e1_var: Variable representing the first input where padding will be
+        e11: constraint of the form e11 = Tensortype[d1, ..., dn]
+        e2:  constraint of the form e2 = Tensortype[d1, ..., dn]
+        e12: constraint of the form e11 = Tensortype[d1, ..., dn]
+        d2: Tensor variables for the second input
+        d11: Tensor variables for the broadcasted first input
+        d12: Tensor variables for the broadcasted second input
+        counter: variable tracking
+
+    Returns: A new constraint whose goal is to apply padding to the broadcasted result
+
+    """
+
+    res = []
+
+    # pad the shorter input with None so we can pass it to the broadcasting helper function
+    for i in range(1, len(d2)):
+        d1, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
+
+        e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
+
+        simulate_padding = [None] * (len(d2) - i)
+
+        assert len(simulate_padding + d1) == len(d2)
+
+        # for every padding size, we also consider broadcasting
+        broadcast_padding = [
+            broadcast_dim(simulate_padding, d2, d11, d12, j, True)
+            for j in range(len(d2) - i)
+        ]
+
+        # we consider the possibilities for broadcasting for every dimension. Since we already
+        # padded d1, we do not consider it while broadcasting
+        all_broadcasting_possibilities = (
+            generate_all_broadcasting_possibilities_no_padding(
+                d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :]
+            )
+        )
+        # combine all constraints into a conjunction
+        c = Conj(
+            [
+                e1,
+                e11,
+                e2,
+                e12,
+                *broadcast_padding,
+                all_broadcasting_possibilities,
+                *nat_constraints,
+            ]
+        )
+        res.append(c)
+
+    return Disj(res), counter
+
+
+def no_broadcast_dim_with_index(
+    d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int
+):
+    """
+    Args:
+        d1: input 1
+        d2: input 2
+        d3: simulated broadcasting for input 1
+        d4: simulated broadcasting for input 2
+        i: the rank of the resulting tensor addition
+
+    Returns: Constraints for when no broadcasting occurs
+    """
+    return Conj(
+        [
+            Disj(
+                [
+                    Conj(
+                        [
+                            BinConstraintD(d1[i], 1, op_eq),
+                            BinConstraintD(d2[i], 1, op_eq),
+                        ]
+                    ),
+                    Conj(
+                        [
+                            BinConstraintD(d1[i], 1, op_neq),
+                            BinConstraintD(d2[i], 1, op_neq),
+                        ]
+                    ),
+                ]
+            ),
+            BinConstraintD(d1[i], d3[i], op_eq),
+            BinConstraintD(d2[i], d4[i], op_eq),
+        ]
+    )
+
+
+def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
+    """
+    Generate lists of DVar to represent tensor dimensions
+    Args:
+        num_tensors: the required number of tensors
+        dim_size: the number of dimensions for each tensor
+        counter: variable tracking
+
+    Returns: A list of a list of tensor dimensions
+
+    """
+    res = []
+
+    for _ in range(num_tensors):
+        dims, counter = gen_tensor_dims(dim_size, counter)
+        res.append(dims)
+
+    return res, counter
+
+
+def create_equality_constraints_for_broadcasting(
+    e1: TVar,
+    e2: TVar,
+    e11: TVar,
+    e12: TVar,
+    d1: list[DVar],
+    d2: list[DVar],
+    d11: list[DVar],
+    d12: list[DVar],
+):
+    """
+    Create equality constraints for when no broadcasting occurs
+    Args:
+        e1: Input 1
+        e2: Input 2
+        e11: Broadcasted input 1
+        e12: Broadcasted input 2
+        d1: Variables that store dimensions for e1
+        d2: Variables that store dimensions for e2
+        d11: Variables that store dimensions for e11
+        d12: Variables that store dimensions for e22
+
+    Returns: Four equality constraints
+
+    """
+
+    e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
+    e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
+    e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
+    e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
+    return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
+
+
+def gen_consistency_constraints(constraint: Constraint, counter: int):
+    """
+    Args:
+        constraint: Consistency constraint on tensors
+        counter: for variable tracking
+
+    Returns: Equality and consistency constraints on dimensions
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
+
+        c_tensor_i = Conj(
+            [
+                BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
+                BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq),
+            ]
+            + [
+                BinConstraintD(d1, d2, op_consistency)
+                for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)
+            ]
+            + nat_constraints
+        )
+
+        all_constraints.append(c_tensor_i)
+
+    return all_constraints, counter
+
+
+def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
+    """
+    Args:
+        constraint: Greatest upper bound on tensors
+        counter: variable tracking
+
+    Returns: A set of equality constraints and DGreatestUpperBound constraints
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        c = []
+        dims1, counter = gen_tensor_dims(i, counter)
+        c1tensor = TensorType(dims1)
+
+        dims2, counter = gen_tensor_dims(i, counter)
+        c2tensor = TensorType(dims2)
+
+        dims3, counter = gen_tensor_dims(i, counter)
+        c3tensor = TensorType(dims3)
+
+        c += [
+            BinConstraintT(constraint.rhs1, c1tensor, op_eq),
+            BinConstraintT(constraint.rhs2, c2tensor, op_eq),
+            BinConstraintT(constraint.res, c3tensor, op_eq),
+        ] + gen_nat_constraints(dims1 + dims2 + dims3)
+
+        assert (
+            len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
+        )
+        for i in range(len(c3tensor.__args__)):
+            c.append(
+                DGreatestUpperBound(
+                    c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i]
+                )
+            )
+
+        all_constraints.append(Conj(c))
+    return all_constraints, counter
+
+
+def generate_all_broadcasting_possibilities_no_padding(
+    d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar]
+):
+    """
+    Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
+    We look at all combinations for all dimensions in d1 and d2
+    Args:
+        d1: input1 dimensions
+        d2: input2 dimensions
+        d11: broadcasted input1 dimensions
+        d12: broadcasted input2 dimensions
+
+    Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
+
+    """
+
+    size = len(d1)
+
+    res2 = []
+
+    for i in range(size):
+        t1 = broadcast_dim(d1, d2, d11, d12, i)
+        t2 = broadcast_dim(d2, d1, d12, d11, i)
+        t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
+
+        res2.append(Disj([t1, t2, t3]))
+
+    return Conj(res2)
+
+
+def gen_broadcasting_constraints(
+    e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int
+):
+    """
+    Simulates broadcasting on e1 and e2 and returns the results
+    respectively in e11 and e12. Because of gradual types,
+    e1 and e2 may not be equal. Similarly, e11 and e12 may not
+    be equal. e11 and e12 should be guaranteed to be consistent
+    as they represent the shapes of the tensors to be added after
+    broadcasting.
+    Args:
+        e1: TVar representing the type of input 1
+        e2: TVar representing the type of input 2
+        e11: TVar representing the representing broadcasted input 1
+        e12: TVar representing the representing broadcasted input 2
+        i: The rank of the resulting type of addition
+        counter: for variable tracking
+
+    Returns: Simplified broadcasting constraints
+
+    """
+    dims, counter = gen_lists_of_dims(4, i, counter)
+    [d1, d2, d3, d4] = dims
+    nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
+
+    initialize_tensors_constraints = create_equality_constraints_for_broadcasting(
+        e1, e2, e11, e12, d1, d2, d3, d4
+    )
+
+    [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
+
+    # without padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_no_padding = Conj(
+        [
+            *initialize_tensors_constraints,
+            generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4),
+        ]
+    )
+
+    # with padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_padding_arg1, counter = apply_padding(
+        e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter
+    )
+
+    final_tensor_constraint_padding_arg2, counter = apply_padding(
+        e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter
+    )
+
+    return (
+        final_tensor_constraint_no_padding,
+        final_tensor_constraint_padding_arg1,
+        final_tensor_constraint_padding_arg2,
+        nat_dims_i,
+        counter,
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
new file mode 100644
index 0000000000000000000000000000000000000000..267100c8545c8b2310299337ecf64211f633f6ce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
@@ -0,0 +1,14 @@
+op_add = "+"
+op_sub = "-"
+op_mul = "*"
+op_div = "/"
+op_eq = "="
+op_neq = "!="
+op_imp = "=>"
+op_matching = "\u22b3"  # (contains)
+op_consistency = "~"
+op_precision = "\u2291"  # (square image of or equal to)
+op_leq = "\u2264"  # less-than or equal to
+op_lt = "<"
+op_gt = ">"
+op_mod = "%"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f9f33965e07551c651fa560a80c5e263dd5b85
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
@@ -0,0 +1,446 @@
+# mypy: allow-untyped-defs
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    BinConstraintD,
+    BinConstraintT,
+    BVar,
+    Conj,
+    Disj,
+    DVar,
+    F,
+    is_algebraic_expression,
+    is_bool_expr,
+    is_dim,
+    Prod,
+    T,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
+    ConstraintGenerator,
+)
+from torch.fx.experimental.migrate_gradual_types.constraint_transformation import (
+    transform_constraint,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import (
+    op_add,
+    op_div,
+    op_eq,
+    op_gt,
+    op_leq,
+    op_lt,
+    op_mod,
+    op_mul,
+    op_neq,
+    op_sub,
+)
+from torch.fx.tensor_type import Dyn, TensorType
+
+
+try:
+    import z3  # type: ignore[import]
+
+    from torch.fx.experimental.migrate_gradual_types.z3_types import (
+        D,
+        tensor_type,
+        z3_dyn,
+    )
+
+    HAS_Z3 = True
+
+    def transform_to_z3(constraint, counter, dimension_dict):
+        if isinstance(constraint, Conj):
+            conjuncts = []
+            for c in constraint.conjucts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                conjuncts.append(new_c)
+            return z3.And(conjuncts), counter
+
+        elif isinstance(constraint, Disj):
+            disjuncts = []
+            for c in constraint.disjuncts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                disjuncts.append(new_c)
+            return z3.Or(disjuncts), counter
+
+        elif isinstance(constraint, T):
+            return True, counter
+
+        elif isinstance(constraint, F):
+            return False, counter
+
+        elif isinstance(constraint, BinConstraintT):
+            if constraint.op == op_eq:
+                lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
+                return (lhs == rhs), counter
+
+            else:
+                raise NotImplementedError("Method not yet implemented")
+
+        elif isinstance(constraint, BinConstraintD):
+            if constraint.op == op_eq:
+                if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
+                    transformed_rhs, counter = transform_to_z3(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    transformed_lhs = z3.Bool(constraint.lhs.c)
+                    return transformed_lhs == transformed_rhs, counter
+
+                elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
+                    # with dimension transformations we consider the encoding
+                    lhs, counter = transform_dimension(
+                        constraint.lhs, counter, dimension_dict
+                    )
+                    rhs, counter = transform_dimension(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    return lhs == rhs, counter
+
+                else:
+                    # then we have an algebraic expression which means that we disregard the
+                    # first element of the encoding
+                    lhs, counter = transform_algebraic_expression(
+                        constraint.lhs, counter, dimension_dict
+                    )
+                    rhs, counter = transform_algebraic_expression(
+                        constraint.rhs, counter, dimension_dict
+                    )
+                    return lhs == rhs, counter
+
+            # The assumption here is that the LHS and RHS must be dimensions
+            elif constraint.op == op_neq:
+                assert is_dim(constraint.lhs)
+                assert is_dim(constraint.rhs)
+                lhs, counter = transform_dimension(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_dimension(
+                    constraint.rhs, counter, dimension_dict
+                )
+                if constraint.rhs == Dyn or constraint.lhs == Dyn:
+                    if constraint.rhs == Dyn:
+                        return lhs.arg(0) == 1, counter
+                    elif constraint.lhs == Dyn:
+                        return rhs.arg(0) == 1, counter
+
+                # if one of the instances is a number
+                elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
+                    if isinstance(constraint.lhs, int):
+                        return (
+                            z3.Or(
+                                [
+                                    rhs.arg(0) == 0,
+                                    z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
+                                ]
+                            ),
+                            counter,
+                        )
+
+                    elif isinstance(constraint.rhs, int):
+                        return (
+                            z3.Or(
+                                [
+                                    lhs.arg(0) == 0,
+                                    z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
+                                ]
+                            ),
+                            counter,
+                        )
+
+                else:
+                    return (
+                        z3.Or(
+                            [
+                                z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
+                                z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
+                                z3.And(
+                                    [
+                                        lhs.arg(0) != 0,
+                                        rhs.arg(0) != 0,
+                                        lhs.arg(1) != rhs.arg(1),
+                                    ]
+                                ),
+                            ]
+                        ),
+                        counter,
+                    )
+
+            elif constraint.op == op_leq:
+                # if the dimensions are not dyn, this will come into effect
+                # there would have been another constraint specifying if a given dimension
+                # is dyn or not
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs <= rhs, counter
+
+            elif constraint.op == op_gt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs > rhs, counter
+
+            elif constraint.op == op_lt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(
+                    constraint.lhs, counter, dimension_dict
+                )
+                rhs, counter = transform_algebraic_expression(
+                    constraint.rhs, counter, dimension_dict
+                )
+                return lhs < rhs, counter
+
+            else:
+                raise NotImplementedError("operation not yet implemented")
+
+        else:
+            raise NotImplementedError("Operation not yet implemented")
+
+    def transform_var(tensor, counter, dimension_dict):
+        """
+        Transforms tensor variables to a format understood by z3
+        Args:
+            tensor: Tensor variable or a tensor type potentially with variable dimensions
+        Returns: Transformed variable to a z3 format
+
+        """
+        if isinstance(tensor, TensorType):
+            res = []
+            for t in tensor.__args__:
+                transformed, counter = transform_dimension(t, counter, dimension_dict)
+                res.append(transformed)
+
+            assert len(res) <= 4
+            if len(tensor.__args__) == 1:
+                return tensor_type.tensor1(res[0]), counter
+            elif len(tensor.__args__) == 2:
+                return tensor_type.tensor2(res[0], res[1]), counter
+            elif len(tensor.__args__) == 3:
+                return tensor_type.tensor3(res[0], res[1], res[2]), counter
+            elif len(tensor.__args__) == 4:
+                return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
+
+        elif tensor == Dyn:
+            return z3_dyn, counter
+
+        elif isinstance(tensor, TVar):
+            return z3.Const(tensor.tvar, tensor_type), counter
+
+    def transform_dimension(dimension, counter, dimension_dict):
+        """
+        Takes a dimension variable or a number and transforms it to a tuple
+        according to our scheme
+        Args:
+            dimension: The dimension to be transformed
+            counter: variable tracking
+
+        Returns:  tuple and the current counter
+
+        """
+        if dimension == Dyn:
+            counter += 1
+            return D(0, z3.Int(counter)), counter
+        elif isinstance(dimension, int):
+            return D(1, dimension), counter
+        elif isinstance(dimension, DVar):
+            if dimension.c in dimension_dict:
+                return (
+                    D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)),
+                    counter,
+                )
+            else:
+                counter += 1
+                dimension_dict[dimension.c] = counter
+                return D(z3.Int(counter), z3.Int(dimension.c)), counter
+
+    def transform_algebraic_expression(expr, counter, dimension_dict):
+        """
+        Transforms an algebraic expression to z3 format
+        Args:
+            expr: An expression is either a dimension variable or an algebraic-expression
+
+
+        Returns: the transformed expression
+
+        """
+        assert is_algebraic_expression(expr) or is_dim(expr)
+
+        if is_dim(expr):
+            transformed, counter = transform_dimension(expr, counter, dimension_dict)
+            return transformed.arg(1), counter
+
+        elif isinstance(expr, Prod):
+            dims = []
+            for dim in expr.products:
+                assert is_dim(dim)
+                d, counter = transform_dimension(dim, counter, dimension_dict)
+                dims.append(d.arg(1))
+            return z3.Product(dims), counter
+
+        elif is_algebraic_expression(expr):
+            lhs, counter = transform_algebraic_expression(
+                expr.lhs, counter, dimension_dict
+            )
+            rhs, counter = transform_algebraic_expression(
+                expr.rhs, counter, dimension_dict
+            )
+
+            if expr.op == op_sub:
+                c = lhs - rhs
+
+            elif expr.op == op_add:
+                c = lhs + rhs
+
+            elif expr.op == op_div:
+                c = lhs / rhs
+
+            elif expr.op == op_mul:
+                c = lhs * rhs
+
+            elif expr.op == op_mod:
+                c = lhs % rhs
+
+            else:
+                raise NotImplementedError("operation not yet implemented")
+
+            return c, counter
+
+        else:
+            raise RuntimeError
+
+    def transform_all_constraints(traced, counter=0):
+        """
+        Given a trace, generates constraints and transforms them to z3 format
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(traced)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        # print(new_constraints.conjucts[0])
+        # print(*new_constraints.conjucts, sep='\n')
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+        # print(new_constraints)
+        # print(new_constraints.conjucts)
+        # new_constraints.conjucts = new_constraints.conjucts[:-1]
+        # print(*new_constraints.conjucts, sep='\n')
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+        # print(transformed)
+        return transformed
+
+    def iterate_till_fixed_point(constraints, counter):
+        """
+        Transform constraints till reaching a fixed point
+        """
+        old_c = None
+        while old_c != constraints:
+            old_c = constraints
+            constraints, counter = transform_constraint(constraints, counter)
+        return constraints, counter
+
+    def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
+        """
+        Takes a node and a graph and generates two sets of constraints.
+        One set constraints the node's constraints and another set
+        constraints the negation of the node's constraints
+        Args:
+            tracer_root: the root for getting the module instances
+            graph: the graph so far in the tracing process
+            node: node that represents a conditional
+            counter: variable tracking
+
+        Returns: Two sets of constraints. One with a conjunction with the
+        the conditional constraint and the other with a conjunction with
+        its negation.
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(tracer_root, graph)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        condition_constraint = new_constraints.conjucts[-1]
+
+        # we know the constraint is a conjunction where the last constraint is about the conditional
+        # so remove the last constraint
+        new_constraints.conjucts = new_constraints.conjucts[:-1]
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+
+        # since the function returns a list of one element, we get the first element
+        # we are only interested in the RHS in this case because the LHS just stores
+        # the result
+
+        # we make sure the constraint is of the form:
+        # c = b where b is a boolean expression
+        # and we consider b (constraint.rhs) for transformation
+        assert isinstance(condition_constraint.lhs, BVar)
+        assert is_bool_expr(condition_constraint.rhs)
+        condition_constraint_rhs = condition_constraint.rhs
+
+        # transform the condition constraint
+        condition_constraint_rhs, counter = iterate_till_fixed_point(
+            condition_constraint_rhs, counter
+        )
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+
+        transformed_condition_constraint, counter = transform_to_z3(
+            condition_constraint_rhs, counter, dimension_dict
+        )
+
+        negation_transformed_condition_constraint = z3.Not(
+            transformed_condition_constraint
+        )
+
+        return z3.And([transformed, transformed_condition_constraint]), z3.And(
+            [transformed, negation_transformed_condition_constraint]
+        )
+
+    def evaluate_conditional_with_constraints(
+        tracer_root, graph, node, counter=0, user_constraints=None
+    ):
+        """
+        Given an IR and a node representing a conditional, evaluate the conditional
+        and its negation
+        Args:
+            tracer_root: Tracer root for module instances
+            node: The node to be evaluated
+
+        Returns: the results of evaluating the condition and the negation with
+        the rest of the constraints
+
+        """
+
+        (
+            transformed_positive,
+            transformed_negative,
+        ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter)
+
+        s = z3.Solver()
+        s.add(transformed_positive)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        condition = s.check()
+
+        s = z3.Solver()
+        s.add(transformed_negative)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        negation = s.check()
+        return condition, negation
+
+except ImportError:
+    HAS_Z3 = False
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b160ec8de70f950db66cbe51d3657fbaf6b3aaf1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
@@ -0,0 +1,58 @@
+from torch.fx.experimental.migrate_gradual_types.constraint import (
+    BinConstraintD,
+    BVar,
+    DVar,
+    TVar,
+)
+from torch.fx.experimental.migrate_gradual_types.operation import op_leq
+
+
+def gen_tvar(curr: int) -> tuple[TVar, int]:
+    """
+    Generate a tensor variable
+    :param curr: The current counter
+    :return: a tensor variable and the updated counter
+    """
+    curr += 1
+    return TVar(curr), curr
+
+
+def gen_dvar(curr: int) -> tuple[DVar, int]:
+    """
+    Generate a dimension variable
+    :param curr: the current counter
+    :return: a dimension variable and an updated counter
+    """
+    curr += 1
+    return DVar(curr), curr
+
+
+def gen_bvar(curr: int) -> tuple[BVar, int]:
+    """
+    Generate a boolean variable
+    :param curr: the current counter
+    :return: a boolean variable and an updated counter
+    """
+    curr += 1
+    return BVar(curr), curr
+
+
+def gen_tensor_dims(n: int, curr: int) -> tuple[list[DVar], int]:
+    """
+    Generate a list of tensor dimensions
+    :param n:  the number of dimensions
+    :param curr: the current counter
+    :return: a list of dimension variables and an updated counter
+    """
+    dims = []
+    for _ in range(n):
+        dvar, curr = gen_dvar(curr)
+        dims.append(dvar)
+    return dims, curr
+
+
+def gen_nat_constraints(list_of_dims: list[DVar]) -> list[BinConstraintD]:
+    """
+    Generate natural number constraints for dimensions
+    """
+    return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..939f4865ab7d982289303093db2024eda6603521
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
@@ -0,0 +1,30 @@
+try:
+    import z3  # type: ignore[import]
+
+    HAS_Z3 = True
+    # dynamic type
+    dyn = z3.DeclareSort("Dyn")
+    dyn_type = z3.Const("dyn", dyn)
+
+    # dimension
+    dim = z3.Datatype("dim")
+    dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort()))
+    dim = dim.create()
+
+    # tensors
+    tensor_type = z3.Datatype("TensorType")
+    tensor_type.declare("Dyn", ("dyn", dyn))
+    tensor_type.declare("tensor1", ("0", dim))
+    tensor_type.declare("tensor2", ("0", dim), ("1", dim))
+    tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim))
+    tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim))
+    tensor_type = tensor_type.create()
+
+    # create dimension
+    D = dim.dim
+
+    z3_dyn = tensor_type.Dyn(dyn_type)
+
+
+except ImportError:
+    HAS_Z3 = False
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/normalize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2dd3c962bbe4d274284d8db26bac70a1a170bed
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/normalize.py
@@ -0,0 +1,164 @@
+# mypy: allow-untyped-defs
+import operator
+from collections.abc import Callable
+from typing import Any, Optional
+
+import torch
+import torch.fx
+import torch.fx as fx
+from torch.fx import Proxy, Transformer
+from torch.fx.node import Argument, map_aggregate, Node, Target
+from torch.fx.operator_schemas import (
+    create_type_hint,
+    normalize_function,
+    normalize_module,
+)
+
+from .schema_type_annotation import AnnotateTypesWithSchema
+
+
+class NormalizeArgs(Transformer):
+    """
+    Normalize arguments to Python targets. This means that
+    `args/kwargs` will be matched up to the module/functional's
+    signature and rewritten to exclusively kwargs in positional order
+    if `normalize_to_only_use_kwargs` is true. Also populates default
+    values. Does not support positional-only parameters or varargs
+    parameters (*args, **kwargs).
+
+    If the nodes have 'type' metadata, it will use it to disambiguate
+    overloads. Otherwise, it will throw an error.
+
+    Example usage:
+        m = torchvision.models.resnet18()
+        traced = torch.fx.symbolic_trace(m)
+        traced = NormalizeArgs(traced).transform()
+    """
+
+    def __init__(
+        self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
+    ):
+        super().__init__(module)
+        self.node_map: dict[Proxy, Node] = {}
+        self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
+
+    def run_node(self, n: Node) -> Any:
+        args, kwargs = self.fetch_args_kwargs_from_env(n)
+
+        def get_type(arg):
+            if isinstance(arg, fx.Node):
+                return n.meta.get("type")
+            return type(arg)
+
+        arg_types = map_aggregate(n.args, get_type)
+        assert isinstance(arg_types, tuple)
+        arg_types = tuple(create_type_hint(i) for i in arg_types)
+        kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
+        if n.op == "call_function":
+            out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
+        else:
+            out = super().run_node(n)
+        if n.op != "output":
+            self.node_map[out] = n
+            out.node.meta = n.meta
+            out.node.type = n.type
+        return out
+
+    def call_function(
+        self,
+        target: Target,
+        args: tuple[Argument, ...],
+        kwargs: dict[str, Any],
+        arg_types: Optional[tuple[Any, ...]] = None,
+        kwarg_types: Optional[dict[str, Any]] = None,
+    ):
+        assert callable(target)
+        new_args_and_kwargs = normalize_function(
+            target,
+            args,  # type: ignore[arg-type]
+            kwargs,
+            arg_types,  # type: ignore[arg-type]
+            kwarg_types,
+            self.normalize_to_only_use_kwargs,
+        )
+        if new_args_and_kwargs:
+            new_args, new_kwargs = new_args_and_kwargs
+            return self.tracer.create_proxy(
+                "call_function", target, new_args, new_kwargs
+            )
+        else:
+            return super().call_function(target, args, kwargs)
+
+    def call_module(
+        self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+    ):
+        assert isinstance(target, str)
+        new_args_and_kwargs = normalize_module(
+            self.module,
+            target,
+            args,  # type: ignore[arg-type]
+            kwargs,
+            self.normalize_to_only_use_kwargs,
+        )
+        if new_args_and_kwargs:
+            new_args, new_kwargs = new_args_and_kwargs
+            return super().call_module(target, new_args, new_kwargs)
+        else:
+            return super().call_module(target, args, kwargs)
+
+
+class NormalizeOperators(AnnotateTypesWithSchema):
+    """
+    Normalize callsites that are different ways of "spelling" the same
+    invocation into a single, canonical call. Currently supports:
+
+    1. Normalize operators (e.g. operator.add) to the `torch` ops they
+       ultimately invoke (e.g. torch.add) when it is possible to statically
+       reason that
+
+    Example usage:
+
+        m = torchvision.models.resnet18()
+
+        traced = torch.fx.symbolic_trace(m)
+
+        traced = NormalizeOperators(traced).transform()
+    """
+
+    binary_magic_method_remap: dict[
+        Callable[[Any, Any], Any], Callable[[Any, Any], Any]
+    ] = {
+        torch.add: operator.add,
+        torch.mul: operator.mul,
+        torch.sub: operator.sub,
+        torch.div: operator.truediv,
+        torch.floor_divide: operator.floordiv,
+        torch.remainder: operator.mod,
+        torch.eq: operator.eq,
+        torch.ne: operator.ne,
+        torch.lt: operator.lt,
+        torch.le: operator.le,
+        torch.gt: operator.gt,
+        torch.ge: operator.ge,
+    }
+
+    def call_function(
+        self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+    ):
+        # Normalize operators according to the magic methods implemented on tensors here:
+        # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
+
+        assert callable(target)
+
+        if target in self.binary_magic_method_remap:
+            if len(args) != 2:
+                return super().call_function(target, args, kwargs)
+            lhs, rhs = args
+
+            return super().call_function(
+                target=self.binary_magic_method_remap[target],
+                args=(lhs, rhs),
+                kwargs={},
+            )
+
+        return super().call_function(target, args, kwargs)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/optimization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..219e6f66c7bf52d8f4bf6384b871dee4a9a494d1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/optimization.py
@@ -0,0 +1,490 @@
+# mypy: allow-untyped-defs
+import copy
+import logging
+import operator
+import time
+from collections import defaultdict
+from collections.abc import Iterable
+from enum import Enum
+from typing import Any, cast, Optional
+
+import torch
+import torch.fx as fx
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.mkldnn as th_mkldnn
+from torch.fx.node import Argument, Target
+from torch.fx.passes.shape_prop import ShapeProp
+from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
+
+
+__all__ = [
+    "matches_module_pattern",
+    "replace_node_module",
+    "fuse",
+    "remove_dropout",
+    "extract_subgraph",
+    "modules_to_mkldnn",
+    "reset_modules",
+    "MklSubgraph",
+    "gen_mkl_autotuner",
+    "use_mkl_length",
+    "UnionFind",
+    "optimize_for_inference",
+]
+
+
+def _parent_name(target: str) -> tuple[str, str]:
+    """
+    Splits a qualname into parent path and last atom.
+    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
+    """
+    *parent, name = target.rsplit(".", 1)
+    return parent[0] if parent else "", name
+
+
+# Works for length 2 patterns with 2 modules
+def matches_module_pattern(
+    pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
+):
+    if len(node.args) == 0:
+        return False
+    nodes: tuple[Any, fx.Node] = (node.args[0], node)
+    for expected_type, current_node in zip(pattern, nodes):
+        if not isinstance(current_node, fx.Node):
+            return False
+        if current_node.op != "call_module":
+            return False
+        if not isinstance(current_node.target, str):
+            return False
+        if current_node.target not in modules:
+            return False
+        if type(modules[current_node.target]) is not expected_type:
+            return False
+    return True
+
+
+def replace_node_module(
+    node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
+):
+    assert isinstance(node.target, str)
+    parent_name, name = _parent_name(node.target)
+    modules[node.target] = new_module
+    setattr(modules[parent_name], name, new_module)
+
+
+def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
+    """
+    Fuses convolution/BN and linear/BN layers for inference purposes.
+    Will deepcopy your model by default, but can modify the model inplace as well.
+    """
+    patterns = [
+        (nn.Conv1d, nn.BatchNorm1d),
+        (nn.Conv2d, nn.BatchNorm2d),
+        (nn.Conv3d, nn.BatchNorm3d),
+        (nn.Linear, nn.BatchNorm1d),
+    ]
+    if not inplace:
+        model = copy.deepcopy(model)
+    if not no_trace or not isinstance(model, torch.fx.GraphModule):
+        fx_model = fx.symbolic_trace(model)
+    else:
+        fx_model = model
+    modules = dict(fx_model.named_modules())
+    new_graph = copy.deepcopy(fx_model.graph)
+
+    for pattern in patterns:
+        for node in new_graph.nodes:
+            if matches_module_pattern(pattern, node, modules):
+                if len(node.args[0].users) > 1:
+                    # Output of conv/linear is used by other nodes
+                    continue
+                first_layer = modules[node.args[0].target]
+                bn = modules[node.target]
+                if not bn.track_running_stats:
+                    continue
+                if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
+                    fused_layer = fuse_conv_bn_eval(first_layer, bn)
+                else:  # nn.Linear
+                    fused_layer = fuse_linear_bn_eval(first_layer, bn)
+                replace_node_module(node.args[0], modules, fused_layer)
+                node.replace_all_uses_with(node.args[0])
+                new_graph.erase_node(node)
+    return fx.GraphModule(fx_model, new_graph)
+
+
+def remove_dropout(model: nn.Module) -> nn.Module:
+    """
+    Removes all dropout layers from the module.
+    """
+    fx_model = fx.symbolic_trace(model)
+
+    class DropoutRemover(torch.fx.Transformer):
+        def call_module(
+            self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+        ) -> Any:
+            if isinstance(self.submodules[target], nn.Dropout):
+                assert len(args) == 1
+                return args[0]
+            else:
+                return super().call_module(target, args, kwargs)
+
+    return DropoutRemover(fx_model).transform()
+
+
+def extract_subgraph(
+    orig_module: nn.Module,
+    nodes: list[fx.Node],
+    inputs: list[fx.Node],
+    outputs: list[fx.Node],
+):
+    """
+    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
+    """
+    new_graph = fx.Graph()
+    env: dict[fx.Node, fx.Node] = {}
+    for input in inputs:
+        new_node = new_graph.placeholder(input.name)
+        env[input] = new_node
+    for node in nodes:
+        new_node = new_graph.node_copy(node, lambda x: env[x])
+        env[node] = new_node
+    new_graph.output([env[output] for output in outputs])
+    new_graph.lint()
+    return fx.GraphModule(orig_module, new_graph)
+
+
+mkldnn_supported = [
+    nn.Conv2d,
+    nn.Linear,
+    nn.BatchNorm2d,
+    nn.ReLU,
+    nn.MaxPool2d,
+    nn.AvgPool2d,
+    nn.AdaptiveAvgPool2d,
+    torch.relu,
+    torch.transpose,
+    torch.sigmoid,
+    F.relu,
+    F.avg_pool2d,
+    F.adaptive_avg_pool2d,
+]
+# These are operators that may not be convertible into MKLDNN ops (e.g. the
+# args are scalar values). Thus, we only include them in the subgraph if their
+# arguments are already in MKLDNN.
+# TODO: Determine whether this can be removed after type inference.
+mkldnn_supported_unknown = [operator.add, operator.mul]
+mkldnn_map = {
+    nn.Conv2d: th_mkldnn.MkldnnConv2d,
+    nn.Linear: th_mkldnn.MkldnnLinear,
+    nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
+}
+
+
+def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
+    """
+    For each node, if it's a module that can be preconverted into MKLDNN,
+    then we do so and create a mapping to allow us to convert from the MKLDNN
+    version of the module to the original.
+    """
+    old_modules: dict[nn.Module, nn.Module] = {}
+    for node in nodes:
+        if node.op == "call_module":
+            assert isinstance(node.target, str)
+            cur_module = modules[node.target]
+            if type(cur_module) in mkldnn_map:
+                # pyrefly: ignore [index-error]
+                new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
+                assert isinstance(new_module, nn.Module)
+                old_modules[new_module] = copy.deepcopy(cur_module)
+                replace_node_module(node, modules, new_module)
+    return old_modules
+
+
+def reset_modules(
+    nodes: list[fx.Node],
+    modules: dict[str, nn.Module],
+    old_modules: dict[nn.Module, nn.Module],
+):
+    """
+    Maps each module that's been changed with `modules_to_mkldnn` back to its
+    original.
+    """
+    for node in nodes:
+        if node.op == "call_module":
+            assert isinstance(node.target, str)
+            cur_module = modules[node.target]
+            if cur_module in old_modules:
+                replace_node_module(node, modules, old_modules[cur_module])
+
+
+class MklSubgraph:
+    def __init__(self, fx_graph: fx.Graph):
+        self.fx_graph = fx_graph
+        self.nodes: list[fx.Node] = []
+        self.start_nodes: list[fx.Node] = []
+        self.end_nodes: list[fx.Node] = []
+
+
+def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
+    """
+    This generates a heuristic that can be passed into `optimize_for_inference` that
+    determines whether a subgraph should be run in MKL by running it with the example_inputs.
+
+    Example usage:
+        heuristic = gen_mkl_autotuner(example_inputs, iters=10)
+        fast_model = optimization.optimize_for_inference(model, heuristic)
+    """
+    fx_model = None
+    old_modules = None
+
+    def use_mkl_heuristic(graph: MklSubgraph) -> bool:
+        nonlocal fx_model, old_modules
+        input_nodes = graph.start_nodes
+        if fx_model is None:
+            fx_model = graph.fx_graph.owning_module
+            old_modules = graph.fx_graph.old_modules  # type: ignore[attr-defined]
+            ShapeProp(fx_model).propagate(example_inputs)
+        sample_inputs = [torch.randn(node.shape) for node in input_nodes]  # type: ignore[attr-defined]
+        output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
+        submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
+
+        def benchmark(f):
+            for _ in range(warmup):
+                f()
+            begin = time.time()
+            for _ in range(iters):
+                f()
+            return time.time() - begin
+
+        mkl_time = benchmark(
+            lambda: [
+                i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
+            ]
+        )
+
+        reset_modules(
+            submodule.graph.nodes,
+            dict(submodule.named_modules()),
+            # pyrefly: ignore [bad-argument-type]
+            old_modules,
+        )
+        no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
+        return mkl_time < no_mkl_time
+
+    return use_mkl_heuristic
+
+
+def use_mkl_length(graph: MklSubgraph) -> bool:
+    """
+    This is a heuristic that can be passed into `optimize_for_inference` that
+    determines whether a subgraph should be run in MKL by checking if there
+    are more than 2 nodes in it
+    """
+    return len(graph.nodes) > 2
+
+
+class UnionFind:
+    def __init__(self, n):
+        self.parent: list[Optional[int]] = [None] * n
+        self.size: list[int] = [0] * n
+
+    def make_set(self, v: int):
+        self.parent[v] = v
+        self.size[v] = 1
+
+    def find(self, v: int) -> int:
+        par = self.parent[v]
+        if v == par:
+            return v
+        assert par is not None
+        self.parent[v] = self.find(par)
+        return cast(int, self.parent[v])
+
+    def join(self, a: int, b: int):
+        a, b = self.find(a), self.find(b)
+        if a == b:
+            return a
+        if self.size[a] < self.size[b]:
+            a, b = b, a
+        self.parent[b] = a
+        self.size[a] += self.size[b]
+
+
+def optimize_for_inference(
+    model: torch.nn.Module,
+    pass_config: Optional[dict[str, Any]] = None,
+    tracer: type[fx.Tracer] = fx.Tracer,
+) -> torch.nn.Module:
+    """
+    Performs a set of optimization passes to optimize a model for the
+    purposes of inference. Specifically, the passes that are run are:
+    1. Conv/BN fusion
+    2. Dropout removal
+    3. MKL layout optimizations
+
+    The third optimization takes a function `use_mkl_heuristic` that's used
+    to determine whether a subgraph should be explicitly run in MKL layout.
+
+    Note: As FX does not currently handle aliasing, this pass currently
+    assumes nothing aliases. If that isn't true, use at your own risk.
+    """
+    default_pass_config = {
+        "conv_bn_fuse": True,
+        "remove_dropout": True,
+        "mkldnn_layout_optimize": {"heuristic": use_mkl_length},
+    }
+    if pass_config is None:
+        pass_config = {}
+    default_pass_config.update(pass_config)
+
+    if default_pass_config["conv_bn_fuse"]:
+        model = fuse(model)
+    if default_pass_config["remove_dropout"]:
+        model = remove_dropout(model)
+    if default_pass_config["mkldnn_layout_optimize"] is False:
+        return model
+    if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
+        raise RuntimeError("mkldnn_layout_optimize config is not a dict")
+    if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
+        raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
+    use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
+
+    cur_tracer = tracer()
+    fx_graph = cur_tracer.trace(copy.deepcopy(model))
+    fx.GraphModule(cur_tracer.root, fx_graph)
+    modules: dict[str, nn.Module] = dict(model.named_modules())
+
+    class MklSupport(Enum):
+        NO = 1
+        YES = 2
+        UNKNOWN = 3
+
+    # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
+    # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
+    # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
+    # a MKLDNN node if its inputs are MKLDNN nodes.
+    for node in list(fx_graph.nodes):
+        supports_mkldnn = MklSupport.NO
+        if node.op == "call_module":
+            cur_module = modules[node.target]
+            if type(cur_module) in mkldnn_supported:
+                supports_mkldnn = MklSupport.YES
+                sample_parameter = next(cur_module.parameters(), None)
+                if sample_parameter is not None:
+                    assert sample_parameter.dtype == torch.float, (
+                        "this pass is only for torch.float modules"
+                    )
+                    assert sample_parameter.device == torch.device("cpu"), (
+                        "this pass is only for CPU modules"
+                    )
+        elif node.op == "call_function":
+            if node.target in mkldnn_supported:
+                supports_mkldnn = MklSupport.YES
+            elif node.target in mkldnn_supported_unknown:
+                supports_mkldnn = MklSupport.UNKNOWN
+
+        if supports_mkldnn != MklSupport.NO:
+            if supports_mkldnn == MklSupport.UNKNOWN:
+                if not any(arg.target == "to_dense" for arg in node.args):
+                    continue
+            with fx_graph.inserting_before(node):
+                mkldnn_args = fx.map_arg(
+                    node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
+                )
+
+            node.args = cast(tuple[fx.node.Argument], mkldnn_args)
+
+            with fx_graph.inserting_after(node):
+                dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
+                node.replace_all_uses_with(dense_x)
+                dense_x.args = (node,)
+
+    # Does pre-conversion of all modules into MKLDNN (when possible)
+    old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
+    fx_graph.old_modules = old_modules  # type: ignore[attr-defined]
+
+    # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
+    for node in fx_graph.nodes:
+        if node.op == "call_method" and node.target == "to_dense":
+            prv_node = node.args[0]
+            users = list(node.users)
+            for user in users:
+                if user.op == "call_method" and user.target == "to_mkldnn":
+                    user.replace_all_uses_with(prv_node)
+                    fx_graph.erase_node(user)
+            if len(node.users) == 0:
+                fx_graph.erase_node(node)
+
+    num_nodes = len(fx_graph.nodes)
+    uf = UnionFind(num_nodes)
+
+    def get_color(n):
+        if hasattr(n, "color"):  # Current node is part of a MKL subgraph
+            return uf.find(n.color)
+        if hasattr(n, "start_color"):  # Current node is input to MKL subgraph
+            return uf.find(n.start_color)
+        return None
+
+    # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
+    # of input nodes (which are only `to_mkldnn` calls), output nodes
+    # (`to_dense` calls), and intermediate nodes, which are run entirely on
+    # MKLDNN layout tensors.
+    #
+    # Specifically, this code does a flood fill on a directed acyclic graph
+    # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
+    # If every node only had one input, this would be sufficient. However, in
+    # the case that a node has multiple inputs coming from different start
+    # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
+    # using a Disjoint Set Union.
+    for cur_idx, node in enumerate(fx_graph.nodes):
+        if node.op == "call_method" and node.target == "to_mkldnn":
+            node.start_color = cur_idx
+            uf.make_set(cur_idx)
+        elif node.op == "call_method" and node.target == "to_dense":
+            assert get_color(node.args[0]) is not None
+            node.end_color = get_color(node.args[0])
+        else:
+            cur_colors = [
+                get_color(i)
+                for i in node.all_input_nodes
+                if isinstance(i, fx.Node)
+                if get_color(i) is not None
+            ]
+
+            if len(cur_colors) == 0:
+                continue
+            assert not any(i is None for i in cur_colors)
+            cur_colors = sorted(cur_colors)
+            node.color = cur_colors[0]
+            for other_color in cur_colors[1:]:
+                uf.join(cur_colors[0], other_color)
+
+    mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
+    for node in fx_graph.nodes:
+        if hasattr(node, "color"):
+            mkldnn_graphs[uf.find(node.color)].nodes.append(node)
+        if hasattr(node, "start_color"):
+            mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
+        if hasattr(node, "end_color"):
+            mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
+
+    # Now that we have all the subgraphs, we need to decide which MKLDNN
+    # subgraphs we actually want to keep in MKLDNN.
+    for graph in mkldnn_graphs.values():
+        if not use_mkl_heuristic(graph):
+            for node in graph.start_nodes + graph.end_nodes:
+                prv = node.args[0]
+                node.replace_all_uses_with(prv)  # type: ignore[arg-type]
+                fx_graph.erase_node(node)
+            reset_modules(graph.nodes, modules, old_modules)
+
+    mkldnn_conversions = 0
+    for node in fx_graph.nodes:
+        if node.target == "to_mkldnn" or node.target == "to_dense":
+            mkldnn_conversions += 1
+
+    logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
+    fx_graph.lint()
+    result = fx.GraphModule(model, fx_graph)
+    return result
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3658dd1a9ce96aff26adbc5f47818e9e57e13d35
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py
@@ -0,0 +1,317 @@
+# mypy: allow-untyped-defs
+from enum import Enum
+from typing import NamedTuple
+
+from torch.fx.node import map_arg, Node
+
+
+class Partition:
+    """Partition class contains all the information about an individual partition.
+    It also provides necessary methods for manipulation the partition.
+    """
+
+    def __init__(self, partition_id: int) -> None:
+        self.nodes: set[Node] = set()
+        self.partition_id = partition_id
+        self.parents: set[Partition] = set()
+        self.children: set[Partition] = set()
+        self.bfs_level: int = -1
+        self.used_mem_bytes: int = 0
+        self.logical_device_ids: list[int] = []
+
+    def __str__(self):
+        return str(self.partition_id)
+
+    def recalculate_mem_size(self):
+        self.used_mem_bytes = 0
+        for node in self.nodes:
+            self.used_mem_bytes += get_extra_size_of(node, self.nodes)
+
+    def add_node(self, node):
+        input_nodes: dict[Node, None] = {}
+        map_arg(node.args, input_nodes.setdefault)
+        map_arg(node.kwargs, input_nodes.setdefault)
+        # Add current node's input nodes if they are placeholder or constants
+        for n in input_nodes:
+            if n.op in {"placeholder", "get_attr"}:
+                self.nodes.add(n)
+        self.nodes.add(node)
+        self.recalculate_mem_size()
+
+    def remove_node(self, node):
+        # Remove a node only if the node is in the partition
+        if node in self.nodes:
+            self.nodes.remove(node)
+            # Collect the node's input nodes
+            input_nodes: dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # Check if an input node is a placeholder or get_attr,
+            # and this input node is not used by some other nodes in this partition,
+            # the remove this input node
+            for input_node in input_nodes:
+                if all(
+                    n not in self.nodes for n in input_node.users
+                ) and input_node.op in {"placeholder", "get_attr"}:
+                    self.nodes.remove(input_node)
+            self.recalculate_mem_size()
+
+
+class Device(NamedTuple):
+    name: str
+    available_mem_bytes: int
+    logical_id: int
+
+
+class NodeLatency(NamedTuple):
+    # Latency due to the memory bandwidth
+    mem_latency_sec: float
+    # Latency due to the computation
+    computer_latency_sec: float
+
+
+class PartitionLatency(NamedTuple):
+    # Sum of all nodes' memory latency on the critical path
+    mem_latency_sec: float
+    # Sum of all nodes' compute latency on the critical path
+    computer_latency_sec: float
+    # Latency of the critical path
+    overall_latency_sec: float
+
+
+class PartitionMode(Enum):
+    size_based = 0
+    sparse_nn = 1
+    cost_aware = 2
+    kl_based = 3
+    aot_based = 4
+
+
+class PartitionerConfig(NamedTuple):
+    devices: list[Device]
+    mode: PartitionMode = PartitionMode.size_based
+    transfer_rate_bytes_per_sec: float = 0.0
+    node_to_latency_mapping: dict[Node, NodeLatency] = {}
+    node_to_partition_mapping: dict[Node, int] = {}
+    partition_to_logical_device_mapping: dict[int, list[int]] = {}
+    # Saturate host by replicating partitions to the remaining idle devices.
+    saturate_host: bool = False
+
+
+def get_extra_size_of(node: Node, nodes: set[Node]) -> int:
+    """Given a node and a set of nodes,
+    this function return the extra size that needed
+    if this node is included in this set.
+    """
+    # Find all its input nodes
+    input_nodes: dict[Node, None] = {}
+    map_arg(node.args, input_nodes.setdefault)
+    map_arg(node.kwargs, input_nodes.setdefault)
+    # Calculate total size of related nodes
+    total_size_of_input_nodes = 0
+    for n in input_nodes:
+        # Make sure this node hasn't been in this set yet
+        if n not in nodes:
+            size_bytes = getattr(n, "size_bytes", None)
+            if size_bytes:
+                total_size_of_input_nodes += size_bytes.output_size
+            else:
+                raise RuntimeError("node has no size_bytes attr")
+    # Don't forget the op node itself
+    size_bytes = getattr(node, "size_bytes", None)
+    if size_bytes:
+        total_size_of_input_nodes += size_bytes.total_size
+    else:
+        raise RuntimeError("node has no size_bytes attr")
+    return total_size_of_input_nodes
+
+
+def get_latency_of_one_partition(
+    partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency]
+) -> PartitionLatency:
+    """Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
+
+    def get_top_nodes(partition: Partition) -> list[Node]:
+        """Given a partition, return a list of nodes on the top bfs level"""
+        top_nodes: list[Node] = []
+        for node in partition.nodes:
+            # Skip placeholder and get_attr nodes
+            if node.op in {"placeholder", "get_attr"}:
+                continue
+            input_nodes: dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # If a node has no input nodes in this partition,
+            # or its input nodes in this partition are placeholders and get_attrs
+            # this node is on the top bfs level in this partition
+            if not any(
+                n in partition.nodes and n.op not in {"placeholder", "get_attr"}
+                for n in input_nodes
+            ):
+                top_nodes.append(node)
+        return top_nodes
+
+    def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
+        """Given a top node of a partition, this function returns
+        the latency of the critical path in the partition
+        """
+        node_latency = node_to_latency_mapping[node]
+        # Calculate the current overall latency of the partition
+        overall_latency_sec = partition_latency.overall_latency_sec + max(
+            node_latency.computer_latency_sec, node_latency.mem_latency_sec
+        )
+        # Update the mem latency of this path
+        mem_latency_sec = (
+            partition_latency.mem_latency_sec + node_latency.mem_latency_sec
+        )
+        # Update the compute latency of this path
+        computer_latency_sec = (
+            partition_latency.computer_latency_sec + node_latency.computer_latency_sec
+        )
+        # Get all users of this node that are in this partition
+        users = set(node.users).intersection(partition.nodes)
+        if users:
+            max_latency = PartitionLatency(
+                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+            )
+            for n in users:
+                # Get new partition latency recursively
+                new_partition_latency = dfs_helper(
+                    n,
+                    PartitionLatency(
+                        mem_latency_sec, computer_latency_sec, overall_latency_sec
+                    ),
+                )
+                if (
+                    new_partition_latency.overall_latency_sec
+                    > max_latency.overall_latency_sec
+                ):
+                    max_latency = new_partition_latency
+            return max_latency
+        # If there is no user, the node is at bottom of the partition
+        return PartitionLatency(
+            mem_latency_sec, computer_latency_sec, overall_latency_sec
+        )
+
+    # Main part starts
+    # Get all top level nodes of this partition
+    top_nodes = get_top_nodes(partition)
+    critical_path_latency = PartitionLatency(
+        mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+    )
+    # Go through all top nodes and find the largest latency (critical pass latency)
+    for node in top_nodes:
+        partition_latency = dfs_helper(
+            node,
+            PartitionLatency(
+                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+            ),
+        )
+        if (
+            partition_latency.overall_latency_sec
+            > critical_path_latency.overall_latency_sec
+        ):
+            critical_path_latency = partition_latency
+    return critical_path_latency
+
+
+def get_partition_to_latency_mapping(
+    partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency]
+) -> dict[Partition, PartitionLatency]:
+    """Given all the partitions and node_to_latency_mapping dictionary,
+    return a mapping dictionary of each partition to its overall latency
+    """
+    partition_to_latency_mapping: dict[Partition, PartitionLatency] = {}
+    # Go through each partition and get its latency
+    for partition in partitions:
+        partition_latency = get_latency_of_one_partition(
+            partition, node_to_latency_mapping
+        )
+        partition_to_latency_mapping[partition] = partition_latency
+    return partition_to_latency_mapping
+
+
+def get_comm_latency_between(
+    parent_partition: Partition,
+    child_partition: Partition,
+    transfer_rate_bytes_per_sec: float,
+):
+    """Given two partitions (parent and child),
+    calculate the communication latency between the two.
+    """
+    # If two partitions are on the same device, the comm latency is 0.
+    if (
+        parent_partition.logical_device_ids != []
+        and child_partition.logical_device_ids != []
+        and parent_partition.logical_device_ids == child_partition.logical_device_ids
+    ):
+        return 0.0
+    # Keep tracking the communication size between parent and child
+    comm_size = 0
+    # Keep tracking all the counted node
+    visited_nodes = set()
+    # Go through all nodes in the child partition
+    # If a node has input nodes from the parent partition,
+    # the output size of those input nodes will be counted
+    # and added to comm_size
+    for node in child_partition.nodes:
+        input_nodes: dict[Node, None] = {}
+        map_arg(node.args, input_nodes.setdefault)
+        map_arg(node.kwargs, input_nodes.setdefault)
+        for n in input_nodes:
+            if n in parent_partition.nodes and n not in visited_nodes:
+                size_bytes = getattr(n, "size_bytes", None)
+                if size_bytes is not None:
+                    comm_size += size_bytes.output_size
+                visited_nodes.add(n)
+    return comm_size / transfer_rate_bytes_per_sec
+
+
+def get_latency_of_partitioned_graph(
+    partitions: list[Partition],
+    partition_to_latency_mapping: dict[Partition, PartitionLatency],
+    transfer_rate_bytes_per_sec: float,
+):
+    """Given all partitions in a graph, find the critical path among all partitions
+    and return its latency as the latency of the whole graph
+    """
+
+    def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
+        """This function helps to recursively get the latency of a path of partitions"""
+        # Update latency by adding current partition's latency
+        latency_so_far_sec += partition_to_latency_mapping[
+            partition
+        ].overall_latency_sec
+
+        if partition.children:
+            max_latency_sec = 0.0
+            for child in partition.children:
+                # Calculate latency between
+                comm_latency_sec = get_comm_latency_between(
+                    partition, child, transfer_rate_bytes_per_sec
+                )
+                new_latency_sec = dfs_helper(
+                    child, latency_so_far_sec + comm_latency_sec
+                )
+                if new_latency_sec > max_latency_sec:
+                    max_latency_sec = new_latency_sec
+            return max_latency_sec
+        return latency_so_far_sec
+
+    def get_top_partitions(partitions: list[Partition]) -> list[Partition]:
+        """This function is to return all the partitions without parents
+        as the starting points of all the paths
+        """
+        # If a partition has no parents, then it is a top partition
+        top_partitions = [
+            partition for partition in partitions if len(partition.parents) == 0
+        ]
+        return top_partitions
+
+    top_partitions = get_top_partitions(partitions)
+    critical_path_latency_sec = 0.0
+    for partition in top_partitions:
+        latency_sec = dfs_helper(partition, 0.0)
+        if latency_sec > critical_path_latency_sec:
+            critical_path_latency_sec = latency_sec
+    return critical_path_latency_sec
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f763ad2ee2cfc1e3bd500f1da9877144aca1a3b2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py
@@ -0,0 +1,2817 @@
+# mypy: allow-untyped-decorators
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import functools
+import inspect
+import logging
+import operator
+import threading
+import typing
+import typing_extensions
+import weakref
+from collections import defaultdict, OrderedDict
+from collections.abc import Callable, Generator, Mapping, Sequence
+from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
+from dataclasses import dataclass
+from typing import (
+    Any,
+    Concatenate,
+    Optional,
+    overload,
+    Protocol,
+    TYPE_CHECKING,
+    TypeVar,
+    Union,
+)
+from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack
+from weakref import WeakKeyDictionary
+
+import torch
+import torch._ops
+import torch.fx as fx
+import torch.fx.traceback as fx_traceback
+import torch.utils._pytree as pytree
+from torch import SymBool, SymInt, Tensor
+from torch._dispatch.python import enable_python_dispatcher
+from torch._library.fake_class_registry import FakeScriptObject
+from torch._library.opaque_object import is_opaque_type
+from torch._logging import trace_structured
+from torch._ops import HigherOrderOperator
+from torch._subclasses.fake_impls import fast_detach
+from torch._subclasses.fake_tensor import (
+    FakeTensor,
+    FakeTensorMode,
+    is_fake,
+    unset_fake_temporarily,
+)
+from torch._subclasses.meta_utils import is_sparse_any
+from torch.fx import GraphModule, Proxy, Tracer
+from torch.fx.graph_module import _assign_attr
+from torch.fx.node import (
+    _side_effectful_need_to_be_preserved_pre_dispatch,
+    Argument,
+    Target,
+)
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.nn import Module
+from torch.overrides import TorchFunctionMode
+from torch.utils._python_dispatch import (
+    _disable_infra_mode,
+    _push_mode,
+    _unset_infra_mode,
+    autograd_would_have_decomposed,
+    TorchDispatchMode,
+)
+from torch.utils._stats import count
+from torch.utils._thunk import Thunk
+from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary
+
+from ._backward_state import BackwardState
+from .sym_node import SymNode
+
+
+if TYPE_CHECKING:
+    import types
+    from collections.abc import MutableMapping
+
+    import sympy
+
+    from torch._ops import OpOverload
+    from torch.fx._symbolic_trace import PHBase
+    from torch.types import BoolLikeType, FloatLikeType, IntLikeType
+
+__all__ = [
+    "PythonKeyTracer",
+    "dispatch_trace",
+    "make_fx",
+    "DecompositionInterpreter",
+    "selective_decompose",
+    "py_sym_types",
+    "get_innermost_proxy_mode",
+    "get_proxy_mode",
+    "handle_sym_dispatch",
+    "maybe_enable_thunkify",
+    "maybe_disable_thunkify",
+]
+
+_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
+
+_AnyScriptObject = (torch.ScriptObject, FakeScriptObject)
+_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject]
+
+aten = torch.ops.aten
+prim = torch.ops.prim
+
+log = logging.getLogger(__name__)
+not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
+
+CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {}
+
+CONSTANT_NUMEL_LIMIT = 1
+
+T = TypeVar("T")
+U = TypeVar("U")
+_P = ParamSpec("_P")
+R = TypeVar("R")
+_Ts = TypeVarTuple("_Ts")
+
+null_ctx_type = type(nullcontext)
+# We currently convert all SymInt to proxies before we use them.
+# This could plausibly be handled at the Dynamo level.
+pytree.register_pytree_node(
+    torch.Size,
+    lambda xs: (list(xs), None),
+    lambda xs, _: tuple(xs),
+    # pyrefly: ignore [bad-argument-type]
+    flatten_with_keys_fn=lambda xs: (
+        [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
+        None,
+    ),
+    serialized_type_name="torch.Size",
+)
+# Ideally unflattening should not lose info, but we unflatten
+# torch.Size to tuple (see above). This is necessary because the
+# torch.Size constructor only accepts ints whereas our infra often
+# transforms them to non-ints, e.g. symint proxies. Anyway, losing
+# such info can cause pytree mapping or spec matching to fail, so
+# work around this problem using the following dict as needed.
+_pytree_subclasses_that_lose_info = {torch.Size: tuple}
+
+
+def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]:
+    """FX gets confused by varargs, de-confuse it"""
+    argnames = ",".join(f"arg{i}" for i in range(nargs))
+    return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
+
+
+@contextmanager
+def decompose(
+    decomposition_table: Optional[Mapping[OpOverload, Callable]],
+) -> Generator[Mapping[OpOverload, Callable], None, None]:
+    global CURRENT_DECOMPOSITION_TABLE
+    old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
+    CURRENT_DECOMPOSITION_TABLE = decomposition_table or {}
+    try:
+        yield CURRENT_DECOMPOSITION_TABLE
+    finally:
+        CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
+
+
+# ensure we cannot collide with other properties
+proxy_slot = object()
+
+
+class _NoDefault:
+    pass
+
+
+no_default = _NoDefault()
+
+from torch.types import py_sym_types, PySymType
+
+
+class _HasMeta(Protocol):
+    meta: dict[str, PySymType]
+
+
+def is_sym_node(node: _HasMeta) -> bool:
+    assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta"
+    return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
+
+
+@overload  # type: ignore[no-overload-impl]
+def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
+
+
+@overload
+def set_proxy_slot(
+    obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
+) -> None: ...
+
+
+@overload
+def set_proxy_slot(
+    obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
+) -> None: ...
+
+
+class _DisableUpdateTensorTracker(threading.local):
+    value: bool = False
+
+
+_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker()
+
+
+_FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT: dict[int, torch.fx.Node] = {}
+
+
+def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool:
+    """
+    Returns current state of disabling update tensor tracker.
+    """
+    return _disable_update_tensor_tracker_tls.value
+
+
+@contextmanager
+def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]:
+    """
+    NOTE "Do not clobber inplace ops"
+    By default tensor_tracker is updated every time.
+    This leads to chaining every operation by the FakeTensor.
+    For example for mutable ops if we have several consecutive mutable operations:
+
+    def f(x, y, z):
+        x.copy_(y)
+        x.copy_(z)
+        return x
+
+    Default graph result:
+    def f_graph(x, y, z)
+        x_1 = x.copy_(y)
+        x_2 = x_1.copy_(z)
+        return x_2
+
+    This chaining simplifies the fx passes and helps to prevent the reordering.
+    But in some cases, we want those nodes to be disconnected.
+    E.g. in case of splitting joint graph into forward and backward.
+    If first inplace op happened in forward, second in backward,
+    we want them after split to be properly placed.
+
+    Enabling this context manager for copy_ will result in:
+    def f_graph_2(x, y, z):
+        x_1 = x.copy_(y)
+        x_2 = x.copy_(z)
+        return x
+
+    Results of copy_ x1 and x2 will have empty users in the graph.
+    The reason why this behavior is not enabled for all inplace ops is that
+    some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_
+    in their fusions passes.
+    We could revisit enabling this logic for all inplace ops in future.
+    """
+    orig_value = _disable_update_tensor_tracker_tls.value
+    _disable_update_tensor_tracker_tls.value = True
+    try:
+        yield
+    finally:
+        _disable_update_tensor_tracker_tls.value = orig_value
+
+
+def set_proxy_slot(  # type: ignore[no-redef]
+    obj: Union[PySymType, _AnyScriptObjectType, Tensor],
+    tracer: _ProxyTracer,
+    proxy: object,
+) -> None:
+    log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
+    if isinstance(obj, Tensor):
+        # We DO want to clobber proxies whenever we run an inplace operation
+        # on a tensor, and it affects the metadata on the proxy.
+        assert isinstance(proxy, _ProxyTensor)
+        # see NOTE [Do not clobber inplace ops]
+        if not _is_proxy_tensor_update_tensor_tracker_disabled():
+            tracer.tensor_tracker[obj] = proxy
+    elif isinstance(obj, (_AnyScriptObject)):
+        # We DO want to clobber proxies, with a similar rationale as for tensors.
+        assert isinstance(proxy, Proxy)
+        tracer.script_object_tracker[obj] = proxy
+    else:
+        # NB: Never clobber pre-existing proxy.  Although the proxies
+        # are in principle equivalent, when we do graph partitioning
+        # we need there not to be spurious dependencies on tangent inputs.
+        # This works because primals get their SymInts set first, and
+        # THEN later we allocate tangent inputs.  Make sure if a SymInt
+        # is derivable from a primal that we use that.
+        assert isinstance(obj, py_sym_types), type(obj)
+        if obj not in tracer.symnode_tracker:
+            proxy = typing.cast(_PySymProxyType, proxy)
+            tracer.symnode_tracker[obj] = proxy
+
+            # WAR: python test/dynamo/test_subclasses.py
+            # TestNestedTensor.test_basic_autograd
+            #
+            # AOTAutograd doesn't pass the "outer sizes" as an actual argument
+            # to make_fx, but it is made use of internally in AOTAutograd's
+            # call to tensor unflatten.  Because the outer sizes isn't passed
+            # as an argument, it is therefore untracked.  However, it turns
+            # out you luck out, because *Dynamo* will manually add the outer
+            # sizes as an argument so you can fix up the proxy'ness.
+            #
+            # This is probably fixed in
+            # https://github.com/pytorch/pytorch/pull/125941/
+            import sympy
+
+            if isinstance(obj.node.expr, sympy.Symbol):
+                tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(
+                    proxy, obj
+                )
+
+
+def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
+    assert isinstance(obj, (Tensor, SymNode)), type(obj)
+    # pyrefly: ignore [no-matching-overload]
+    return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
+
+
+_PySymProxyType = Thunk[Proxy]
+
+
+@overload
+def get_proxy_slot(
+    obj: Tensor,
+    tracer: _ProxyTracer,
+) -> _ProxyTensor: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: Tensor,
+    tracer: _ProxyTracer,
+    default: U,
+) -> Union[_ProxyTensor, U]: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: Tensor,
+    tracer: _ProxyTracer,
+    default: U,
+    transform: Callable[[_ProxyTensor], R],
+) -> Union[R, U]: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: _AnyScriptObjectType,
+    tracer: _ProxyTracer,
+) -> Proxy: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: _AnyScriptObjectType,
+    tracer: _ProxyTracer,
+    default: U,
+) -> Union[Proxy, U]: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: _AnyScriptObjectType,
+    tracer: _ProxyTracer,
+    default: U,
+    transform: Callable[[Proxy], R],
+) -> Union[R, U]: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: PySymType,
+    tracer: _ProxyTracer,
+) -> _PySymProxyType: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: PySymType,
+    tracer: _ProxyTracer,
+    default: T,
+) -> Union[T, _PySymProxyType]: ...
+
+
+@overload
+def get_proxy_slot(
+    obj: PySymType,
+    tracer: _ProxyTracer,
+    default: U,
+    transform: Callable[[_PySymProxyType], R],
+) -> Union[R, U]: ...
+
+
+# the default argument is what to return if the slot is not set.
+# the transform argument is handy if you need to extract a subfield from
+# the successfully looked up result (but NOT the default.)
+def get_proxy_slot(
+    obj: Union[Tensor, _AnyScriptObjectType, PySymType],
+    tracer: _ProxyTracer,
+    default: object = no_default,
+    transform: Callable = lambda x: x,
+) -> object:
+    tracker: Any
+    if isinstance(obj, Tensor):
+        tracker = tracer.tensor_tracker
+    elif isinstance(obj, _AnyScriptObject):
+        tracker = tracer.script_object_tracker
+    else:
+        assert isinstance(obj, py_sym_types), type(obj)
+        tracker = tracer.symnode_tracker
+
+    # pyrefly: ignore [index-error]
+    # pyrefly: ignore [no-matching-overload, bad-argument-type]
+    value = tracker.get(obj)
+
+    if value is None and isinstance(obj, py_sym_types):
+        if obj.node.is_symbolic():
+            # Last ditch - we found a SymInt (SymBool, etc) we don't know
+            # about.
+            if (tmp := tracer.sympy_expr_tracker.get(obj.node.expr)) is not None:
+                value = tmp.proxy
+
+            else:
+                # Attempt to build it from first principles.
+                _build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
+                # pyrefly: ignore [no-matching-overload]
+                value = tracker.get(obj)
+
+    if value is None:
+        # We don't know this value - return the default.
+        if isinstance(default, _NoDefault):
+            raise RuntimeError(
+                f"{obj} ({type(obj)}, {id(obj)})is not tracked with proxy for {tracer}"
+            )
+        return default
+
+    res = transform(value)
+    return res
+
+
+@functools.cache
+def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
+    """
+    Returns a dict converting sympy functions to python operators
+    (i.e. `sympy.Mul` -> `operator.mul`)
+    """
+    import torch.utils._sympy.interp
+
+    handlers = {}
+    for k, v in torch.utils._sympy.interp.handlers().items():
+        op = getattr(operator, v, None)
+        if op is not None:
+            handlers[k] = op
+    return handlers
+
+
+def _build_proxy_for_sym_expr(
+    tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
+) -> IntLikeType | FloatLikeType | BoolLikeType | None:
+    """
+    Decompose `expr` and look for the pieces as inputs. If `out` is provided
+    then that will be the resulting SymNode (and `out.expr` must be the same as
+    `expr`).
+
+    This function is used when the ProxyTorchDispatchMode sees a SymNode
+    that it hasn't seen before to try to associate it with traced inputs.
+
+    How can this happen?
+
+    First thing to remember is that although sympy.Exprs are interned (so
+    `sympy.Expr("s3*s4")` will always have the same `id` and will always compare
+    equal) SymNode does not (so doing `SymNode("s3")*SymNode("s4")` twice in a
+    row will give two unique SymNodes).
+
+    - On way for this to happen is if we turn off tracing to compute an
+      intermediate value and then USE that value with tracing turned on - for
+      example if we turn off tracing to do some FakeTensor propagation to
+      compute a size (dtensor does this) but then turn tracing back on and use
+      that computed size.
+
+    - Another way is if we compute a size in one graph and stash it somewhere
+      hidden (such as in some meta-data) and later use it in a different graph
+      (dtensor does this too). Since the size was computed in the first graph
+      and it's not an official input to the second graph it's not tracked
+      properly. This is often going to show up as it usually works in fullgraph
+      but a graph break causes a failure.
+
+    To handle this we decompose the sympy.Expr and look for the pieces as
+    inputs. But there are problems with this approach:
+
+    - We lose operation provanance: We end up figuring out where to get the
+      inputs - but those may not actually be correct. If we have "s1" coming in
+      from both tensor1 and tensor2 and we pick the wrong one we could end up
+      keeping a tensor alive longer than intended.
+
+    - There's no guarantee that those values are inputs to the graph: If we have
+      "s1*s2" computed in a graph #1 and used in graph #2 there's no guarantee
+      that the input that holds "s1" is actually an input on graph #2.
+
+    - The decomposition isn't guaranteed to be the same: Sympy can "simplify"
+      expressions so it's possible that our inputs are "s1*s2" and "s3" but we
+      decompose it into "s1" and "s2*s3" - which wouldn't be found.
+
+    Other ways we could handle this:
+
+    - Don't: Just require that all inputs are tracked properly. This is the
+      "correct" solution but harder because you need to track down each
+      potential problem one by one and fix them. And when it fails it's a lot of
+      work to figure out both why it's failing and the right way to fix it. This
+      is complicated by the fact that a stashed value could be incorrect but
+      work fine until we happen to get an graph break in the wrong place - so it
+      may be a while before the bug is found. (Maybe we need a "dynamo abuse
+      mode" where we run tests with as many graph breaks inserted as possible?)
+
+    - Track SymNode ops separately from proxy tracing: Right now SymNode
+      operations are tracked as part of the proxy tracing - so when we disable
+      proxy tracing we also disable SymNode tracing. But we don't have to do
+      that - we could instead always have SymNodes track where they came from
+      and just use that when needed. This solves the problem of tracing being
+      temporarily turned off but doesn't help if an input isn't present after a
+      graph break.
+
+    - Better decomposition: Right now the decomposition is pretty simple. We do
+      have a sat-solver available to us so we could theoretically do a better
+      job figuring out a "correct" decomposition. But that still relies on
+      having the inputs available at all - which isn't a guarantee.
+    """
+
+    if (value := tracer.sympy_expr_tracker.get(expr)) is not None:
+        assert not out
+        return value.value
+
+    if isinstance(expr, (int, float, bool)):
+        return expr
+    if expr.is_Integer:
+        return int(expr)
+    if expr.is_Float:
+        return float(expr)
+
+    args = []
+    for arg in expr.args:
+        if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
+            return None
+        args.append(arg_value)
+    args = tuple(args)
+
+    func: OpOverload | None = _sympy_handlers().get(expr.func)  # type: ignore[assignment]
+    if not func:
+        # Handler not found
+        return None
+
+    if out is None:
+        out = func(*args)
+    else:
+        _sym_register(tracer, func, args, out)
+    return out
+
+
+def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
+    # val.detach() will also eventually call fast_detach(),
+    # but this saves us a full trip into __torch_dispatch__
+    # (snapshot_fake is called a lot)
+    if isinstance(val, FakeTensor):
+        return fast_detach(val.fake_mode, val, include_real)
+    else:
+        return val.detach()
+
+
+_ExtractValType = Optional[
+    Union[
+        PySymType,
+        _AnyScriptObjectType,
+        BackwardState,
+        list["_ExtractValType"],
+        tuple["_ExtractValType", ...],
+        dict[str, "_ExtractValType"],
+        Tensor,
+        int,
+        float,
+        bool,
+    ]
+]
+
+
+def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType:
+    if is_fake(val):
+        return snapshot_fake(val, include_real=include_real)
+    elif isinstance(val, py_sym_types):
+        return val
+    elif isinstance(val, _AnyScriptObject):
+        return val
+    elif isinstance(val, BackwardState):
+        return val
+    elif isinstance(val, (list, tuple)):
+        return val.__class__([extract_val(x) for x in val])
+    elif isinstance(val, dict):
+        return {k: extract_val(v) for k, v in val.items()}
+    elif isinstance(val, Tensor):
+        if not val.is_sparse:
+            # NB: Kinda hacky, but we should try to get val as the metadata
+            # everywhere
+            # TODO: This doesn't properly track storages.  A more robust
+            # approach would be to maintain a per-trace FakeTensorMode and
+            # from_real_tensor to create fake values (don't forget to
+            # snapshot_fake)
+            from torch._guards import detect_fake_mode
+
+            fake_tensor_mode = detect_fake_mode(val)
+            if not fake_tensor_mode:
+                fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
+            with fake_tensor_mode:
+                return torch.empty_strided(
+                    val.shape, val.stride(), device=val.device, dtype=val.dtype
+                )
+        else:
+            return None
+    elif isinstance(val, (int, float, bool)):
+        return val
+    elif val is None:
+        return None
+
+    typing_extensions.assert_never(val)
+
+
+@contextmanager
+def _enable_thunkify(
+    tracer: _ProxyTracer, *, enable: bool = True
+) -> Generator[None, None, None]:
+    """
+    Enable thunkification inside the context manager.  Thunkification prevents
+    SymNode computation from directly being traced into an FX graph; instead,
+    the compute is only added to the graph if it is actually used.  This helps
+    us track SymNode compute when it is computed (since we need /something/
+    to put in the tracker) even if it is unlikely to be used.
+    """
+    old = tracer.enable_thunkify
+    tracer.enable_thunkify = enable
+    try:
+        yield
+    finally:
+        tracer.enable_thunkify = old
+
+
+@contextmanager
+def maybe_disable_thunkify() -> Generator[None, None, None]:
+    """Within a context, disable thunkification.  See :func:`maybe_enable_thunkify`
+    for more details.  This is helpful if you have a wrapper function which
+    you want to enable thunkification on, but in some segment on the inside (say,
+    the original user function), you want to disable thunkification as you know
+    it is not needed there.
+    """
+    proxy_mode = get_proxy_mode()
+    if proxy_mode is not None:
+        with _enable_thunkify(proxy_mode.tracer, enable=False):
+            yield
+    else:
+        yield
+
+
+@contextmanager
+def maybe_enable_thunkify() -> Generator[None, None, None]:
+    """Within this context manager, if you are doing make_fx tracing, we will thunkify
+    all SymNode compute and avoid tracing it into the graph unless it is actually needed.
+    You should prefer to avoid using this as much as possible, as lazy evaluation of
+    SymNode tracing can lead to long chains of thunks which will stack overflow
+    if you evaluate them.  However, this is currently sometimes necessary as there
+    are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
+    due to insufficient tracing of SymNode computation.
+    """
+    proxy_mode = get_proxy_mode()
+    if proxy_mode is not None:
+        with _enable_thunkify(proxy_mode.tracer):
+            yield
+    else:
+        yield
+
+
+# Note [invariants for node meta 'val']
+# What invariants do we have for the 'val' set on the FX node?  It has accurate
+# metadata... but only for metadata that exists "below" all other subsystems
+# (most notably autograd, but also vmap, functorch transforms, etc).  This means
+# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad,
+# grad_fn, _base (_base actually may be set due to recursive call to
+# ADInplaceOrView, but you shouldn't rely on it.)
+def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
+    proxy.node.meta["val"] = extract_val(
+        val, include_real=(proxy.node.op == "placeholder")
+    )
+
+    with _enable_thunkify(proxy.tracer):  # type: ignore[arg-type]
+        # Best effort tensor_meta setting; prefer using val!
+        if is_fake(val):
+            proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
+        elif isinstance(val, Tensor) and not val.is_sparse:
+            proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
+    return proxy
+
+
+def thunkify(
+    tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs
+) -> Thunk[R]:
+    """
+    Delays computation of f until it's called again
+    Also caches the result
+    """
+    if tracer.enable_thunkify:
+        return Thunk(functools.partial(f, *args, **kwargs))
+    else:
+        r = f(*args, **kwargs)
+        return Thunk(lambda: r)
+
+
+def track_tensor(
+    tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer
+) -> None:
+    def try_set_proxy_slot(
+        outer_s: IntLikeType,
+        proxy_callable: Callable[Concatenate[PySymType, _P], Proxy],
+        *args: _P.args,
+        **kwargs: _P.kwargs,
+    ) -> None:
+        assert callable(proxy_callable)
+        if isinstance(outer_s, SymInt):
+            with _enable_thunkify(tracer):
+                set_proxy_slot(
+                    outer_s,
+                    tracer,
+                    thunkify(tracer, proxy_callable, outer_s, *args, **kwargs),
+                )
+
+    # The basic idea is that we need to associate each tensor/SymInt
+    # with a Proxy.  How do we setup this association?  We just store
+    # the proxy on the proxy slot of the object, keyed on the tracer
+    # (so that if we have multiple tracers at the same time, they
+    # don't clobber each other.)
+    for i, s in enumerate(tensor.shape):
+        try_set_proxy_slot(
+            s,
+            lambda x, i: set_meta(
+                tracer.create_proxy(
+                    "call_function", torch.ops.aten.sym_size.int, (proxy, i), {}
+                ),
+                x,
+            ),
+            i,
+        )
+
+    if not is_sparse_any(tensor):
+        for i, s in enumerate(tensor.stride()):
+            try_set_proxy_slot(
+                s,
+                lambda x, i: set_meta(
+                    tracer.create_proxy(
+                        "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {}
+                    ),
+                    x,
+                ),
+                i,
+            )
+
+    try_set_proxy_slot(
+        tensor.numel(),
+        lambda x: set_meta(
+            tracer.create_proxy(
+                "call_function", torch.ops.aten.sym_numel.default, (proxy,), {}
+            ),
+            x,
+        ),
+    )
+    if not is_sparse_any(tensor):
+        try_set_proxy_slot(
+            tensor.storage_offset(),
+            lambda x: set_meta(
+                tracer.create_proxy(
+                    "call_function",
+                    torch.ops.aten.sym_storage_offset.default,
+                    (proxy,),
+                    {},
+                ),
+                x,
+            ),
+        )
+    set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
+
+
+_NestedProxys = Union[
+    Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"]
+]
+_NestedTensors = Union[
+    Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"]
+]
+
+
+def track_tensor_tree(
+    inner_res: T,
+    proxy_res: _NestedProxys,
+    *,
+    constant: Optional[_NestedTensors],
+    tracer: _ProxyTracer,
+) -> T:
+    # NB: We call set_unbacked_bindings only on the *topmost* call to
+    # track_tensor_tree, not recursive calls.  This is because there must
+    # be only ONE unbacked_binding proxy call, and it should be the one
+    # where all of the unbacked SymInts actually first come into existence.
+    # If you call this again on the inner proxies for the tuple projections,
+    # you will have multiple unbacked_bindings for the same symbol, but
+    # they're not going to show up anywhere.
+    #
+    # I was briefly deceived into setting unbacked bindings recursively when
+    # working on https://github.com/pytorch/pytorch/pull/133585 because I
+    # observed that some extra unbacked bindings were needed to handle some
+    # higher order operator code.  But actually it looks like this was
+    # just an unrelated bug that needed to be fixed separately.
+    _set_unbacked_bindings(inner_res, proxy_res)
+
+    def wrap_with_proxy(
+        e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
+    ) -> None:
+        if isinstance(e, Tensor):
+            assert isinstance(proxy, Proxy)
+            assert constant is None or isinstance(constant, Tensor)
+            track_tensor(e, proxy, tracer=tracer, constant=constant)
+            set_meta(proxy, e)
+        elif isinstance(e, py_sym_types):
+            assert isinstance(proxy, Proxy)
+            # NB: eagerly set meta here, so that the numbering is in order
+            set_meta(proxy, e)
+            set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
+        elif isinstance(e, _AnyScriptObject):
+            assert isinstance(proxy, Proxy)
+            set_proxy_slot(e, tracer, proxy)
+            set_meta(proxy, e)
+        elif isinstance(e, (tuple, list)):
+            # example use case: allreduce_ returns ([tensor], work)
+            if isinstance(proxy, fx.Proxy):
+                set_meta(proxy, e)
+
+            def get_constant(
+                c: Optional[_NestedTensors], idx: int
+            ) -> Optional[_NestedTensors]:
+                if c is None:
+                    return None
+                else:
+                    assert isinstance(c, (list, tuple))
+                    return c[idx]
+
+            for idx, ee in enumerate(e):
+                # Use an indexer here - if proxy is a List then it will unwrap
+                # it. If it's a Proxy then it will proxy the getelem.
+                wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx))  # type: ignore[index]
+
+        elif isinstance(e, dict):
+            # example use case: triton_kernel_wrapper takes arguments as kwargs
+
+            # In theory we could support const-prop when proxy-tensor-tracing
+            # operators that returns dicts of tensors, but we have no use case
+            # for it today (since the only op we currently trace that can
+            # return a dict is triton_kernel_wrapper_functional/mutation,
+            # which does not participate in const-prop)
+            assert constant is None
+
+            if isinstance(proxy, fx.Proxy):
+                set_meta(proxy, e)
+
+            for key, val in e.items():
+                wrap_with_proxy(val, proxy[key], None)  # type: ignore[index]
+
+        elif isinstance(e, BackwardState):
+            assert isinstance(proxy, Proxy)
+            set_meta(proxy, e)
+            e.proxy = proxy
+        else:
+            # intentionally pass on primitives
+            pass
+
+    wrap_with_proxy(inner_res, proxy_res, constant)
+
+    return inner_res
+
+
+@dataclass
+class _ProxyTensor:
+    proxy: Proxy
+    constant: Optional[Tensor]
+
+
+def fetch_sym_proxy(
+    tracer: _ProxyTracer,
+) -> Callable[[PySymType], Union[bool, int, float, Proxy]]:
+    def inner(e: PySymType) -> Union[int, bool, float, Proxy]:
+        n = e.node
+        if n.constant is not None:
+            return n.constant
+        if e.node.expr.is_number:
+            if isinstance(e, SymBool):
+                return bool(e.node.expr)
+            elif isinstance(e, SymInt):
+                return int(e.node.expr)
+            return float(e.node.expr)
+        else:
+            assert isinstance(e, py_sym_types)
+            # NB: we REQUIRE all symints to be tracked
+            return get_proxy_slot(e, tracer).force()
+
+    return inner
+
+
+@overload
+def fetch_object_proxy(
+    tracer: _ProxyTracer, t: Tensor
+) -> Union[_ProxyTensor, Tensor]: ...
+
+
+@overload
+def fetch_object_proxy(
+    tracer: _ProxyTracer, t: _AnyScriptObjectType
+) -> Union[Proxy, _AnyScriptObjectType]: ...
+
+
+@overload
+def fetch_object_proxy(
+    tracer: _ProxyTracer, t: PySymType
+) -> Union[_PySymProxyType, PySymType]: ...
+
+
+def fetch_object_proxy(
+    tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
+) -> object:
+    return get_proxy_slot(t, tracer, t)
+
+
+HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor)
+
+
+def _maybe_record_pointwise_barrier(
+    func: object, proxy_mode: ProxyTorchDispatchMode
+) -> None:
+    """
+    Records operators whose tensor outputs or inputs are fp16/bf16 so downstream pointwise code can
+    emulate eager's rounding behavior when emulate_precision_casts is enabled.
+    """
+    if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts:
+        return
+
+    if not isinstance(func, torch._ops.OpOverload):
+        return
+
+    last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes)))
+    t = last_node.meta.get("val")
+    low_pr_fp = (torch.bfloat16, torch.float16)
+
+    output_low_precision = isinstance(t, torch.Tensor) and t.dtype in low_pr_fp
+
+    if not output_low_precision:
+        for input_node in last_node.all_input_nodes:
+            val = input_node.meta.get("val") if hasattr(input_node, "meta") else None
+            if isinstance(val, torch.Tensor) and val.dtype in low_pr_fp:
+                output_low_precision = True
+                break
+
+    if not output_low_precision:
+        return
+
+    last_node.meta["low_precision_pointwise_barrier"] = True
+
+
+def _fetch_proxies_and_all_constant_flag(
+    flat_args_kwargs: Union[list[object], tuple[object, ...]], tracer: _ProxyTracer
+) -> tuple[list[object], tuple[object, ...], bool]:
+    """
+    Given flat arguments, fetch the proxies and whether they are all constants.
+    This is later used in proxy_call or when someone is trying to stitch together
+    graph node in tf or td modes.
+    """
+    f_flat_args_kwargs = [
+        (
+            fetch_object_proxy(tracer, x)
+            if isinstance(x, (Tensor, _AnyScriptObject))
+            else x
+        )
+        for x in flat_args_kwargs
+    ]
+
+    # If there are SymInts, we also should not consider this constant.
+    # However, fake tensor handling of SymInts is sufficiently broken that
+    # I couldn't write a test for this case
+    all_constant = (
+        not any(
+            t.constant is None
+            for t in f_flat_args_kwargs
+            if isinstance(t, _ProxyTensor)
+        )
+        # TODO: maybe constant SymInts should also be allowed?  Not sure if
+        # this can happen
+        and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs)
+    )
+
+    proxy_flat_args_kwargs = [
+        e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs
+    ]
+
+    proxy_flat_args_kwargs = [
+        (fetch_sym_proxy(tracer)(e) if isinstance(e, py_sym_types) else e)
+        for e in proxy_flat_args_kwargs
+    ]
+
+    return f_flat_args_kwargs, tuple(proxy_flat_args_kwargs), all_constant
+
+
+def proxy_call(
+    proxy_mode: ProxyTorchDispatchMode,
+    func: OpOverload,
+    pre_dispatch: bool,
+    args: tuple[object, ...],
+    kwargs: dict[str, object],
+) -> object:
+    unrecognized_types: list[type] = []
+    flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
+
+    def can_handle_tensor(x: Tensor) -> bool:
+        r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
+        if proxy_mode._allow_fake_constant:
+            r = r or type(x) is torch._subclasses.FakeTensor
+        if not r:
+            unrecognized_types.append(type(x))
+        return r
+
+    # If there are any tensor subclasses, we need to handle those tensor subclasses first
+    # TODO: we could use types to test this
+    if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)):
+        not_implemented_log.debug(
+            "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s",
+            unrecognized_types,
+        )
+        return NotImplemented
+
+    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
+    if r is not NotImplemented:
+        _maybe_record_pointwise_barrier(func, proxy_mode)
+        return r
+
+    # For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
+    if (
+        not pre_dispatch
+        and func
+        not in [
+            torch.ops.aten.size.default,
+            torch.ops.aten.stride.default,
+            torch.ops.aten.storage_offset.default,
+        ]
+        and autograd_would_have_decomposed(func, flat_args_kwargs)
+    ):
+        with proxy_mode:
+            r = func.decompose(*args, **kwargs)
+            if r is not NotImplemented:
+                return r
+
+    if func is torch.ops.aten.is_nonzero.default:
+        with proxy_mode:
+            torch._check(
+                args[0].numel() == 1,  # type: ignore[attr-defined]
+                lambda: "Boolean value of Tensor with more than one value is ambiguous",
+            )
+            return (args[0] != 0).item()  # type: ignore[attr-defined]
+
+    tracer = proxy_mode.tracer
+    f_flat_args_kwargs, proxy_flat_args_kwargs, all_constant = (
+        _fetch_proxies_and_all_constant_flag(flat_args_kwargs, tracer)
+    )
+
+    if torch.Tag.data_dependent_output in func.tags:
+        # Check if all of the Tensor inputs are constants
+        if all_constant:
+            const_flat_args_kwargs = [
+                t.constant if isinstance(t, _ProxyTensor) else t
+                for t in f_flat_args_kwargs
+            ]
+            const_args, const_kwargs = pytree.tree_unflatten(
+                const_flat_args_kwargs, spec
+            )
+            with unset_fake_temporarily():
+                return func(*const_args, **const_kwargs)
+        # If any of the Tensor inputs are "real" (not FakeTensor), we may
+        # incorrectly burn in constants by allowing this access.  Raise
+        # an error in this case
+        if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only(
+            Tensor, lambda t: not is_fake(t), (args, kwargs)
+        ):
+            raise RuntimeError(
+                f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
+                "It's likely that this is caused by data-dependent control flow or similar.  "
+                "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
+                "in your make_fx call."
+            )
+
+    proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec)
+
+    # When we trace through a torch.tensor invocation, you never actually
+    # see a torch.ops.aten.tensor call. Instead, the way this function is
+    # implemented internally is that we allocate a plain tensor (this is
+    # *guaranteed* to be a plain tensor, we disable all modes when doing
+    # so), and then call at::lift_fresh on it (to give modes a chance to do
+    # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
+    # to be freshly allocated, so we want lift_fresh to be a no-op (directly
+    # returning the input argument).
+    #
+    # Here is the basic problem: when we trace this sequence of executions
+    # into an FX graph, what happens to this call sequence?  Traditionally,
+    # tensor constants get interned as buffers on the FX GraphModule.  But
+    # this is dangerous.  Consider:
+    #
+    #       x = torch.tensor(1)
+    #       x.add_(2)
+    #
+    # Naively, this traces into:
+    #
+    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
+    #       x = torch.ops.aten.lift_fresh(t)
+    #       x.add_(2)
+    #
+    # If lift_fresh returns t directly, the subsequent add_ call will
+    # modify the tensor constant. Really, the problem is we've violated
+    # the invariant the argument to lift is fresh.  So what we should
+    # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
+    #
+    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
+    #       x = torch.ops.aten.lift_fresh_copy(t)
+    #       x.add_(2)
+    #
+    # This is what the overload modification does.
+    if func is torch.ops.aten.lift_fresh.default:
+        func = torch.ops.aten.lift_fresh_copy.default
+
+    proxy_out = proxy_mode.tracer.create_proxy(
+        "call_function",
+        func,
+        proxy_args,
+        proxy_kwargs,
+        name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
+    )
+
+    with _enable_thunkify(proxy_mode.tracer):
+        out = func(*args, **kwargs)
+
+    # In some circumstances, we will be tracing in a situation where a tensor
+    # is *statically* known to be a constant (currently, this only happens if
+    # you run torch.tensor; deterministic factory functions like torch.arange
+    # don't get this treatment).  When the tensor in question is small, it's
+    # helpful to due constant propagation in case we call item() (in which
+    # case we can return the constant value that is known, rather than give
+    # an error.)  The logic here tests if constant propagation is possible
+    # (because all of the inputs are constant).  If so, we disable fake tensor
+    # mode (if it is on) and do true compute on the constant.
+    #
+    # It's worth highlighting that we're making a policy decision here.
+    # There is a potential that the tensor is actually quite large, and we
+    # don't actually want to run the compute.  The tensor being quite large
+    # is one of the reasons why factory functions don't get this treatment
+    # (since they can be quite large; if a parameter is initialized to a
+    # constant value it will be!)  Similarly, there is also a potential
+    # to run an operator that blows up the size of a small tensor; we don't
+    # protect against this case, but we could force, e.g., only single
+    # element constant computation by testing the numel of the result before
+    # propagating const-ness.  Similarly, we don't require the constant to
+    # live on CPU, but we could.
+    any_constant = any(
+        t.constant is not None
+        for t in f_flat_args_kwargs
+        if isinstance(t, _ProxyTensor)
+    )
+
+    constant = None
+
+    def tensor_numel_in_limit(t: Tensor) -> bool:
+        return t.numel() <= CONSTANT_NUMEL_LIMIT
+
+    # If this is a lift, the input tensor is guaranteed to be a
+    # constant, so we keep a copy of the original argument along so
+    # we can query it if we're asked to item() it at some later point
+    if (
+        func is torch.ops.aten.lift_fresh_copy.default
+        and out.numel() <= CONSTANT_NUMEL_LIMIT
+    ):
+        with unset_fake_temporarily():
+            assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
+            constant = args[0].clone()
+    elif (
+        torch.Tag.nondeterministic_seeded not in func.tags
+        and all_constant
+        and any_constant
+        and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
+    ):
+        # NB: do NOT include factories as constants
+        with unset_fake_temporarily():
+            const_flat_args_kwargs = [
+                t.constant if isinstance(t, _ProxyTensor) else t
+                for t in f_flat_args_kwargs
+            ]
+            const_args, const_kwargs = pytree.tree_unflatten(
+                const_flat_args_kwargs, spec
+            )
+            constant = func(*const_args, **const_kwargs)
+    else:
+        constant = None
+
+    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
+    _maybe_record_pointwise_barrier(func, proxy_mode)
+    return out
+
+
+class _SymNodeDict:
+    """
+    Wrapper around a dictionary that will hash SymInts with their nodes
+    """
+
+    def __init__(self) -> None:
+        self.sym_node_dict: dict[PySymType, _PySymProxyType] = {}
+
+    def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
+        self.sym_node_dict[key.node] = value
+
+    def __getitem__(self, key: PySymType) -> _PySymProxyType:
+        return self.sym_node_dict[key.node]
+
+    def __contains__(self, key: PySymType) -> bool:
+        return key.node in self.sym_node_dict
+
+    def get(
+        self, key: PySymType, default: Optional[_PySymProxyType] = None
+    ) -> _PySymProxyType:
+        # dict.get()'s annotation doesn't accept `None` when the value type
+        # isn't Optional.
+        return self.sym_node_dict.get(key.node, default)  # type: ignore[arg-type, return-value]
+
+    def __iter__(self) -> Any:
+        raise NotImplementedError
+
+    def __len__(self) -> int:
+        return len(self.sym_node_dict)
+
+
+@dataclass
+class _SympyExprTrackerValue:
+    proxy: _PySymProxyType
+    value: PySymType
+
+
+class PythonKeyTracer(Tracer):
+    script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
+    symnode_tracker: _SymNodeDict
+    sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
+    tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
+    torch_fn_counts: dict[OpOverload, int]
+    enable_thunkify: bool = False
+
+    def __init__(self) -> None:
+        super().__init__(autowrap_modules=())  # type: ignore[arg-type]
+        self.tensor_tracker = WeakTensorKeyDictionary()
+        self.symnode_tracker = _SymNodeDict()
+        self.script_object_tracker = WeakIdKeyDictionary(
+            dict=None, ref_type=_WeakHashRef
+        )
+        self.sympy_expr_tracker = {}
+
+        # Stores the torch function that was called during tracing
+        self.torch_fn_metadata = None
+        # Stores the counts for every torch function called. This is to help
+        # distinguish between different calls to the same torch function.
+        self.torch_fn_counts = {}
+        self.enable_thunkify = False
+
+    # In general, we don't want to make modules leaves. In principle, users of
+    # this tracer might want to override this in order to turn a couple specific
+    # modules into leaves in the traced graph.
+    def call_module(
+        self,
+        m: Module,
+        forward: Callable[..., Any],
+        args: tuple[Any, ...],
+        kwargs: dict[str, Any],
+    ) -> Any:
+        return forward(*args, **kwargs)
+
+    # We don't want to turn getattr calls into proxies. So we just return the actual value.
+    def getattr(
+        self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
+    ) -> object:
+        return attr_val
+
+    def create_arg(self, a: object) -> fx.node.Node:
+        if isinstance(a, torch.nn.Parameter):
+            for n, p in self.root.named_parameters():
+                if a is p:
+                    return self.create_node("get_attr", n, (), {})
+
+            qualname = self.get_fresh_qualname("_param_constant")
+            setattr(self.root, qualname, a)
+
+            return self.create_node("get_attr", qualname, (), {})
+        elif isinstance(a, py_sym_types):
+            assert a.node.constant is not None
+            return a.node.constant
+        return super().create_arg(a)  # type: ignore[return-value]
+
+    @overload
+    def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ...
+
+    @overload
+    def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ...
+
+    @overload
+    def unwrap_proxy(
+        self, e: _AnyScriptObjectType
+    ) -> Union[Proxy, _AnyScriptObjectType]: ...
+
+    def unwrap_proxy(self, e: T) -> object:
+        if isinstance(e, Tensor):
+            return get_proxy_slot(e, self, e, lambda x: x.proxy)  # type: ignore[attr-defined]
+        elif isinstance(e, py_sym_types):
+            return get_proxy_slot(e, self, e, lambda e: e.force())
+        elif isinstance(e, _AnyScriptObject):
+            return get_proxy_slot(e, self, e)
+        else:
+            return e
+
+    def create_node(
+        self,
+        kind: str,
+        target: Target,
+        args: tuple[Argument, ...],
+        kwargs: dict[str, Argument],
+        name: Optional[str] = None,
+        type_expr: Optional[Any] = None,
+    ) -> torch.fx.Node:
+        node = super().create_node(kind, target, args, kwargs, name, type_expr)  # type: ignore[arg-type]
+
+        if node.op in ["placeholder", "output"] and "stack_trace" in node.meta:
+            del node.meta["stack_trace"]
+
+        if kind == "get_attr":
+            assert isinstance(target, str)
+            attr = getattr(self.root, target)
+            if isinstance(attr, torch.Tensor):
+                with disable_proxy_modes_tracing():
+                    node.meta["val"] = extract_val(attr)
+
+        def map_fn(v: Any) -> Optional[_ExtractValType]:
+            if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
+                return None
+            val = v.meta["val"]
+            # other subclasses like FunctionalTensor error on `extract_val`
+            # "Attempting to use FunctionalTensor on its own." just store FakeTensors for now
+            if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor):
+                return None
+            return extract_val(v.meta["val"])
+
+        if _should_save_eager_input_vals(target, (args, kwargs)):
+            # NOTE "eager_input_vals"
+            # We save the original (args, kwargs) FakeTensor values for nodes
+            # that have exact stride requirements. This is useful downstream.
+            # We use this information inside Inductor to ensure that inputs to
+            # stride-sensitive operators have the correct strides.
+            arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn)  # type: ignore[misc, arg-type]
+            node.meta["eager_input_vals"] = (arg_inp, kwarg_inp)
+
+        return node
+
+
+def _should_save_eager_input_vals(
+    target: Any,
+    args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None,
+) -> bool:
+    from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP
+
+    if not callable(target):
+        return False
+    if isinstance(
+        target,
+        (
+            torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional,
+            torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
+            InvokeSubgraphHOP,
+        ),
+    ):
+        return True
+    if args_kwargs is not None and (
+        target is torch.ops.higher_order.auto_functionalized
+        or target is torch.ops.higher_order.auto_functionalized_v2
+    ):
+        args = args_kwargs[0]
+        assert isinstance(
+            args[0], (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
+        )
+        return _should_save_eager_input_vals(args[0], None)
+    if target is torch.ops.higher_order.with_effects:
+        # TODO: inductor lowering for with_effects needs to be updated to propagate
+        # the arg_kwarg_vals
+        return False
+    if isinstance(target, torch._ops.HigherOrderOperator):
+        if pytree.tree_any(_should_save_eager_input_vals, args_kwargs):
+            raise RuntimeError(
+                f"NYI: The HOP {target} has an input that is an OpOverload that "
+                f"needs exact strides. We probably need special logic to "
+                f"propagate the FakeTensor vals. Please file an issue."
+            )
+    if isinstance(target, torch._ops.OpOverload):
+        from torch._library.utils import get_layout_constraint_tag
+
+        return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
+    return False
+
+
+def _make_temp_remove_mode_context_manager(
+    mode_ty: type[TorchFunctionMode],
+) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
+    @contextmanager
+    def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
+        from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
+
+        temp_elements = []
+        removed_mode = None
+
+        while _len_torch_function_stack() > 0:
+            mode = _pop_mode()
+            if isinstance(mode, mode_ty):
+                removed_mode = mode
+                break
+            else:
+                temp_elements.append(mode)
+
+        for mode in reversed(temp_elements):
+            _push_mode(mode)
+
+        try:
+            yield removed_mode
+
+        finally:
+            if removed_mode is not None:
+                count = len(temp_elements)
+                while count > 0:
+                    mode = _pop_mode()
+                    count -= 1
+
+                temp_elements.append(removed_mode)
+
+                for mode in reversed(temp_elements):
+                    _push_mode(mode)
+
+    return context_manager_fn
+
+
+@torch._disable_dynamo
+def dispatch_trace(
+    root: Union[Module, Callable],
+    tracer: Tracer,
+    concrete_args: Optional[tuple[Any, ...]] = None,
+) -> GraphModule:
+    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
+
+    # NB: be careful not to DCE .item() calls
+    def impure_pred(n: fx.Node) -> bool:
+        from .symbolic_shapes import is_accessor_node
+
+        # Always defer to the built-in notion of impure
+        if n.is_impure():
+            return True
+
+        # Accessors always OK to DCE
+        if is_accessor_node(n):
+            return False
+
+        # If the operator in question takes SymInt args to SymInt output,
+        # we assume it's pure and OK to DCE
+        if (
+            isinstance(n.meta.get("val"), py_sym_types)
+            and
+            # NB: constant args ok
+            all(
+                isinstance(a.meta.get("val"), py_sym_types)
+                for a in n.args
+                if isinstance(a, fx.Node)
+            )
+        ):
+            return False
+
+        # No idea, just assume it's not OK
+        return True
+
+    graph.eliminate_dead_code(impure_pred)
+    from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
+
+    dedupe_symints(graph)
+    name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
+    return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
+
+
+def wrap_key(
+    f: Callable[[Unpack[_Ts]], R],
+    tensors: tuple[Unpack[_Ts]],
+    tracer: _ProxyTracer,
+    pre_dispatch: bool,
+) -> Callable[_P, R]:
+    flat_tensors, _tensors_spec = pytree.tree_flatten(tensors)
+
+    @functools.wraps(f)
+    def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
+        nonlocal tensors
+
+        flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
+        assert len(flat_proxies) == len(flat_tensors)
+        with disable_proxy_modes_tracing() as m:
+            assert isinstance(m, ProxyTorchDispatchMode)
+            track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
+
+        if getattr(tracer, "proxy_module_inputs", False):
+            tensors = [  # type: ignore[assignment, var-annotated]
+                p if isinstance(t, torch.nn.Module) else t
+                for t, p in zip(tensors, proxies)  # type: ignore[arg-type]
+            ]
+
+        def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
+            return get_proxy_slot(t, tracer, t, lambda x: x.proxy)  # type: ignore[attr-defined]
+
+        out = f(*tensors)  # type:ignore[call-arg]
+        out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
+        out = pytree.tree_map_only(
+            _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out
+        )
+
+        def get_sym_proxy_slot(t: PySymType) -> Proxy:
+            return get_proxy_slot(t, tracer).force()
+
+        out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out)
+        return out
+
+    return wrapped
+
+
+# TODO: Make downstream users of this work with OperatorBase
+ORIGINAL_ATEN: Optional[object] = None
+
+
+@contextmanager
+def set_original_aten_op(
+    func: OpOverload | torch._ops.HigherOrderOperator,
+) -> Generator[None, None, None]:
+    global ORIGINAL_ATEN
+    if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
+        ORIGINAL_ATEN = func
+        fx_traceback.current_meta["original_aten"] = func
+        try:
+            yield
+        finally:
+            ORIGINAL_ATEN = None
+            fx_traceback.current_meta["original_aten"] = None
+    else:
+        yield
+
+
+class TorchFunctionMetadataMode(TorchFunctionMode):
+    def __init__(self, tracer: _ProxyTracer) -> None:
+        self.tracer = tracer
+
+    def __torch_function__(
+        self,
+        func: OpOverload,
+        types: tuple[torch._C._TensorMeta, ...],
+        args: tuple[object, ...] = (),
+        kwargs: Optional[dict[str, object]] = None,
+    ) -> object:
+        kwargs = kwargs or {}
+        # pyrefly: ignore [bad-assignment]
+        self.tracer.torch_fn_metadata = func
+        self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
+        return func(*args, **kwargs)
+
+
+_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
+    TorchFunctionMetadataMode
+)
+
+
+# This mode is **only** used for pre_dispatch tracing.
+# In particular, we need to make sure that autograd/autocast API's
+# that do not desugar into dispatcher operators stay in the graph.
+class PreDispatchTorchFunctionMode(TorchFunctionMode):
+    def __init__(self, tracer: _ProxyTracer) -> None:
+        self.tracer = tracer
+        # The input to torch.amp.autocast_mode._exit_autocast graph node should be the
+        # enter_autocast node. So we have to save the enter autocast node here, and assign it
+        # to the exit_autocast call_function node.
+        self.enter_autocast_nodes: list[torch.fx.Node] = []
+
+    def __torch_function__(
+        self,
+        func: Union[OpOverload, Callable],
+        types: tuple[torch._C._TensorMeta, ...],
+        args: tuple[object, ...] = (),
+        kwargs: Optional[dict[str, object]] = None,
+    ) -> object:
+        kwargs = kwargs or {}
+        if func in _side_effectful_need_to_be_preserved_pre_dispatch:
+            # It's for passing the export verifier which needs to verify the meta['val']
+            # TODO(tmanlaibaatar): we should systematically couple it with export verifier,
+            # instead of hardcoding it here.
+            # T203648563
+            if func is torch.amp.autocast_mode._exit_autocast:
+                enter_node = self.enter_autocast_nodes.pop()
+                args = (enter_node,)
+            node = self.tracer.create_node("call_function", func, args, {})  # type: ignore[arg-type]
+            if func is torch.amp.autocast_mode._enter_autocast:
+                self.enter_autocast_nodes.append(node)
+            if func in [
+                torch._C._set_grad_enabled,
+                torch.amp.autocast_mode._enter_autocast,
+                torch.amp.autocast_mode._exit_autocast,
+            ]:
+                node.meta["val"] = None
+            # For autocast, the python APIs run so we don't have to run them again
+            # here.
+            if func is torch._C._set_grad_enabled:
+                # pyrefly: ignore [bad-argument-type]
+                func(*args, **kwargs)
+            return node
+
+        # We need more complicated handling here because the inputs
+        # to these functions are sometimes tensors or symints where
+        # we need to fetch the proxies properly.
+        if func in [
+            torch._functorch.predispatch._add_batch_dim,
+            torch._functorch.predispatch._remove_batch_dim,
+            torch._functorch.predispatch._vmap_increment_nesting,
+            torch._functorch.predispatch._vmap_decrement_nesting,
+            torch._functorch.vmap.lazy_load_decompositions,
+        ]:
+            _, proxies, _ = _fetch_proxies_and_all_constant_flag(args, self.tracer)
+            out_proxy = self.tracer.create_proxy(
+                "call_function",
+                func,
+                proxies,
+                {},
+            )
+            res = func(*args, **kwargs)
+            track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer)
+            return res
+        return func(*args, **kwargs)
+
+
+_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
+    PreDispatchTorchFunctionMode
+)
+
+
+class ProxyTorchDispatchMode(TorchDispatchMode):
+    # Ensure this is read-only; this exists only for legacy reasons
+    @property
+    def enable_tracing(self) -> bool:
+        return True
+
+    def __init__(
+        self,
+        tracer: _ProxyTracer,
+        tracing_mode: str,
+        pre_dispatch: bool = False,
+        _allow_fake_constant: bool = False,
+        _error_on_data_dependent_ops: bool = True,
+    ) -> None:
+        dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
+        super().__init__(dk)
+        self.tracer = tracer
+        self.tracing_mode = tracing_mode
+        self.pre_dispatch = pre_dispatch
+        self._allow_fake_constant = _allow_fake_constant
+        self._error_on_data_dependent_ops = _error_on_data_dependent_ops
+        # Indicates to our torch_dispatch dispatching infra that
+        # this is an "infra" mode with lower dispatching precedence.
+        self._mode_key = torch._C._TorchDispatchModeKey.PROXY
+        # Every time we enter a mode, we maintain a stack telling us what the previous
+        # ProxyTorchDispatchMode state was (if there was any).
+        # This lets us properly reset the state on exit.
+        self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
+        self.decomp_layers: int = 0
+        from torch._inductor import config
+
+        self.emulate_precision_casts: bool = config.emulate_precision_casts
+
+    @count
+    def __torch_dispatch__(
+        self,
+        func: OpOverload,
+        types: tuple[torch._C._TensorMeta, ...],
+        args: tuple[object, ...] = (),
+        kwargs: Optional[dict[str, object]] = None,
+    ) -> object:
+        with set_original_aten_op(func):
+            kwargs = kwargs or {}
+
+            if func == prim.device.default:
+                return func(*args, **kwargs)
+
+            return proxy_call(self, func, self.pre_dispatch, args, kwargs)
+
+    def __enter__(self) -> Self:
+        # Stash and store the previous proxy mode (there may or may not be one)
+        maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
+        self.enter_stack.append(maybe_prev_proxy_mode)
+        return super().__enter__()
+
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
+        b = super().__exit__(exc_type, exc_value, traceback)
+
+        # Re-enable the previous proxy mode, if there was one.
+        mb_previous_proxy_mode = self.enter_stack.pop()
+        if mb_previous_proxy_mode is not None:
+            _push_mode(mb_previous_proxy_mode)
+
+        return b
+
+    @classmethod
+    def is_infra_mode(cls) -> bool:
+        return True
+
+    def __sym_dispatch__(
+        self,
+        func: OpOverload,
+        types: tuple[torch._C._TensorMeta, ...],
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> object:
+        # Peephole optimize multiply by one
+        # NB: be careful not to trigger guards here!
+        if func is operator.mul:
+            if isinstance(args[1], int) and args[1] == 1:
+                return args[0]
+            elif isinstance(args[0], int) and args[0] == 1:
+                return args[1]
+
+        # For speed, we assume there are no nested data structures
+        # (otherwise we could use tree_map)
+        # We also assume there are no keyword arguments.
+        assert not kwargs
+        out = func(*args, **kwargs)
+        _sym_register(self.tracer, func, args, out)
+        return out
+
+
+def _sym_register(
+    tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: object
+) -> None:
+    # If func returned a constant, we don't need to trace; we have
+    # determined that the result is constant (no matter if the inputs
+    # were symbolic) and it is no longer necessary to trace the
+    # computation.  This could occur if func triggered some guards.
+    if isinstance(out, py_sym_types):
+        p_out_thunk = thunkify(
+            tracer, _compute_proxy, tracer, func=func, args=args, out=out
+        )
+        set_proxy_slot(out, tracer, p_out_thunk)
+
+
+def _compute_proxy(
+    tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: PySymType
+) -> Proxy:
+    # Handle torch.sym_sum
+    n_args: tuple[object, ...]
+    if len(args) == 1 and isinstance(args[0], (list, tuple)):
+        n_args = (
+            tuple(
+                (
+                    get_proxy_slot(a, tracer).force().node
+                    if isinstance(a, py_sym_types)
+                    else a
+                )
+                for a in args[0]
+            ),
+        )
+    else:
+        n_args = tuple(
+            (
+                get_proxy_slot(a, tracer).force().node
+                if isinstance(a, py_sym_types)
+                else a
+            )
+            for a in args
+        )
+
+    # func doesn't have a __torch_function__ that Proxy can interpose, so
+    # we gotta do it manually
+    n_out = tracer.create_node("call_function", func, n_args, {})  # type: ignore[arg-type]
+    p_out = fx.Proxy(n_out, tracer)
+    set_meta(p_out, out)
+    return p_out
+
+
+class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
+    script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
+    symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
+    tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
+    sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
+    torch_fn_metadata: Optional[OpOverload]
+    torch_fn_counts: dict[OpOverload, int]
+    enable_thunkify: bool = False
+
+    def __init__(self, graph: fx.graph.Graph) -> None:
+        super().__init__(graph)
+        self.symnode_tracker = weakref.WeakKeyDictionary()
+        self.tensor_tracker = WeakTensorKeyDictionary()
+        self.sympy_expr_tracker = {}
+        self.script_object_tracker = WeakIdKeyDictionary(
+            dict=None, ref_type=_WeakHashRef
+        )
+        # Stores the torch function that was called during tracing
+        self.torch_fn_metadata = None
+        # Stores the counts for every torch function called. This is to help
+        # distinguish between different calls to the same torch function.
+        self.torch_fn_counts = {}
+
+
+# TODO: I'm not sure what the point of this class is; you can just
+# make_fx through a regular Interpreter
+class DecompositionInterpreter(fx.Interpreter):
+    def __init__(
+        self,
+        module: fx.GraphModule,
+        new_graph: fx.Graph,
+        decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
+        **kwargs: object,
+    ) -> None:
+        super().__init__(module, **kwargs)  # type: ignore[arg-type]
+        self.new_graph = new_graph
+        self.tracer = _GraphAppendingTracerEx(self.new_graph)
+        # Blegh
+        self.decomposition_table = decomposition_table or {}
+        self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
+
+    # pyrefly: ignore [bad-override]
+    def placeholder(
+        self,
+        target: str,  # type: ignore[override]
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> object:
+        out = super().placeholder(target, args, kwargs)  # type: ignore[arg-type]
+        proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+        # TODO handle case where the first character of target is '*'
+        return out
+
+    # pyrefly: ignore [bad-override]
+    def get_attr(
+        self,
+        target: str,  # type: ignore[override]
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> object:
+        out = super().get_attr(target, args, kwargs)  # type: ignore[arg-type]
+        proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+        return out
+
+    # call_function, call_method, call_module get traced automatically by the outer mode.
+
+    # pyrefly: ignore [bad-override]
+    def output(
+        self,
+        target: str,  # type: ignore[override]
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> object:
+        out = super().output(target, args, kwargs)  # type: ignore[arg-type]
+
+        def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
+            return x.proxy.node
+
+        def unwrap(e: Tensor) -> Union[Tensor, fx.Node]:
+            return get_proxy_slot(e, self.tracer, e, get_proxy_node)
+
+        self.new_graph.output(pytree.tree_map(unwrap, out))
+        return out
+
+    def run(self, *args: object, **kwargs: object) -> object:
+        # Should enter the mode at least once for being able to restore it later
+        # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
+        with decompose(self.decomposition_table), self.mode:
+            return super().run(*args, **kwargs)  # type: ignore[arg-type]
+
+
+class _SelectiveDecomposeInterpreter(fx.Interpreter):
+    def __init__(
+        self,
+        module: fx.GraphModule,
+        should_decompose: Callable[[fx.Node], bool],
+        decomposition_table: Mapping[OpOverload, Callable],
+        **kwargs: object,
+    ) -> None:
+        """
+        For all nodes in `module`, selectively decompose if is `should_decompose`,
+        following the given `decomposition_table`.
+        """
+        super().__init__(module, **kwargs)  # type: ignore[arg-type]
+        self.should_decompose = should_decompose
+        self.decomposition_table = decomposition_table
+
+    @staticmethod
+    def recursive_wrap(
+        gm: fx.GraphModule,
+        should_decompose: Callable[[fx.Node], bool],
+        decomposition_table: Mapping[OpOverload, Callable],
+        **kwargs: object,
+    ) -> _SelectiveDecomposeInterpreter:
+        """
+        Recursively wrap gm and its sub graph modules. Specifically, HOP takes
+        sub graph module as args. We may not want to decompose all nodes within
+        these sub graph modules. So we also need to wrap these sub graph modules.
+        As a result:
+        - if should_decompose(hop) is True, we decompose all nodes within the hop.
+        - if should_decompose(hop) is False, we check each node within the hop
+            and decide whether decompose or not.
+        """
+        for node in gm.graph.nodes:
+            if node.op == "call_function" and isinstance(
+                node.target, HigherOrderOperator
+            ):
+                new_args = []
+                for arg in node.args:
+                    if isinstance(arg, fx.GraphModule):
+                        new_arg = _SelectiveDecomposeInterpreter.recursive_wrap(
+                            arg, should_decompose, decomposition_table, **kwargs
+                        )
+                    else:
+                        new_arg = arg
+                    new_args.append(new_arg)
+                node.args = tuple(new_args)
+
+        return _SelectiveDecomposeInterpreter(
+            gm, should_decompose, decomposition_table, **kwargs
+        )
+
+    def run_node(self, n):
+        if self.should_decompose(n):
+            with decompose(self.decomposition_table):
+                result = super().run_node(n)
+        else:
+            result = super().run_node(n)
+        return result
+
+
+def selective_decompose(
+    joint_gm: fx.GraphModule,
+    *args,
+    decomposition,
+    should_decompose,
+    trace_joint_graph: bool,
+) -> fx.GraphModule:
+    """Retrace a joint graph module and selectively apply decomposition."""
+
+    if trace_joint_graph:
+        # the arg name, primals and tangents, are important.
+        # make_fx keeps the name in the traced graph and partitioner later relies
+        # on the name to partition joint graph correctly.
+        def wrap_fn(primals: list[Any], tangents: list[Any]):
+            return _SelectiveDecomposeInterpreter.recursive_wrap(
+                joint_gm, should_decompose, decomposition
+            ).run(*args)
+    else:
+
+        def wrap_fn(*args):
+            return _SelectiveDecomposeInterpreter.recursive_wrap(
+                joint_gm, should_decompose, decomposition
+            ).run(*args)
+
+    return make_fx(wrap_fn, decomposition_table={})(*args)
+
+
+def wrapper_and_args_for_make_fx(
+    func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
+) -> tuple[Callable[[list[object]], R], list[object]]:
+    # make_fx doesn't support kwargs, so we need to do this flattening
+    # and then unflatten the args before calling func
+    flat_args, spec = pytree.tree_flatten((args, kwargs))
+
+    def wrapped(flat_args: list[object]) -> R:
+        fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
+        return func(*fn_args, **fn_kwargs)
+
+    return wrapped, flat_args
+
+
+@contextmanager
+def disable_autocast_cache() -> Generator[None, None, None]:
+    old_value = torch.is_autocast_cache_enabled()
+    torch.set_autocast_cache_enabled(False)
+    try:
+        yield
+    finally:
+        torch.set_autocast_cache_enabled(old_value)
+
+
+class _ModuleNotInstalledAsSubmoduleError(NameError):
+    pass
+
+
+# Base class for inline _ModuleStackTracer.__init__.AttrProxy
+class _AttrProxy:
+    def reset_proxy_mapping(self, base: Module, path: str) -> None:
+        pass
+
+
+class _ModuleStackTracer(PythonKeyTracer):
+    r"""Customized version of PythonKeyTracer that retains module stack
+    information in node.meta["nn_module_stack"].
+
+    FX symbolic trace actually does this already, but it relies on `self.root`
+    being the actual module being traced. Since make_fx traces a lambda of our
+    creation, things don't work properly.
+
+    So for this version we hold onto a reference to the original module
+    (scope_root) and use that to match the path. Also when we see,
+            A
+           / \
+          B   C
+           \ /
+            D
+    we want to record the path as A.B.D by recording only one path.
+    See Note [Preserving the nn module stack metadata during export non-strict mode]  # noqa: W605
+    """
+
+    def __init__(self, scope_root: GraphModule) -> None:
+        super().__init__()
+        self.record_stack_traces = True
+        self._record_forward_stack_traces_only = True
+        self.scope_root = scope_root
+        self.enable_attr_proxy = False
+        self.submodule_paths = {}
+        for name, m in self.scope_root.named_modules(remove_duplicate=False):
+            if m in self.submodule_paths:
+                log.info(
+                    "Shared module found between %s and %s, AttrProxy is enabled.",
+                    self.submodule_paths[m],
+                    name,
+                )
+                self.enable_attr_proxy = True
+            else:
+                self.submodule_paths[m] = name
+
+        self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
+        self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary()
+        self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary()
+        self.counter = 0
+
+        self.module_id_cache = defaultdict(list)
+        for name, mod in self.scope_root.named_modules(remove_duplicate=False):
+            self.module_id_cache[id(mod)].append(name)
+
+        # Build a wrapper around _AttrProxy to provide the tracer. We can't
+        # store it on _AttrProxy itself beceause we mimic the underlying class
+        # (including its attributes).
+        tracer = self
+
+        class AttrProxy(_AttrProxy):
+            def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None:
+                if isinstance(base, _AttrProxy):
+                    base = base.get_base()  # type: ignore[attr-defined]
+
+                assert isinstance(base, Module)
+                # Class is modified to be a subclass of torch.nn.Module
+                # Warning: We blow away our own attributes here to mimic the base class
+                # - so don't expect `self.x` to do anything useful.
+                # pyrefly: ignore [no-matching-overload]
+                # pyrefly: ignore [bad-override]
+                self.__class__ = type(
+                    base.__class__.__name__,
+                    (self.__class__, base.__class__),
+                    {},
+                )
+                self.__dict__ = base.__dict__
+                self.__class__.__module__ = base.__class__.__module__
+                self.__class__.__qualname__ = base.__class__.__qualname__
+
+                # This overwrites any existing paths if `base` is an AttrProxy
+                tracer.proxy_paths[self] = path
+                tracer.proxy_modules[self] = base
+
+            def __getattr__(self, name: str) -> AttrProxy:
+                assert isinstance(self, Module)
+                # Calling into torch.nn.Module.__getattr__ with super(),
+                # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py.
+                # which then calls into _ModuleStackTracer.getattr
+                attr_val = super().__getattr__(name)  # type: ignore[misc]
+                if not isinstance(attr_val, Module):
+                    return attr_val
+
+                # pyrefly: ignore [index-error]
+                return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
+
+            def get_base(self) -> Module:
+                return tracer.proxy_modules[self]
+
+            def __getitem__(self, idx: Union[int, slice]) -> AttrProxy:
+                if isinstance(idx, slice):
+                    if isinstance(self, torch.nn.Sequential):
+                        # Copied from nn/modules/container.py
+                        res = torch.nn.Sequential(
+                            OrderedDict(list(self._modules.items())[idx])
+                        )
+                        # pyrefly: ignore [index-error]
+                        return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
+                    elif isinstance(self, torch.nn.ModuleList):
+                        # Copied from nn/modules/container.py
+                        res = torch.nn.ModuleList(list(self._modules.values())[idx])
+                        # pyrefly: ignore [index-error]
+                        return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
+
+                return super().__getitem__(idx)  # type: ignore[misc]
+
+            @property
+            def _modules(self) -> dict[str, AttrProxy]:
+                assert "_modules" in self.__dict__
+                submodules = self.__dict__["_modules"]
+                assert isinstance(submodules, dict)
+                return {
+                    key: (
+                        AttrProxy(value, tracer.proxy_paths[self] + "." + str(key))  # type: ignore[misc]
+                        if value is not None
+                        else value
+                    )
+                    for key, value in submodules.items()
+                }
+
+        self.proxy_type = AttrProxy
+
+    def path_of_module(self, mod: Module) -> str:
+        """
+        Use tracked access path during tracing instead of the default BFS behavior.
+        Still use all the possible module paths to verify the result.
+        """
+        if mod is self.scope_root:
+            return ""
+
+        if isinstance(mod, _AttrProxy):
+            return self.proxy_paths[mod]
+
+        try:
+            return Tracer.path_of_module(self, mod)
+        except NameError as e:
+            raise _ModuleNotInstalledAsSubmoduleError from e
+
+    def getattr(
+        self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
+    ) -> object:
+        if (
+            not isinstance(attr_val, Module)
+            or isinstance(attr_val, fx.GraphModule)
+            or not self.enable_attr_proxy
+        ):
+            return super().getattr(attr, attr_val, parameter_proxy_cache)
+        if isinstance(attr_val, _AttrProxy):
+            return attr_val
+
+        # See NOTE [caching AttrProxy].
+        if attr_val not in self.attr_proxy_map:
+            self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr)
+        else:
+            self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
+        return self.attr_proxy_map[attr_val]
+
+    def trace(  # type: ignore[override]
+        self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]]
+    ) -> fx.Graph:
+        res = super().trace(root, concrete_args)
+
+        # NOTE [export non-strict fake tensor leak detection]
+        # In non-strict export, we don't have dynamo's side effect
+        # tracking logic which makes some cases hard to detect.
+        # In general, our detecting strategy is:
+        #  (1) We instrument fake tensor creation to log all the fake tensors created during export.
+        #  (2) We dump the proxy to fake tensor map from make_fx tracer (_FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT))
+        #  (3) Filter out fake tensors that are logged during (1):
+        #      (1) Associated with TrackedFake (input tracking thing in symbolic_shapes)
+        #      (2) Associated with gm.meta
+        #  (4) Do ID match with the proxies
+
+        global _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT
+        _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT.clear()
+
+        for key, val in self.tensor_tracker.items():
+            _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[id(key)] = val.proxy.node
+
+        # Since we are making _AttrProxy mimic the original
+        # submodule, when someone registers a module directly
+        # to the tracer while tracing, the proxy object gets registered
+        # first. So we need to replace the proxy modules with the real ones
+        # This can happen during HOO tracing
+        proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = []
+        for name, module in self.root.named_modules():
+            if module in self.proxy_modules:
+                proxy_module_names_to_be_replaced.append((name, module))
+
+        def _delete_proxy_attr(obj: Module, target: str) -> bool:
+            # Copied from fx/graph_module.py
+            # Customized it for proxy type
+            atoms = target.split(".")
+            path, target_submod = atoms[:-1], atoms[-1]
+            assert isinstance(obj, Module)
+            mod = obj
+
+            # Get the parent module
+            for item in path:
+                if not hasattr(mod, item):
+                    return False
+
+                mod = getattr(mod, item)
+
+                if not isinstance(mod, (_AttrProxy, Module)):
+                    return False
+
+            if not hasattr(mod, target_submod):
+                return False
+
+            # At least the leaf module should be proxy type.
+            if not isinstance(getattr(mod, target_submod), _AttrProxy):
+                return False
+
+            delattr(mod, target_submod)
+            return True
+
+        for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced:
+            _delete_proxy_attr(self.root, proxy_module_name)
+            actual_module = self.proxy_modules[proxy_module]
+            _assign_attr(actual_module, self.root, proxy_module_name)
+
+        return res
+
+    def call_module(
+        self,
+        m: Module,
+        forward: Callable,
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> None:
+        """PythonKeyTracer overrides call_module to avoid the scope handling,
+        but we actually want it.
+        """
+        from torch._dynamo import OptimizedModule
+
+        # FIXME (tmanlaibaatar)
+        # When we call torch.compile inside HOO, we will end up
+        # invoking a module that is not registered on the root. For
+        # now, we just inline them. But once we start supporting
+        # mark_strict in export, we do need to properly handle this.
+        # Right now, it doesn't matter because current non-strict
+        # use cases don't need to work with HOO.
+        if isinstance(m, (OptimizedModule, GraphModule)):
+            return forward(*args, **kwargs)
+
+        try:
+            return Tracer.call_module(self, m, forward, args, kwargs)
+        except _ModuleNotInstalledAsSubmoduleError:
+            log.debug(
+                "Unable to find the path of the module %s. "
+                "This might be because the module was not properly registered "
+                "as a submodule, which is not good practice. We will trace "
+                "through the module without recording stack information.",
+                str(m),
+            )
+            return forward(*args, **kwargs)
+
+    def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool:
+        return False
+
+    def create_node(self, *args: object, **kwargs: object) -> fx.node.Node:
+        """
+        Create node and add on metadata.
+        Add nn_module_stack here instead of TracerBase,
+        since calls to make_fx() might not want to record module stack metadata.
+        Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
+        Add stack_trace by filtering out forward() stack frames.
+        """
+        node = super().create_node(*args, **kwargs)  # type: ignore[arg-type]
+
+        # nn_module_stack
+        if node.op not in ["placeholder", "output"]:
+            if node.meta.get("nn_module_stack") is None:
+                node.meta["nn_module_stack"] = self.module_stack.copy()
+            # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]]
+            for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items():
+                if isinstance(mod_cls, type):
+                    node.meta["nn_module_stack"][key] = (
+                        fqn,
+                        mod_cls.__module__ + "." + mod_cls.__qualname__,
+                    )
+
+        # torch_fn
+        if (
+            node.op == "call_function"
+            and self.torch_fn_metadata is not None
+            and "torch_fn" not in node.meta
+        ):
+            node.meta["torch_fn"] = (
+                f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}",
+                f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}",
+            )
+
+        return node
+
+
+class _MakefxTracer:
+    def __init__(
+        self,
+        decomposition_table: Optional[Mapping[OpOverload, Callable]],
+        tracing_mode: str,
+        _allow_non_fake_inputs: bool,
+        pre_dispatch: bool,
+        record_module_stack: bool,
+        _allow_fake_constant: bool,
+        _error_on_data_dependent_ops: bool,
+        record_stack_traces: bool = False,
+        parent_tracer: Optional[_MakefxTracer] = None,
+        proxy_module_inputs: bool = False,
+    ) -> None:
+        # Configurations that are used to initialize the context managers and their states.
+        # Should not modify them during tracing.
+        self.decomposition_table: dict[OpOverload, Callable] = dict(
+            decomposition_table or {}
+        )
+        self.decomposition_table.setdefault(
+            torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel
+        )
+        self.tracing_mode: str = tracing_mode
+        self._allow_non_fake_inputs: bool = _allow_non_fake_inputs
+        self.pre_dispatch: bool = pre_dispatch
+        self.record_module_stack: bool = record_module_stack
+        self._allow_fake_constant: bool = _allow_fake_constant
+        self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops
+
+        # All context managers and their states should be initialized before tracing based on the inputs
+        # and configurations. After tracing, their states should be cleaned except for shape_env.
+        # Remember to specify how to initialize it from user inputs and from parent tracer whenever
+        # adding new modes in _MakefxTracer.
+        self.fake_tensor_mode: Optional[FakeTensorMode] = None
+        self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
+        self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = (
+            nullcontext()
+        )
+        self.fx_tracer: Optional[PythonKeyTracer] = None
+        self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
+        self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
+            nullcontext()
+        )
+        self.record_stack_traces = record_stack_traces
+        self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
+        self.proxy_module_inputs = proxy_module_inputs
+
+    def _checkpoint_modes(self) -> list[Any]:
+        return [
+            self.fake_tensor_mode,
+            self.proxy_mode,
+            self.proxy_function_mode,
+            self.fx_tracer,
+            self.python_dispatcher_mode,
+            self.torch_fn_metadata_mode,
+        ]
+
+    def _restore_modes(
+        self,
+        prev_fake_tensor_mode: Optional[FakeTensorMode],
+        prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode],
+        prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode],
+        prev_fx_tracer: Optional[PythonKeyTracer],
+        prev_python_dispatcher_mode: Union[nullcontext, Any],
+        prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode],
+    ) -> None:
+        self.fake_tensor_mode = prev_fake_tensor_mode
+        self.proxy_mode = prev_proxy_mode
+        self.proxy_function_mode = prev_proxy_function_mode
+        self.fx_tracer = prev_fx_tracer
+        self.python_dispatcher_mode = prev_python_dispatcher_mode
+        self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode
+
+    @contextmanager
+    def _init_modes_from_inputs(
+        self, f: Callable, args: tuple[object, ...]
+    ) -> Generator[None, None, None]:
+        prev_modes = self._checkpoint_modes()
+        try:
+            # Avoid importing sympy at a module level
+            from .symbolic_shapes import ShapeEnv
+
+            if hasattr(f, "_orig_mod") and self.record_module_stack:
+                scope_root = f._orig_mod
+                # _ModuleStackTracer always try to preserve stack trace
+                # in forward functions
+                self.fx_tracer = _ModuleStackTracer(scope_root)
+            else:
+                self.fx_tracer = PythonKeyTracer()
+                self.fx_tracer.record_stack_traces = self.record_stack_traces
+                if self.record_stack_traces:
+                    self.fx_tracer._record_forward_stack_traces_only = True
+
+            if self.tracing_mode == "fake":
+                import torch._dynamo
+
+                fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
+                if fake_tensor_mode is None:
+                    import torch._functorch.config as _config
+
+                    with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
+                        fake_tensor_mode = FakeTensorMode(
+                            allow_fallback_kernels=True,
+                            allow_non_fake_inputs=self._allow_non_fake_inputs,
+                            shape_env=ShapeEnv(),
+                            static_shapes=True,
+                        )
+                self.fake_tensor_mode = fake_tensor_mode
+            elif self.tracing_mode == "symbolic":
+                import torch._dynamo
+
+                fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
+                if fake_tensor_mode is None:
+                    shape_env = ShapeEnv()
+                    import torch._functorch.config as _config
+
+                    with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
+                        fake_tensor_mode = FakeTensorMode(
+                            allow_fallback_kernels=False,
+                            allow_non_fake_inputs=self._allow_non_fake_inputs,
+                            shape_env=shape_env,
+                        )
+                assert fake_tensor_mode.shape_env is not None, (
+                    "shape_env should be set if tracing with 'symbolic'"
+                )
+                self.fake_tensor_mode = fake_tensor_mode
+            else:
+                if not self.tracing_mode == "real":
+                    raise AssertionError(
+                        f"Unexpected tracing type: {self.tracing_mode}"
+                    )
+
+            self._construct_modes_with_fx_tracer(self.fx_tracer)
+            yield
+        finally:
+            self._restore_modes(*prev_modes)
+
+    def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None:
+        self.proxy_mode = ProxyTorchDispatchMode(
+            fx_tracer,
+            self.tracing_mode,
+            pre_dispatch=self.pre_dispatch,
+            _allow_fake_constant=self._allow_fake_constant,
+            _error_on_data_dependent_ops=self._error_on_data_dependent_ops,
+        )
+
+        if self.pre_dispatch:
+            self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
+
+        # pre-autograd tracing uses per-dispatch-key modes,
+        # which requires the python dispatcher
+        if self.tracing_mode == "symbolic" or self.pre_dispatch:
+            self.python_dispatcher_mode = enable_python_dispatcher()
+
+        self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
+        fx_tracer.proxy_module_inputs = self.proxy_module_inputs  # type: ignore[union-attr]
+
+    @contextmanager
+    def _init_modes_from_parent(
+        self, parent_tracer: _MakefxTracer
+    ) -> Generator[None, None, None]:
+        # By default, subtracer creates new modes based on parent tracer's config.
+        # However, there are cases where we want to share the same modes with parent tracer
+        # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same.
+        prev_modes = self._checkpoint_modes()
+        try:
+            self.fake_tensor_mode = parent_tracer.fake_tensor_mode
+
+            def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer:
+                if type(parent_tracer) is PythonKeyTracer:
+                    return PythonKeyTracer()
+                elif type(parent_tracer) is _ModuleStackTracer:
+                    return _ModuleStackTracer(parent_tracer.scope_root)
+                else:
+                    raise RuntimeError(
+                        f"Unexpected tracer type: {type(parent_tracer)}."
+                    )
+
+            assert parent_tracer.fx_tracer is not None
+            self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer)
+            self._construct_modes_with_fx_tracer(self.fx_tracer)
+            yield
+        finally:
+            self._restore_modes(*prev_modes)
+
+    def _trace_inner(self, f: Callable, *args: object) -> GraphModule:
+        # TODO: We need to explicitly import torch._dynamo before calling dispatch_trace,
+        # because dispatch_trace will introduce the lazy import of torch._dynamo,
+        # and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo,
+        # such as some torch API(torch.ones and so on) in populate_builtin_to_tensor_fn_map() will be affected
+        # by the context set before dispatch_trace.
+        import torch._dynamo
+
+        phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args)
+
+        def _wrap_fake(args: T) -> T:
+            arg_count = 0
+
+            def inner_wrap_fake(x: object) -> object:
+                nonlocal arg_count
+                # TODO: it would be nice to line these up with the names
+                # FX will choose for the placeholders, but we don't
+                # actually know what the names will be at this point yet
+                # NB: the Source here is actually meaningless
+                from torch._dynamo.source import ConstantSource
+
+                assert self.fake_tensor_mode is not None
+                source = ConstantSource(f"input{arg_count}")
+                if isinstance(x, Tensor):
+                    arg_count += 1
+                    return self.fake_tensor_mode.from_tensor(x, source=source)
+                # NB: don't match on bools
+                elif type(x) is int and self.tracing_mode == "symbolic":
+                    assert self.fake_tensor_mode.shape_env is not None, (
+                        "shape_env should be set if tracing with 'symbolic'"
+                    )
+                    return self.fake_tensor_mode.shape_env.create_symintnode(
+                        self.fake_tensor_mode.shape_env.create_symbol(
+                            x, source, positive=None
+                        ),
+                        hint=x,
+                        source=source,
+                    )
+                elif isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
+                    return torch._library.fake_class_registry.maybe_to_fake_obj(
+                        self.fake_tensor_mode, x
+                    )
+
+                assert not isinstance(x, FakeScriptObject), (
+                    f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
+                )
+                return x
+
+            wrap_fn_map = {
+                "real": lambda x: x,
+                "fake": inner_wrap_fake,
+                "symbolic": inner_wrap_fake,
+            }
+            return pytree.tree_map(wrap_fn_map[self.tracing_mode], args)
+
+        def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]:
+            if (
+                not hasattr(inspect.unwrap(f), "__code__")
+                or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS
+            ):
+                # FX doesn't support varargs, so we gotta fake up a wrapper
+                # TODO: Would be nice to fix this at the source...
+                return fake_signature(f, len(phs))
+            return f
+
+        args = _wrap_fake(args)
+        func = _wrap_func(f, phs)
+        # We disable the autocast cache as the autocast cache causes type conversions on parameters to
+        # check a cache, which introduces untracked tensors into the graph
+        #
+        # We also disable tracing by any other tensor proxy-based tracers except the current. The
+        # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
+        # thus irrelevant to any external functional trace.
+        proxy_mode: ProxyTorchDispatchMode = typing.cast(
+            ProxyTorchDispatchMode, self.proxy_mode
+        )
+        with ExitStack() as stack:
+            stack.enter_context(decompose(self.decomposition_table))
+            if self.fake_tensor_mode:
+                stack.enter_context(self.fake_tensor_mode)
+            stack.enter_context(self.python_dispatcher_mode)
+            stack.enter_context(self.proxy_function_mode)
+            stack.enter_context(self.torch_fn_metadata_mode)
+            stack.enter_context(proxy_mode)
+            stack.enter_context(disable_autocast_cache())
+            stack.enter_context(_set_make_fx_tracer(self))
+
+            assert self.fx_tracer is not None
+            try:
+                t = dispatch_trace(
+                    wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
+                    tracer=self.fx_tracer,
+                    concrete_args=tuple(phs),
+                )
+            except Exception:
+                trace_structured(
+                    "artifact",
+                    metadata_fn=lambda: {
+                        "name": "make_fx_fail_partial",
+                        "encoding": "string",
+                    },
+                    payload_fn=lambda: self.fx_tracer.graph.python_code(  # type: ignore[union-attr]
+                        root_module="self",
+                        verbose=True,
+                        include_stride=True,
+                        include_device=True,
+                    ).src,
+                )
+                raise
+
+        if (
+            self.is_hop_subgraph_tracer()
+            and (fake_mode := torch._guards.detect_fake_mode(args))
+            and fake_mode.shape_env is not None
+        ):
+            from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
+
+            insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
+            t.recompile()
+        # TODO: kind of a bad way to do it, should maybe figure out a better way
+        if self.tracing_mode == "symbolic":
+            assert self.fake_tensor_mode is not None
+            t.shape_env = self.fake_tensor_mode.shape_env  # type: ignore[assignment]
+        return t
+
+    def trace(self, f: Callable, *args: object) -> fx.GraphModule:
+        with self._init_modes_from_inputs(f, args):
+            return self._trace_inner(f, *args)
+
+    def is_hop_subgraph_tracer(self) -> bool:
+        return self.parent_tracer is not None
+
+    def trace_subgraph(self, f: Callable, *args: object) -> GraphModule:
+        # Create a new tracer based on parent's config
+        sub_tracer = _MakefxTracer(
+            self.decomposition_table,
+            "real",
+            self._allow_non_fake_inputs,
+            self.pre_dispatch,
+            self.record_module_stack,
+            self._allow_fake_constant,
+            self._error_on_data_dependent_ops,
+            parent_tracer=self,
+        )
+        with sub_tracer._init_modes_from_parent(self):
+            return sub_tracer._trace_inner(f, *args)
+
+
+_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None
+
+
+@contextmanager
+def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]:
+    global _CURRENT_MAKE_FX_TRACER
+    prev_tracer = _CURRENT_MAKE_FX_TRACER
+    try:
+        _CURRENT_MAKE_FX_TRACER = tracer
+        yield
+    finally:
+        _CURRENT_MAKE_FX_TRACER = prev_tracer
+
+
+def make_fx(
+    f: Callable,
+    decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
+    tracing_mode: str = "real",
+    _allow_non_fake_inputs: bool = False,
+    *,
+    pre_dispatch: bool = False,
+    record_module_stack: bool = False,
+    _allow_fake_constant: bool = False,
+    _error_on_data_dependent_ops: bool = True,
+    record_stack_traces: bool = False,
+    proxy_module_inputs: bool = False,
+) -> Callable[..., GraphModule]:
+    """
+    Given a function f, return a new function which when executed with valid
+    arguments to f, returns an FX GraphModule representing the set of operations that
+    were executed during the course of execution.
+
+    If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"]
+    """
+
+    assert tracing_mode in ["real", "fake", "symbolic"]
+
+    from torch._inductor import config
+
+    make_fx_tracer = _MakefxTracer(
+        decomposition_table,
+        tracing_mode,
+        _allow_non_fake_inputs,
+        pre_dispatch,
+        record_module_stack,
+        _allow_fake_constant,
+        _error_on_data_dependent_ops,
+        record_stack_traces=record_stack_traces
+        or config.trace.provenance_tracking_level == 1,
+        proxy_module_inputs=proxy_module_inputs,
+    )
+
+    @functools.wraps(f)
+    def wrapped(*args: object) -> GraphModule:
+        return make_fx_tracer.trace(f, *args)
+
+    return wrapped
+
+
+def get_torch_dispatch_modes() -> list[TorchDispatchMode]:
+    return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
+
+
+# TODO: this is a legacy name, there is only ever one proxy mode as it's an
+# infra mode
+def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
+    return get_proxy_mode()
+
+
+def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
+    """
+    Current the currently active proxy tracing mode, or None if
+    we are not currently tracing.  This includes pre-dispatch proxy
+    tracing.
+    """
+    pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(
+        torch._C._TorchDispatchModeKey.PROXY
+    )
+    mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
+    assert pre_dispatch_mode is None or mode is None, (
+        f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
+    )
+    return pre_dispatch_mode or mode
+
+
+def handle_sym_dispatch(
+    func: Callable[_P, R],
+    args: _P.args,  # type: ignore[valid-type]  # not allowed to use _P.args here
+    kwargs: _P.kwargs,  # type: ignore[valid-type]  # not allowed to use _P.kwargs here
+) -> R:
+    """
+    Call into the currently active proxy tracing mode to do a
+    SymInt/SymFloat/SymBool dispatch trace on a function that operates on
+    these arguments.
+    """
+    mode = get_proxy_mode()
+    assert mode
+    # Have to do it manually, because we're not doing the normal torch
+    # dispatch machinery which disables it for us
+    with disable_proxy_modes_tracing():
+        # TODO: properly compute types
+        types: list[type] = []
+        return mode.__sym_dispatch__(func, types, args, kwargs)  # type: ignore[arg-type, return-value]
+
+
+@contextmanager
+def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]:
+    return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
+
+
+def maybe_handle_decomp(
+    proxy_mode: ProxyTorchDispatchMode,
+    op: OpOverload,
+    args: tuple[object, ...],
+    kwargs: dict[str, object],
+) -> object:
+    from torch._inductor.compiler_bisector import CompilerBisector
+
+    if op in CURRENT_DECOMPOSITION_TABLE:
+        if CompilerBisector.disable_subsystem(
+            "aot_eager_decomp_partition", "decomposition", lambda: repr(op)
+        ):
+            return NotImplemented
+
+        with proxy_mode:
+            proxy_mode.decomp_layers += 1
+            out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
+            proxy_mode.decomp_layers -= 1
+            return out
+
+    return NotImplemented
+
+
+def get_isolated_graphmodule(
+    func: Callable,
+    args: tuple[object, ...],
+    kwargs: dict[str, object],
+    tracing_mode: str = "real",
+    decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
+) -> GraphModule:
+    """A helper function used to get the GraphModule for the given func.
+
+    It's expected to be used in the ProxyTensor tracing context.
+    It detaches the args and kwargs from the current tracer so that the trace of
+    the current graph module can be created without any side-effects.
+    """
+    wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
+
+    with disable_proxy_modes_tracing():
+        gm = make_fx(
+            wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode
+        )(all_args)
+    return gm
+
+
+def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
+    """A helper function for setting up unbacked_bindings on the destination FX graph."""
+    from .symbolic_shapes import compute_unbacked_bindings
+
+    # Can't use detect_fake_mode here,
+    #
+    # python test/distributed/_tensor/test_dtensor_compile.py -k
+    # test_tp_compile_fullgraph_is_seq_parallel_False
+    #
+    # will fail.  Very strange, it probably isn't right for them to be using
+    # two fake modes there...
+    fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
+    if fake_mode and fake_mode.shape_env:
+        if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
+            assert isinstance(out_proxy, Proxy), out_proxy
+            out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/recording.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/recording.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ec092898cd69d74362acbe57a029b09d9b23bee
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/recording.py
@@ -0,0 +1,530 @@
+# mypy: allow-untyped-defs
+import functools
+import inspect
+import itertools
+import logging
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.utils._pytree as pytree
+
+
+log = logging.getLogger(__name__)
+trace_shape_events_log = torch._logging.getArtifactLogger(
+    __name__, "trace_shape_events"
+)
+
+
+__all__ = [
+    "ShapeEnvEvent",
+    "record_shapeenv_event",
+    "replay_shape_env_events",
+    "FakeTensorMeta",
+    "shape_env_check_state_equal",
+    "NotEqualError",
+]
+
+# [Note: Recording ShapeEnv Events]
+# =================================
+#
+# What is a ShapeEnv event?
+# -------------------------
+# We consider a ShapeEnv event every function call (ShapeEnv method or
+# independent function) that modifies the state of the ShapeEnv instance.
+# Such calls are recorded alongside their positional and keyword arguments,
+# so that it may be replayed over a different ShapeEnv instance.
+#
+# See [Note: ShapeEnv State Equality] for what is considered the state
+# of a ShapeEnv instance.
+#
+# What is it for?
+# ---------------
+# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
+# arbitrary state in time.
+#
+# Being able to arbitrarily replay events like so is useful, mainly for
+# translation validation bisection. i.e. if a ValidationException has been
+# raised, find the earliest point in time where the translation validation
+# fails.
+#
+# Besides that, it also allows us to inspect the given instance and,
+# for example, check the guards that would actually be issued at that point.
+#
+# What kind of arguments can be stored in an event?
+# -------------------------------------------------
+# There's no specific rule for what cannot be used as an argument.
+# That said, pay special attention to the following cases:
+#
+#   1. Tensor inputs: there are some tests that check whether the inputs
+#      were garbage collected after execution. These will fail if there's
+#      an event that is holding a reference to those inputs.
+#
+#   2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
+#      will be automatically replaced by the new given ShapeEnv instance.
+#
+#   3. SymTypes arguments: they also hold references to ShapeEnv. So,
+#      whenever we see them, we create a new instance, replacing the
+#      ShapeEnv reference.
+#
+#   4. FX nodes: specifically, FX nodes from the FX graph for symbolic
+#      shapes. That argument must be replaced when replaying the event at
+#      ShapeEnvEvent.run, since it has to reference a node from the given
+#      instance, and not from the recorded instance.
+
+
+# Event class for reconstructing ShapeEnv at arbitrary time.
+#
+# Represents a method call that mutates ShapeEnv in a way that affects the
+# issued guards, when ShapeEnv.produce_guards is called.
+@dataclass
+class ShapeEnvEvent:
+    # ShapeEnv method.
+    f: Callable
+
+    # Arguments and keyword arguments called with.
+    args: Optional[list[Any]] = None
+    kwargs: Optional[dict[str, Any]] = None
+
+    # List of tracked_fakes at the time the method was called.
+    tracked_fakes: Optional[list[Any]] = None
+
+    # Name of the captured event.
+    # Used for special handling of particular methods.
+    name: Optional[str] = None
+
+    # Replay itself, but using shape_env as self.
+    def run(self, shape_env=None) -> Any:
+        from torch.fx.experimental.symbolic_shapes import (
+            is_symbolic,
+            ShapeEnv,
+            SymTypes,
+        )
+
+        # Special handling for the constructor event.
+        if self.f is ShapeEnv:
+            assert shape_env is None and self.args is None and self.kwargs is not None
+            return ShapeEnv(**self.kwargs)
+
+        assert shape_env is not None
+        args = list(self.args or [])
+        kwargs = dict(self.kwargs or {})
+
+        # Replace any argument of type ShapeEnv by the given one.
+        args, kwargs = pytree.tree_map_only(
+            ShapeEnv, lambda _: shape_env, (args, kwargs)
+        )
+
+        # Replace any argument of type SymTypes by a new instance,
+        # replacing its ShapeEnv reference.
+        args, kwargs = pytree.tree_map_only(
+            lambda x: isinstance(x, SymTypes) and is_symbolic(x),
+            lambda a: type(a)(a.node.with_shape_env(shape_env)),
+            (args, kwargs),
+        )
+
+        # Converts FX nodes using the mapping argument.
+        def maybe_convert_node(x: Any) -> Any:
+            if not isinstance(x, torch.fx.Node):
+                # Don't do anything to x if it's not an FX node.
+                return x
+
+            # If, at some point, we created an FX node, it means that translation validation is on.
+            # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
+            # we are tracking node names at shape_env.name_to_node.
+            assert hasattr(shape_env, "name_to_node")
+            name_to_node = shape_env.name_to_node  # type: ignore[attr-defined]
+            assert x.name in name_to_node
+            return name_to_node[x.name]
+
+        # Replaces the value of an specific argument by the result of fn.
+        def replacearg(index: int, key: str, fn: Callable):
+            if index < len(args):
+                args[index] = fn(args[index])
+            if key in kwargs:
+                kwargs[key] = fn(kwargs[key])
+
+        if self.is_create_fx_call_function():
+            # ShapeEnv.create_fx_call_function:
+            # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
+            # They must be replaced, since a "call_function" FX node with this tuple as argument
+            # will be added to the FX graph of the new shape_env.
+            replacearg(
+                index=2,
+                key="args",
+                fn=lambda args: tuple(maybe_convert_node(a) for a in args),
+            )
+        if self.is_evaluate_expr() or self.is_defer_runtime_assert():
+            # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert:
+            # "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
+            # They must be replaced, since it will be part of a "call_function" FX node for
+            # torch._assert, which will be added to the FX graph of the new shape_env.
+            replacearg(index=3, key="fx_node", fn=maybe_convert_node)
+
+        # Actually call the method with the converted arguments.
+        return self.f(*args, **kwargs)
+
+    def __str__(self) -> str:
+        name = self.name if self.name is not None else self.f.__name__
+        return f"event: {name} ({self.args}, {self.kwargs})"
+
+    def is_create_fx_call_function(self) -> bool:
+        return self.name == "_create_fx_call_function"
+
+    def is_evaluate_expr(self) -> bool:
+        return self.name == "evaluate_expr"
+
+    def is_defer_runtime_assert(self) -> bool:
+        return self.name == "guard_or_defer_runtime_assert"
+
+
+NEST = 0
+
+
+# Extracts a ShapeEnv instance inside args and kwargs.
+# Specifically, it looks for:
+#   1. ShapeEnv arguments
+#   2. SymInt, SymFloat, or SymBool arguments
+# If we find more than one object of any of the above types, we
+# also check that the ShapeEnv instance is the same for all of them.
+def _extract_shape_env_and_assert_equal(args, kwargs):
+    from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
+
+    def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
+        if old is not None:
+            assert old is new, "call with different ShapeEnv"
+        return new
+
+    shape_env = None
+    for val in itertools.chain(args, kwargs.values()):
+        if isinstance(val, ShapeEnv):
+            shape_env = assert_equal(shape_env, val)
+        if isinstance(val, SymTypes) and is_symbolic(val):
+            shape_env = assert_equal(shape_env, val.node.shape_env)
+
+    return shape_env
+
+
+# Decorator for recording the given function as a replayable event.
+#
+# This decorator should be used at every function that mutates the state of
+# ShapeEnv in some way that affects the resulting issued guards (i.e. when
+# ShapeEnv.produce_guards is called).
+#
+# save_tracked_fakes: saves a snapshot of the TrackedFake list.
+# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
+#
+# name: the name of the function being recorded. Normally (and by default) this
+# is taken from the decorated function but can be set if you need to override
+# it.
+#
+# When to save the list of TrackedFake?
+# =====================================
+# We should save the list of TrackedFake whenever the translation validation
+# bisection may actually stop and call the produce_guards method at the moment
+# right after the recorded function was played. In other words, since the
+# bisection bisects through torch._assert calls, we should save in all methods
+# that adds a torch._assert call to the symbolic shapes FX graph.
+#
+# At the moment, there are 2 methods that save the list:
+#   - ShapeEnv.evaluate_expr
+#   - ShapeEnv.guard_or_defer_runtime_assert
+def record_shapeenv_event(
+    *, save_tracked_fakes: bool = False, name: Optional[str] = None
+) -> Callable:
+    def decorator(fn: Callable) -> Callable:
+        assert callable(fn)
+        args = inspect.getfullargspec(fn).args
+        assert args and args[0] == "self", (
+            "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
+            "code so that it calls into a method on ShapeEnv"
+        )
+        nonlocal name
+        if name is None:
+            name = fn.__name__
+
+        @functools.wraps(fn)
+        def wrapper(*args, **kwargs):
+            from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+            assert isinstance(args[0], ShapeEnv)
+
+            global NEST
+
+            trace_shape_events_log.debug(
+                "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
+            )
+            NEST += 1
+
+            def retlog(r):
+                trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
+                return r
+
+            shape_env = args[0]
+
+            try:
+                if not shape_env.should_record_events or shape_env.is_recording:  # type: ignore[has-type]
+                    # If ShapeEnv is already recording an event, call the wrapped
+                    # function directly.
+                    #
+                    # NB: here, we skip the check of whether all ShapeEnv instances
+                    # are equal, in favor of a faster dispatch.
+                    return retlog(fn(*args, **kwargs))
+
+                # Retrieve an instance of ShapeEnv.
+                # Assumption: the collection of args and kwargs may not reference
+                # different ShapeEnv instances.
+                self = _extract_shape_env_and_assert_equal(args, kwargs)
+
+                # If we are calling this function without any ShapeEnv instance
+                # alive in its arguments, we don't record and call the original.
+                if self is None:
+                    return retlog(fn(*args, **kwargs))
+
+                # Otherwise, start recording and call the function.
+                with self._recording():
+                    # Take a snapshot of the current tracked_fakes.
+                    tracked_fakes = (
+                        self._snapshot_tracked_fakes() if save_tracked_fakes else None
+                    )
+                    # Record the event for 'fn'.
+                    event = ShapeEnvEvent(
+                        fn,
+                        list(args),
+                        kwargs,
+                        tracked_fakes,
+                        name=name,
+                    )
+                    # Play the event on this ShapeEnv.
+                    # NB: It's important to put the event first, because running
+                    # the event can trigger internal events that must be ordered
+                    # after this event.  However, if an exception happens, we do
+                    # NOT want to have the event in the list, so pop it off from
+                    # the record if an error happened
+                    self.events.append(event)
+                    try:
+                        return retlog(event.run(self))
+                    except Exception:
+                        self.events.pop()
+                        raise
+
+            except Exception:
+                if not shape_env.should_record_events or shape_env.is_recording:
+                    # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging.
+                    raise
+                log.error(  # noqa: G201
+                    "failed while running %s(*%s, **%s)",
+                    name,
+                    args[1:],
+                    kwargs,
+                    exc_info=log.isEnabledFor(logging.INFO),
+                )
+                raise
+
+            finally:
+                NEST -= 1
+
+        return wrapper
+
+    return decorator
+
+
+# Replays the ShapeEnvEvents list.
+# It assumes the first event is the constructor call.
+#
+# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
+def replay_shape_env_events(events):
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+    constructor_event = events[0]
+    assert constructor_event.f == ShapeEnv
+
+    # Constructs the new ShapeEnv.
+    shape_env = constructor_event.run()
+
+    for event in events[1:]:
+        try:
+            # Actually replays each event.
+            # We need to call create_mapping_fn every time, since the node list might
+            # change after each event is replayed.
+            event.run(shape_env)
+        except Exception:
+            log.error("failed when running event: %s", event)
+            raise
+
+    return shape_env
+
+
+# FakeTensor metadata.
+# This is to be used in place of FakeTensor placeholders when calling
+# ShapeEnv.produce_guards.
+@dataclass
+class FakeTensorMeta:
+    tensor_size: tuple[Union[int, torch.SymInt], ...]
+    tensor_stride: tuple[Union[int, torch.SymInt], ...]
+    tensor_storage_offset: Union[int, torch.SymInt]
+    is_nested: bool
+
+    def size(self) -> tuple[Union[int, torch.SymInt], ...]:
+        return self.tensor_size
+
+    def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
+        return self.tensor_stride
+
+    def storage_offset(self) -> Union[int, torch.SymInt]:
+        return self.tensor_storage_offset
+
+    def dim(self) -> int:
+        return len(self.tensor_size)
+
+    @staticmethod
+    def from_fake(fake) -> "FakeTensorMeta":
+        return FakeTensorMeta(
+            fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
+        )
+
+
+# [Note: ShapeEnv State Equality]
+# ===============================
+#
+# What is considered ShapeEnv state?
+# ----------------------------------
+# We consider to be the state of a ShapeEnv instance everything that
+# is not in the inline tuple inside remove_nonstate_variables function.
+# That is: the fields within ShapeEnv that modify the flow of execution
+# of the program.
+#
+# So, for example: the replacements field might influence on how an
+# expression is simplified. That, in turn, may result in a guard being
+# statically known (i.e. not added).
+#
+# On the other hand, var_to_stack serves only changes what is printed
+# in the screen, i.e. used only for debugging purposes. Therefore, we
+# should not consider it when comparing states.
+#
+# What to do on NotEqualError?
+# ----------------------------
+# Here are a few possible causes for getting a NotEqualError raised:
+#
+#   1. New field that does not belong in the ShapeEnv state.
+#      For example: log field of type ShapeEnvLoggerAdapter. Different
+#      ShapeEnv instances will always have different ShapeEnvLoggerAdapter
+#      instances, i.e. equality comparison would fail.
+#      Solution: add it to the inlined tuple inside remove_nonstate_variables
+#      function inside check_equal method.
+#
+#   2. New field that is not directly comparable across instances.
+#      For example: guards field of type List[ShapeGuard]. More specifically,
+#      the ShapeGuard type holds an expression and a stack information
+#      for debugging purposes. When replaying the even on a new ShapeEnv
+#      instance, the stack would be different, which would trigger this error.
+#      Solution: add a special case to the map_value function inside
+#      check_equal function.
+#
+#   3. Mutation of ShapeEnv on some not recorded function.
+#      If a mutation of the state of ShapeEnv happens inside a function
+#      that is not recorded (or that no caller in the stack is recorded),
+#      then, the replayed ShapeEnv won't catch that.
+#      Solution: decorate the function with record_shape_env_event.
+
+
+# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
+# returned by ShapeEnv.produce_guards.
+def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
+    # Collect and remove variables that don't necessarily represent the state
+    # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
+    # instance itself.
+    env1_vars = vars(env1).copy()
+    env2_vars = vars(env2).copy()
+
+    for v in non_state_variable_names:
+        if v in env1_vars:
+            env1_vars.pop(v)
+        if v in env2_vars:
+            env2_vars.pop(v)
+
+    # Function for transforming the mismatched values into string.
+    # Needed, since dict and set entries order might not be the same every time.
+    def value_to_str(value: Any) -> str:
+        if isinstance(value, dict):
+            return (
+                "{"
+                + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
+                + "}"
+            )
+        if isinstance(value, set):
+            return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
+        return str(value)
+
+    # Compares env1_vars with env2_vars.
+    # Here, we allow the value of each field to be mapped, so that we appropriately
+    # compare the two values.
+    def compare_vars(
+        map_value: Callable[[str, Any], Any],
+    ) -> list[tuple[str, str, str]]:
+        env1_set, env2_set = set(env1_vars), set(env2_vars)
+
+        # First, compare the set of keys in each vars dictionary.
+        if env1_set != env2_set:
+            raise NotEqualError(
+                "field set mismatch:",
+                [
+                    (
+                        "found unique fields:",
+                        str(sorted(env1_set - env2_set)),
+                        str(sorted(env2_set - env1_set)),
+                    ),
+                ],
+            )
+
+        # Then, sort the keys, and compare the mapped values of each key.
+        sorted_keys = list(env1_set)
+        sorted_keys.sort()
+
+        mapped_dict = [
+            (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
+            for k in sorted_keys
+        ]
+
+        # Return a list of tuples representing the fields that did not match
+        # alongside their respective mapped values.
+        return [
+            (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
+            for k, val1, val2 in mapped_dict
+            if val1 != val2
+        ]
+
+    # Accumulate the mismatching fields.
+    errors = compare_vars(map_value)
+
+    if len(errors) > 0:
+        raise NotEqualError("field values don't match:", errors)
+
+
+class NotEqualError(Exception):
+    def __init__(
+        self,
+        msg: str,
+        mismatched: list[tuple[str, str, str]],
+    ) -> None:
+        details = "\n".join(
+            [
+                "\n".join(
+                    [
+                        f"==> {inner_msg}",
+                        f"  >  Left: {str1}",
+                        f"  > Right: {str2}",
+                    ]
+                )
+                for inner_msg, str1, str2 in mismatched
+            ]
+        )
+
+        super().__init__(
+            f"""\
+ShapeEnv not equal: {msg}
+
+{details}
+"""
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e92163a2139caab2fd2a690d810f52073e75644
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py
@@ -0,0 +1,16 @@
+class Equality:
+    def __init__(self, lhs: object, rhs: object):
+        self.lhs = lhs
+        self.rhs = rhs
+
+    def __str__(self) -> str:
+        return f"{self.lhs} = {self.rhs}"
+
+    def __repr__(self) -> str:
+        return f"{self.lhs} = {self.rhs}"
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, Equality):
+            return self.lhs == other.lhs and self.rhs == other.rhs
+        else:
+            return False
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cc902599aeb0a36d8253b0cf8cbece3f6e5ac68
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py
@@ -0,0 +1,144 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import ast
+import copy
+import functools
+import inspect
+import textwrap
+from collections.abc import Callable
+from types import FunctionType
+from typing import Any, cast, Optional, Union
+
+import torch
+from torch._sources import normalize_source_lines
+from torch.fx._symbolic_trace import Tracer
+from torch.fx.graph import Graph
+
+
+class AST_Rewriter(ast.NodeTransformer):
+    """
+    Take a FunctionType object representing a `forward` method, then
+    perform an AST rewrite to swap out nodes that are not symbolically
+    traceable with a callsite to the FX alternative.
+
+    To support swapping out an AST node, define a new `visit` method on
+    that node. For more details, see:
+    https://docs.python.org/3/library/ast.html#ast.NodeTransformer
+    """
+
+    # This function checks for new keys added in the globals dict. TorchDynamo
+    # can insert new keys in the global dict and upset the check. Therefore, put
+    # a disable here. This function is an optimization pass and not really
+    # suitable for dynamo tracing anyways.
+    @torch._dynamo.disable
+    def rewrite(self, fn: FunctionType):
+        # Normalize the source lines
+        sourcelines, _ = inspect.getsourcelines(fn)
+        sourcelines = normalize_source_lines(sourcelines)
+        source = "".join(sourcelines)
+        normalized_str = textwrap.dedent(source)
+
+        # Rewrite the original AST
+        source_ast = ast.parse(normalized_str)
+        dest_ast = ast.fix_missing_locations(self.visit(source_ast))
+
+        # Pull out the compiled function from the newly-created Module
+        code = compile(dest_ast, "", "exec")
+        globals_dict = copy.copy(fn.__globals__)
+        keys_before = set(globals_dict.keys())
+        exec(code, globals_dict)
+        new_keys = list(set(globals_dict.keys()) - keys_before)
+        assert len(new_keys) == 1
+        fn_compiled = globals_dict[new_keys[0]]
+
+        # return the compiled function with the original globals
+        def change_func_globals(f, globals):
+            """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
+            # __globals__ is a private member of the function class
+            # so we have to copy the function, f, all of its member, except f.__globals__
+            g = FunctionType(
+                f.__code__,
+                globals,
+                name=f.__name__,
+                argdefs=f.__defaults__,
+                closure=f.__closure__,
+            )
+            g = functools.update_wrapper(g, f)
+            g.__kwdefaults__ = copy.copy(f.__kwdefaults__)  # type:ignore[attr-defined]
+            return g
+
+        # Return the correct FunctionType object
+        return change_func_globals(fn_compiled, globals=fn.__globals__)
+
+    def visit_Assert(self, node):
+        """
+        Swap out the Assert node (Python's `assert`) with a callsite to the
+        symbolically-traceable torch._assert function
+        """
+        # Create the Call node
+        n = ast.parse("torch._assert()", mode="eval")
+        assert isinstance(n, ast.Expression)
+        call_node = n.body
+        assert isinstance(call_node, ast.Call)
+        msg = node.msg if node.msg else ast.Constant(value="", kind=None)
+        call_node.args = [node.test, msg]
+
+        # Ensure that the new node conforms to the Python AST grammar
+        expr_wrapper = ast.Expr(value=call_node)
+
+        # Return the new Call node to signify that we want to use it as
+        # a replacement for the original _assert node
+        return ast.copy_location(expr_wrapper, node)
+
+    def visit_AnnAssign(self, node):
+        """
+        Swap out Python's AnnAssign with an Assign node where the annotation function is called.
+        Example:
+             Original:
+             y: Tensor_Type(1,2,3, Dyn) = f2(x)
+            Output:
+             y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
+        """
+        return ast.Assign(
+            targets=[node.target],
+            value=ast.Call(
+                func=ast.Name(id="annotate", ctx=ast.Load()),
+                args=[node.value, node.annotation],
+                keywords=[],
+            ),
+        )
+
+
+class RewritingTracer(Tracer):
+    def trace(
+        self,
+        root: Union[torch.nn.Module, Callable],
+        concrete_args: Optional[dict[str, Any]] = None,
+    ) -> Graph:
+        return super().trace(_rewrite(root), concrete_args)
+
+
+def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
+    if isinstance(fn, torch.nn.Module):
+        # Rewrite this module's `forward` as well as the `forward`s of
+        # all of this module's recursive descendents. Return the new,
+        # rewritten module hierarchy.
+        def rewrite_module(m: torch.nn.Module):
+            class RewrittenModule(torch.nn.Module):
+                def __init__(self, orig):
+                    super().__init__()
+                    for k, v in orig.__dict__.items():
+                        if isinstance(v, torch.nn.Module):
+                            self.__dict__[k] = copy.copy(rewrite_module(v))
+                        else:
+                            self.__dict__[k] = copy.copy(v)
+
+            RewrittenModule.forward = AST_Rewriter().rewrite(
+                cast(FunctionType, m.forward)
+            )
+            return RewrittenModule(m)
+
+        return rewrite_module(fn)
+    else:
+        # Rewrite this single free function
+        return AST_Rewriter().rewrite(cast(FunctionType, fn))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1b2f1680d64a1ff928a8519dd4d93d61a861a54
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py
@@ -0,0 +1,145 @@
+# mypy: allow-untyped-defs
+import inspect
+from typing import Any, Optional
+
+import torch
+import torch.fx
+from torch._jit_internal import boolean_dispatched
+from torch.fx import Transformer
+from torch.fx.node import Argument, Target
+from torch.fx.operator_schemas import _torchscript_type_to_python_type
+
+
+class AnnotateTypesWithSchema(Transformer):
+    """
+    Use Python function signatures to annotate types for `Nodes` within an FX graph.
+    This pulls out Python function signatures for:
+
+        1. Standard `torch.nn` Module calls
+        2. `torch.nn.functional` calls
+        3. Attribute fetches via `get_attr`
+
+    Example usage:
+
+        m = torchvision.models.resnet18()
+
+        traced = torch.fx.symbolic_trace(m)
+
+        traced = AnnotateTypesWithSchema(traced).transform()
+
+    """
+
+    def __init__(
+        self,
+        module: torch.nn.Module,
+        annotate_functionals: bool = True,
+        annotate_modules: bool = True,
+        annotate_get_attrs: bool = True,
+    ):
+        super().__init__(module)
+        self.annotate_functionals = annotate_functionals
+        self.annotate_modules = annotate_modules
+        self.annotate_get_attrs = annotate_get_attrs
+
+    def call_function(
+        self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+    ):
+        python_ret_type = None
+        if self.annotate_functionals and target.__module__ == "torch.nn.functional":
+            target_for_analysis = target
+            if target in boolean_dispatched:
+                # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
+                # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
+                # branches of the dispatch have exactly the same signature. If they do, use the `true`
+                # branch signature for analysis. Otherwise, leave this un-normalized
+                assert not isinstance(target, str)
+                dispatched = boolean_dispatched[target]
+                if_true, if_false = dispatched["if_true"], dispatched["if_false"]
+                # TODO: can we emit the union of these? What are the implications on TorchScript
+                # compilation?
+                if (
+                    inspect.signature(if_true).return_annotation
+                    != inspect.signature(if_false).return_annotation
+                ):
+                    return super().call_function(target, args, kwargs)
+                target_for_analysis = if_true
+
+            python_ret_type = self._extract_python_return_type(target_for_analysis)
+
+        return_proxy = super().call_function(target, args, kwargs)
+        return_proxy.node.type = (
+            return_proxy.node.type if return_proxy.node.type else python_ret_type
+        )
+        return return_proxy
+
+    def call_module(
+        self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+    ):
+        python_ret_type = None
+        assert isinstance(target, str)
+        submod = self.fetch_attr(target)
+        if self.annotate_modules and hasattr(submod.__class__, "__name__"):
+            classname = submod.__class__.__name__
+            if getattr(torch.nn, classname, None) == submod.__class__:
+                python_ret_type = self._extract_python_return_type(submod.forward)
+        return_proxy = super().call_module(target, args, kwargs)
+        return_proxy.node.type = (
+            return_proxy.node.type if return_proxy.node.type else python_ret_type
+        )
+        return return_proxy
+
+    def get_attr(
+        self,
+        target: torch.fx.node.Target,
+        args: tuple[Argument, ...],
+        kwargs: dict[str, Any],
+    ):
+        attr_proxy = super().get_attr(target, args, kwargs)
+
+        if self.annotate_get_attrs:
+            module_itr = self.module
+            assert isinstance(target, str)
+            atoms = target.split(".")
+            for i, atom in enumerate(atoms):
+                if not hasattr(module_itr, atom):
+                    raise RuntimeError(
+                        f"Node referenced nonextent target {'.'.join(atoms[:i])}!"
+                    )
+                module_itr = getattr(module_itr, atom)
+
+            maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
+            if maybe_inferred_ts_type.success():
+                python_type = _torchscript_type_to_python_type(
+                    maybe_inferred_ts_type.type()
+                )
+                attr_proxy.node.type = (
+                    python_type if not attr_proxy.node.type else attr_proxy.node.type
+                )
+
+        return attr_proxy
+
+    def _extract_python_return_type(self, target: Target) -> Optional[Any]:
+        """
+        Given a Python call target, try to extract the Python return annotation
+        if it is available, otherwise return None
+
+        Args:
+
+            target (Callable): Python callable to get return annotation for
+
+        Returns:
+
+            Optional[Any]: Return annotation from the `target`, or None if it was
+                not available.
+        """
+        assert callable(target)
+        try:
+            sig = inspect.signature(target)
+        except (ValueError, TypeError):
+            return None
+
+        return (
+            sig.return_annotation
+            if sig.return_annotation is not inspect.Signature.empty
+            else None
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..96b44b0aebd4d34eb9dc00fa3bfc0e133fe609bc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py
@@ -0,0 +1,1896 @@
+# mypy: allow-untyped-defs
+
+from __future__ import annotations
+
+
+"""
+This file does three things:
+- Contains the definition of SymNode
+- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
+- Does not depend on sympy at import time
+
+As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
+to avoid having to load SymPy at import time, as doing so is *very* slow.
+"""
+
+
+import builtins
+import functools
+import inspect
+import itertools
+import logging
+import math
+import operator
+import sys
+from functools import lru_cache, update_wrapper
+from typing import Optional, TYPE_CHECKING, Union
+
+import torch
+import torch._logging.structured as structured
+
+# NB: The sym_* functions are used via getattr() and must be imported here.
+from torch import (  # noqa: F401
+    sym_float,
+    sym_ite,
+    sym_max,
+    sym_min,
+    sym_not,
+    SymBool,
+    SymFloat,
+    SymInt,
+)
+from torch._logging import dtrace_structured
+
+
+if TYPE_CHECKING:
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+log = logging.getLogger(__name__)
+sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
+
+
+__all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"]
+
+
+from torch.types import py_sym_types as SymTypes
+
+
+def _to_symtype(t):
+    if t is bool:
+        return SymBool
+    if t is int:
+        return SymInt
+    if t is float:
+        return SymFloat
+    return t
+
+
+# TODO: An incomplete list
+# 1. Set variables to be equal when we do equality
+# 2. Specialize on 0/1 when we do subtraction
+class SymNode:
+    """
+    This is a type erased SymInt/SymFloat which we use to do actual operations.
+    End users don't touch this.  Magic methods are NOT defined on this object.
+    """
+
+    # Note [optimized_summation]: indicates that SymNode is an Add expression of the form
+    # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
+    # for common patterns see _optimized_add.
+
+    # The unfortunate reason we have this here is because sympy sets  __slots__ = () for add expression,
+    # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
+    # a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
+    _optimized_summation: bool = False
+
+    def __init__(
+        self,
+        expr,
+        shape_env,
+        pytype,
+        hint: Optional[Union[int, float, bool]],
+        constant=None,
+        fx_node=None,
+        optimized_summation=False,
+    ):
+        self._expr = expr
+        self.shape_env = shape_env
+        self.pytype = pytype
+        self._optimized_summation = optimized_summation
+
+        # What's the difference between hint and constant?
+        #
+        # - A constant is known to be invariant across invocations of the model;
+        #   it will always be this value.  We only really know this when we
+        #   encounter an honest-to-goodness literal (when wrapping it into
+        #   a SymNode, we set constant.)  Most of the time, constant is None
+        #
+        # - A hint is a *particular* value from the particular run we are
+        #   tracing, but it may vary the next time around.  It's useful to
+        #   keep this around, as if we need a concrete value from a SymNode,
+        #   we will return the hint and guard on the expression that produced
+        #   it giving the same hint next time around.  The hint is not
+        #   guaranteed to be set either: if you have an unbacked SymNode,
+        #   there won't be any hint; it was the result of some tensor-dependent
+        #   computation, but we don't know what it actually is because we
+        #   haven't actually run the tensor computation.
+        #
+        # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
+        # in hopes that we've learned enough about the unbacked symints to
+        # discharge the hint; otherwise, you're likely to just error out.
+        #
+        # (A previous version of this system had some optimizations to only
+        # recompute when it was possible we had learned enough about the
+        # unbacked symint that a hint was now possible, but as we added more
+        # potential refinements to unbacked symints this got harder to keep
+        # in sync, so we've deleted it for now.)
+
+        def compute_hint():
+            from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
+
+            # This occasionally gets exercised by, e.g.,
+            # convert_shape_to_symint.  It's just a nicety so you don't HAVE
+            # to have a correct hint on hand when making a SymNode.
+            # Don't attempt to compute for unbacked, this can be quite
+            # expensive.
+            if has_free_unbacked_symbols(self.expr):
+                return None
+            hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
+            if hint is not None:
+                hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
+            return hint
+
+        if hint is not None:
+            assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
+                "Cannot create SymNode of type "
+                f"{pytype} with incompatible hint of type {type(hint)}"
+            )
+            if self.shape_env and self.shape_env._translation_validation_enabled:
+                # This is technically not TV, but this assert is expensive so
+                # let's only do it when we're already doing expensive things
+                computed_hint = compute_hint()
+                assert hint == computed_hint, (
+                    f"{hint} != {computed_hint} (for {self.expr})"
+                )
+        else:
+            hint = compute_hint()
+        self._hint = hint
+        self.constant: Optional[Union[int, float, bool]] = constant
+
+        # Record the FX node of the current node if we are doing translation
+        # validation. They will be used for building the input assertions for
+        # the translation validation problem.
+        tx_validation_en = (
+            self.shape_env and self.shape_env._translation_validation_enabled
+        )
+        self.fx_node = tx_validation_en and fx_node
+
+    def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
+        return SymNode(
+            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
+        )
+
+    def _value_eq(self, other: SymNode) -> bool:
+        # Purposely don't include the shape_env in the eq.
+        return (
+            self._expr == other._expr
+            and self.pytype == other.pytype
+            and self._hint == other._hint
+            and self.constant == other.constant
+            and self.fx_node == other.fx_node
+        )
+
+    def _value_hash(self) -> int:
+        # Purposely don't include the shape_env in the hash.
+        return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
+
+    @property
+    def expr(self):
+        return self.shape_env.replace(self._expr)
+
+    @property
+    def hint(self):
+        return self._hint
+
+    def has_hint(self):
+        return self._hint is not None
+
+    def require_hint(self, fallback=None):
+        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+
+        if self._hint is None:
+            if fallback is not None:
+                # Say we have some expr like 2*u0 + s0
+                # The hint will be None, since the expr contains at least 1 unbacked.
+                # We will:
+                # - replace every backed free symbol with its corresponding hint
+                # - replace every unbacked free symbol with the fallback
+                # - regenerate the expression with those symbol replacements
+                # Note: this is not really complete either, since right now
+                # this logic does not take into account any value ranges
+                # for the unbacked symints, we may need to beef it up at some point.
+                unbacked_symbols = free_unbacked_symbols(self.expr)
+                replacements = {
+                    s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s]
+                    for s in self.expr.free_symbols
+                }
+                return self.expr.xreplace(replacements)
+            # NB: we expect this to raise
+            return self.shape_env.size_hint(self.expr)
+        return self._hint
+
+    def maybe_as_int(self):
+        if self.expr.is_number:
+            return int(self.expr)
+        else:
+            return None
+
+    # NB: This does conversions, not sure if this is good or not
+    def maybe_as_float(self):
+        import sympy
+
+        if isinstance(self.expr, sympy.Float):
+            return float(self.expr)
+        else:
+            return None
+
+    def maybe_as_bool(self):
+        import sympy
+
+        if self.expr is sympy.true:
+            return True
+        elif self.expr is sympy.false:
+            return False
+        else:
+            return None
+
+    def is_int(self):
+        return self.pytype is int
+
+    def is_float(self):
+        return self.pytype is float
+
+    def is_bool(self):
+        return self.pytype is bool
+
+    def is_nested_int(self):
+        # Unbacked SymInts cannot be nested int today
+        return (
+            self._hint is not None
+            and isinstance(self._hint, SymInt)
+            and self._hint.node.is_nested_int()
+        )
+
+    def wrap_int(self, num):
+        assert type(num) is int
+        import sympy
+
+        return SymNode(
+            sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
+        )
+
+    def wrap_float(self, num):
+        assert type(num) is float
+        import sympy
+
+        return SymNode(
+            sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
+        )
+
+    def wrap_bool(self, num):
+        assert type(num) is bool
+        import sympy
+
+        return SymNode(
+            sympy.true if num else sympy.false,
+            self.shape_env,
+            bool,
+            num,
+            constant=num,
+            fx_node=num,
+        )
+
+    def clone(self):
+        return self
+
+    def str(self):
+        return f"{self.expr}"
+
+    def __str__(self):
+        return self.str()
+
+    def __repr__(self):
+        rep = [
+            f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
+        ]
+        if self._hint is not None:
+            rep.append(f"hint={self._hint}")
+        if self.constant is not None:
+            rep.append(f"constant={self.constant}")
+        if self.fx_node is not None:
+            rep.append(f"fx_node={self.fx_node}")
+        return ", ".join(rep) + ")"
+
+    def _graph_repr(self) -> builtins.str:
+        # Representation used by GraphModule to create a pythonic version of a graph
+        return self.str()
+
+    # These methods call the metaprogrammed methods, they're hand written
+    # here so we get good stack traces
+    def abs(self) -> SymNode:
+        return self._abs()  # type: ignore[attr-defined]
+
+    def pos(self) -> SymNode:
+        return self._pos()  # type: ignore[attr-defined]
+
+    def round(self, ndigits=None) -> SymNode:
+        return self._round(ndigits)  # type: ignore[attr-defined]
+
+    def trunc(self) -> SymNode:
+        return self._trunc()  # type: ignore[attr-defined]
+
+    def add(self, other) -> SymNode:
+        return self._add(other)  # type: ignore[attr-defined]
+
+    def sub(self, other) -> SymNode:
+        return self._sub(other)  # type: ignore[attr-defined]
+
+    def mul(self, other) -> SymNode:
+        return self._mul(other)  # type: ignore[attr-defined]
+
+    def mod(self, other) -> SymNode:
+        return self._mod(other)  # type: ignore[attr-defined]
+
+    def float_pow(self, other) -> SymNode:
+        return self._float_pow(other)  # type: ignore[attr-defined]
+
+    def pow_by_natural(self, other) -> SymNode:
+        return self._pow_by_natural(other)  # type: ignore[attr-defined]
+
+    def and_(self, other) -> SymNode:
+        return self._and_(other)  # type: ignore[attr-defined]
+
+    def or_(self, other) -> SymNode:
+        return self._or_(other)  # type: ignore[attr-defined]
+
+    def float_truediv(self, other) -> SymNode:
+        return self._float_truediv(other)  # type: ignore[attr-defined]
+
+    def int_truediv(self, other) -> SymNode:
+        return self._int_truediv(other)  # type: ignore[attr-defined]
+
+    def int_floordiv(self, other) -> SymNode:
+        return self._int_floordiv(other)  # type: ignore[attr-defined]
+
+    def lshift(self, other) -> SymNode:
+        return self._lshift(other)  # type: ignore[attr-defined]
+
+    def rshift(self, other) -> SymNode:
+        return self._rshift(other)  # type: ignore[attr-defined]
+
+    def sym_not(self) -> SymNode:  # noqa: F811
+        return self._sym_not()  # type: ignore[attr-defined]
+
+    def eq(self, other) -> SymNode:
+        return self._eq(other)  # type: ignore[attr-defined]
+
+    def ne(self, other) -> SymNode:
+        return self._ne(other)  # type: ignore[attr-defined]
+
+    def gt(self, other) -> SymNode:
+        return self._gt(other)  # type: ignore[attr-defined]
+
+    def lt(self, other) -> SymNode:
+        return self._lt(other)  # type: ignore[attr-defined]
+
+    def le(self, other) -> SymNode:
+        return self._le(other)  # type: ignore[attr-defined]
+
+    def ge(self, other) -> SymNode:
+        return self._ge(other)  # type: ignore[attr-defined]
+
+    def floor(self) -> SymNode:
+        return self._floor()  # type: ignore[attr-defined]
+
+    def is_integer(self) -> SymNode:
+        return self._is_integer()  # type: ignore[attr-defined]
+
+    def sym_float(self) -> SymNode:  # noqa: F811
+        return self._sym_float()  # type: ignore[attr-defined]
+
+    def sym_int(self) -> SymNode:
+        return self._sym_int()  # type: ignore[attr-defined]
+
+    def ceil(self) -> SymNode:
+        return self._ceil()  # type: ignore[attr-defined]
+
+    def neg(self) -> SymNode:
+        return self._neg()  # type: ignore[attr-defined]
+
+    def sym_min(self, other) -> SymNode:  # noqa: F811
+        return self._sym_min(other)  # type: ignore[attr-defined]
+
+    def sym_max(self, other) -> SymNode:  # noqa: F811
+        return self._sym_max(other)  # type: ignore[attr-defined]
+
+    def sym_ite(self, then_val, else_val) -> SymNode:
+        return self._sym_ite(then_val, else_val)  # type: ignore[attr-defined]
+
+    def is_contiguous(self, sizes, strides) -> SymNode:
+        return self._is_contiguous(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
+        return self._is_channels_last_contiguous_2d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
+        return self._is_channels_last_contiguous_3d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
+        return self._is_channels_last_strides_2d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
+        return self._is_channels_last_strides_3d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
+        return self._is_non_overlapping_and_dense_indicator(sizes, strides)  # type: ignore[attr-defined]
+
+    # Make C++ happy
+    def sym_or(self, other):
+        return self.or_(other)
+
+    def sym_and(self, other):
+        return self.and_(other)
+
+    # Integer bitwise ops
+    def bitwise_and(self, other):
+        return self._bitwise_and(other)  # type: ignore[attr-defined]
+
+    def bitwise_or(self, other):
+        return self._bitwise_or(other)  # type: ignore[attr-defined]
+
+    def bitwise_xor(self, other):
+        return self._bitwise_xor(other)  # type: ignore[attr-defined]
+
+    # There is no int_truediv available from C++
+    def truediv(self, other):
+        return self.float_truediv(other)
+
+    def floordiv(self, other) -> SymNode:
+        return self.int_floordiv(other)
+
+    # We didn't bind integer pow in C++
+    def pow(self, other):
+        return self.float_pow(other)
+
+    def is_non_overlapping_and_dense(self, sizes, strides):
+        return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
+            to_node(self, 1)
+        )  # type: ignore[attr-defined]
+
+    def int_(self):
+        return self.guard_int("", 0)  # NB: uses Python backtrace
+
+    # This one is currently done by hand, but if we add other variadic
+    # functions consider factoring it out to be metaprogrammed too.  Note that
+    # some load bearing logic is directly in torch.sym_sum
+
+    def sym_sum(self, args) -> SymNode:
+        import sympy
+
+        # Inner impl
+        from torch.fx.experimental.proxy_tensor import (
+            get_proxy_mode,
+            handle_sym_dispatch,
+        )
+
+        if get_proxy_mode():
+            return to_node(
+                self,
+                handle_sym_dispatch(
+                    torch.sym_sum,
+                    (tuple(wrap_node(a) for a in args),),
+                    {},
+                ),
+            )
+        exprs = [a.expr for a in args]
+        out = sympy.Add(*exprs)
+
+        size_hints = []
+        out_hint = None
+        for a in args:
+            if a.hint is None:
+                break
+            size_hints.append(a.hint)
+        else:
+            out_hint = sum(size_hints)
+
+        fx_node, _ = self.shape_env._create_fx_call_function(
+            torch.sym_sum, (tuple(a.fx_node for a in args),)
+        )
+
+        # NB: Only for integers!
+        return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
+
+    def evaluate(self, size_oblivious=False):
+        return self.shape_env.evaluate_sym_node(self, size_oblivious)
+
+    # You can manually trigger a guard with this function
+    def guard_int(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.evaluate()
+        try:
+            return int(r)
+        except Exception:
+            log.warning("Failed to convert to int: %s", r)
+            raise
+
+    def guard_float(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.evaluate()
+        try:
+            return float(r)
+        except Exception:
+            log.warning("Failed to convert to float: %s", r)
+            raise
+
+    def guard_bool(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.evaluate()
+        try:
+            return bool(r)
+        except Exception:
+            log.warning("Failed to convert to bool: %s", r)
+            raise
+
+    def expect_true(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+
+        if (
+            self.has_hint()
+            and not free_unbacked_symbols(self.expr)
+            and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
+        ):
+            # OK to generate guards
+            return self.guard_bool(file, line)
+        # Generate a deferred runtime assert (this might actually end up doing
+        # a regular guard if we can!)
+        # TODO: file/line here is very important, because the assert has been
+        # deferred so you can't backtrace easily
+        return self.shape_env.guard_or_defer_runtime_assert(
+            self.expr, f"{file}:{line}", fx_node=self.fx_node
+        )
+
+    def statically_known_true(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import statically_known_true
+
+        assert self.is_bool()
+        return statically_known_true(SymBool(self))
+
+    def guard_size_oblivious(self, file, line):
+        """
+        Like guard_bool, but if we encounter unbacked symbols, if those symbols
+        are size-like, we will treat them as >= 2 for the purposes of the analysis.
+
+        This CHANGES the runtime semantics, but all size-oblivious sites have been
+        audited to ensure that the runtime semantics don't change in a material way.
+        Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
+        an unbacked one size, or a tensor reporting as non-contiguous even if it's
+        contiguous if it would have been reported contiguous due to being empty.
+        """
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.evaluate(size_oblivious=True)
+        try:
+            return bool(r)
+        except Exception:
+            log.warning("Failed to convert to bool: %s", r)
+            raise
+
+    def guard_or_false(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import guard_or_false
+
+        assert self.is_bool()
+        return guard_or_false(SymBool(self))
+
+    def guard_or_true(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import guard_or_true
+
+        assert self.is_bool()
+        return guard_or_true(SymBool(self))
+
+    def bool_(self):
+        return self.guard_bool("", 0)
+
+    def is_symbolic(self):
+        return True
+
+    def nested_int(self):
+        return None
+
+    def is_constant(self):
+        return False
+
+
+class _DynamicScalar:
+    def __new__(cls, *args):
+        if cls is _DynamicScalar:
+            raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.")
+        return super().__new__(cls, *args)
+
+
+class DynamicInt(_DynamicScalar, int):
+    """
+    User API for marking dynamic integers in `torch.compile`.
+    Intended to be compatible with both compile and eager mode.
+
+    Example usage::
+
+        fn = torch.compile(f)
+        x = DynamicInt(4)
+        fn(x)  # compiles x as a dynamic integer input; returns f(4)
+    """
+
+    def __new__(cls, val):
+        assert isinstance(val, int)
+        obj = super().__new__(cls, int(val))
+        return obj
+
+    def __repr__(self):
+        return f"DynamicInt({self.real})"
+
+    def __floordiv__(self, other):  # // was casting to int without these overrides?
+        return DynamicInt(self.real // other)
+
+    def __rfloordiv__(self, other):
+        return DynamicInt(other // self.real)
+
+
+# TODO: this probably needs the sizes-strides eval functions
+METHOD_TO_OPERATOR = {
+    "pos": operator.pos,
+    "abs": operator.abs,
+    "add": operator.add,
+    "and": operator.and_,
+    "bitwise_and": operator.and_,
+    "ceil": math.ceil,
+    "eq": operator.eq,
+    "floor": math.floor,
+    "trunc": math.trunc,
+    "int_floordiv": operator.floordiv,
+    "ge": operator.ge,
+    "gt": operator.gt,
+    "is_integer": lambda x: x.is_integer(),
+    "le": operator.le,
+    "lshift": operator.lshift,
+    "lt": operator.lt,
+    "mod": operator.mod,
+    "mul": operator.mul,
+    "ne": operator.ne,
+    "neg": operator.neg,
+    "or": operator.or_,
+    "bitwise_or": operator.or_,
+    "bitwise_xor": operator.xor,
+    "float_pow": operator.pow,
+    "pow_by_natural": operator.pow,
+    "round": builtins.round,
+    "rshift": operator.rshift,
+    "sub": operator.sub,
+    "sym_float": sym_float,
+    "sym_ite": sym_ite,
+    "sym_max": sym_max,
+    "sym_min": sym_min,
+    "sym_not": sym_not,
+    "float_truediv": operator.truediv,
+    "int_truediv": operator.truediv,
+}
+
+unary_magic_methods = {
+    "abs",
+    "sym_float",
+    "sym_int",
+    "ceil",
+    "floor",
+    "neg",
+    "sym_not",
+    "pos",
+    "trunc",
+}
+
+
+# Adding math ops: sqrt, cos, sin, ...
+def _get_sym_node_fn(name):
+    def fn(self):
+        return getattr(self, f"_sym_{name}")()
+
+    return fn
+
+
+math_op_names = (
+    "sqrt",
+    "cos",
+    "cosh",
+    "sin",
+    "sinh",
+    "tan",
+    "tanh",
+    "asin",
+    "acos",
+    "atan",
+    "log2",
+)
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    priv_sym_name = f"_{sym_name}"
+    setattr(SymNode, sym_name, _get_sym_node_fn(name))
+    METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
+    unary_magic_methods.add(sym_name)
+    __all__.append(sym_name)
+
+
+# Unary methods that are not magic methods
+unary_nonmagic_methods = {
+    "is_integer",
+}
+
+unary_methods = unary_magic_methods | unary_nonmagic_methods
+
+# Most methods are only registered on SymInt and SymFloat
+# Some methods are only be registered on SymBool
+only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
+# Methods that implicitly convert SymBool into SymInt
+bool_becomes_int_magic_methods = {"add", "sub", "mul"}
+# Methods that are also on SymBool, in addition to on SymInt and SymFloat
+also_bool_magic_methods = {"eq"}
+bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
+
+# Methods that are only for float
+only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
+
+
+magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
+# remap necessary because an op name can have a bitwise and boolean implementation
+bitwise_ops = {"bitwise_and": "and", "bitwise_or": "or", "bitwise_xor": "xor"}
+
+
+always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
+
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    always_float_magic_methods.add(sym_name)
+
+
+always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
+always_bool_magic_methods = {
+    "eq",
+    "ne",
+    "gt",
+    "lt",
+    "le",
+    "ge",
+    "and",
+    "or",
+    "sym_not",
+    "is_non_overlapping_and_dense",
+    "is_integer",
+}
+
+# Methods that have a `__foo__` as well as `__rfoo__`
+
+
+def _sympy_float_truediv(a, b):
+    from torch.utils._sympy.functions import FloatTrueDiv
+
+    return FloatTrueDiv(a, b)
+
+
+def _sympy_int_truediv(a, b):
+    from torch.utils._sympy.functions import IntTrueDiv
+
+    return IntTrueDiv(a, b)
+
+
+def _sympy_floordiv(a, b):
+    from torch.utils._sympy.functions import FloorDiv
+
+    return FloorDiv(a, b)
+
+
+def _sympy_mod(a, b):
+    from torch.utils._sympy.functions import Mod, PythonMod
+
+    if a.is_nonnegative and b.is_nonnegative:
+        return Mod(a, b)
+    else:
+        return PythonMod(a, b)
+
+
+def _sympy_pow_by_natural(a, b):
+    from torch.utils._sympy.functions import PowByNatural
+
+    return PowByNatural(a, b)
+
+
+def _sympy_float_pow(a, b):
+    from torch.utils._sympy.functions import FloatPow
+
+    return FloatPow(a, b)
+
+
+def _sympy_and(a, b):
+    import sympy
+
+    return sympy.And(a, b)
+
+
+def _sympy_or(a, b):
+    import sympy
+
+    return sympy.Or(a, b)
+
+
+def _sympy_lshift(a, b):
+    from torch.utils._sympy.functions import LShift
+
+    return LShift(a, b)
+
+
+def _sympy_rshift(a, b):
+    from torch.utils._sympy.functions import RShift
+
+    return RShift(a, b)
+
+
+def _binary_search_insert_arg(ordered_args, new_arg):
+    """
+    If new_arg is found in ordered_args None is returned, else the new
+    ordered_args with new_arg inserted
+    """
+    if len(ordered_args) == 0:
+        return [new_arg]
+
+    from sympy.core.basic import _args_sortkey as sort_key, Basic
+
+    # Fast path when new_arg > ordered_args[-1].
+    if sort_key(ordered_args[-1]) < sort_key(new_arg):
+        return ordered_args + [new_arg]
+
+    # Fast path when new_arg < ordered_args[0].
+    if sort_key(ordered_args[0]) > sort_key(new_arg):
+        return [new_arg] + ordered_args
+
+    low, high = 0, len(ordered_args) - 1
+
+    while low <= high:
+        mid = (low + high) // 2
+        compare_result = Basic.compare(ordered_args[mid], new_arg)
+        if compare_result == 0:
+            return None
+        elif compare_result < 0:
+            low = mid + 1
+        else:
+            high = mid - 1
+
+    ordered_args.insert(low, new_arg)
+    return ordered_args
+
+
+def _optimized_add(
+    lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
+):
+    """
+    Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
+    is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
+    and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
+    1. Avoid running other optimizations when the Add is constructed.
+    2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
+    (comparing terms is expensive and shows in the profiles).
+    The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
+    (2) the result sympy expression.
+    """
+    import sympy
+    from sympy.core.basic import _args_sortkey as sortkey
+
+    def make_optimized(ordered_args):
+        assert ordered_args is not None
+        result = sympy.Add(*ordered_args, evaluate=False)
+        return (True, result)
+
+    from torch.utils._sympy.functions import _is_symbols_binary_summation
+
+    lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
+    rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
+
+    if lhs_is_optimized_summation and rhs_is_optimized_summation:
+        # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
+        if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
+            return make_optimized(lhs._args + rhs._args)
+        #  (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
+        if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
+            return make_optimized(rhs._args + lhs._args)
+
+        #  (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
+        if len(lhs._args) <= 2 and len(rhs._args) <= 2:
+            new_args = list(lhs._args)
+            for a in rhs._args:
+                new_args = _binary_search_insert_arg(new_args, a)
+                if new_args is None:
+                    break
+            # None means an element already exists.
+            if new_args is not None:
+                return make_optimized(new_args)
+
+    # (a0+a2) + a1 => (a0+a1+a2)
+    if lhs_is_optimized_summation and rhs.is_symbol:
+        new_args = _binary_search_insert_arg(list(lhs._args), rhs)
+        # None means an element already exists.
+        if new_args is not None:
+            return make_optimized(new_args)
+
+    # a1 + (a0+a2)=> (a0+a1+a2)
+    if rhs_is_optimized_summation and lhs.is_symbol:
+        new_args = _binary_search_insert_arg(list(rhs._args), lhs)
+        # None means an element already exists.
+        if new_args is not None:
+            return make_optimized(new_args)
+
+    result = sympy.Add(lhs, rhs)
+    return (_is_symbols_binary_summation(result), result)
+
+
+def _bitwise_and(a, b):
+    from torch.utils._sympy.functions import BitwiseFn_bitwise_and
+
+    return BitwiseFn_bitwise_and(a, b)
+
+
+def _bitwise_or(a, b):
+    from torch.utils._sympy.functions import BitwiseFn_bitwise_or
+
+    return BitwiseFn_bitwise_or(a, b)
+
+
+def _bitwise_xor(a, b):
+    from torch.utils._sympy.functions import BitwiseFn_bitwise_xor
+
+    return BitwiseFn_bitwise_xor(a, b)
+
+
+reflectable_magic_methods = {
+    "add": operator.add,
+    "sub": operator.sub,
+    "mul": operator.mul,
+    "mod": _sympy_mod,
+    "pow_by_natural": _sympy_pow_by_natural,
+    "float_pow": _sympy_float_pow,
+    "and": _sympy_and,
+    "bitwise_and": _bitwise_and,
+    "or": _sympy_or,
+    "bitwise_or": _bitwise_or,
+    "bitwise_xor": _bitwise_xor,
+    "float_truediv": _sympy_float_truediv,
+    "int_truediv": _sympy_int_truediv,
+    "int_floordiv": _sympy_floordiv,
+    "lshift": _sympy_lshift,
+    "rshift": _sympy_rshift,
+}
+
+
+def _floor_ceil_helper(a, fn):
+    import sympy
+
+    if isinstance(a, sympy.Mul):
+        aa = a.args
+        if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
+            coef = sympy.Integer(aa[0])
+            if aa[0] == coef:  # structural equality test
+                return coef * aa[1]
+    if (
+        isinstance(a, sympy.Float)
+        and a == sympy.Integer(a)
+        or isinstance(a, sympy.Integer)
+    ):
+        return sympy.Integer(a)
+    return fn(a)
+
+
+def _sympy_floor(a):
+    from torch.utils._sympy.functions import FloorToInt
+
+    return FloorToInt(a)
+
+
+# NB: this is Python trunc semantics which returns an int.  Do NOT use this to
+# represent torch.trunc (which is float to float)
+def _sympy_trunc(a):
+    from torch.utils._sympy.functions import TruncToInt
+
+    return TruncToInt(a)
+
+
+def _sympy_ceil(a):
+    from torch.utils._sympy.functions import CeilToInt
+
+    return CeilToInt(a)
+
+
+def _sympy_eq(a, b):
+    import sympy
+
+    return sympy.Eq(a, b)
+
+
+def _sympy_ne(a, b):
+    import sympy
+
+    return sympy.Ne(a, b)
+
+
+def _sympy_gt(a, b):
+    import sympy
+
+    return sympy.Gt(a, b)
+
+
+def _sympy_lt(a, b):
+    import sympy
+
+    return sympy.Lt(a, b)
+
+
+def _sympy_le(a, b):
+    import sympy
+
+    return sympy.Le(a, b)
+
+
+def _sympy_ge(a, b):
+    import sympy
+
+    return sympy.Ge(a, b)
+
+
+def _sympy_min(a, b):
+    from torch.utils._sympy.functions import Min
+
+    return Min(a, b)
+
+
+def _sympy_max(a, b):
+    from torch.utils._sympy.functions import Max
+
+    return Max(a, b)
+
+
+def _sympy_ite(a, t, f):
+    import sympy
+
+    return sympy.Piecewise((t, a), (f, True))
+
+
+current_module = sys.modules[__name__]
+
+
+def _get_sym_math_fn(name):
+    def fn(a):
+        import torch.utils._sympy.functions
+
+        return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
+
+    return fn
+
+
+for name in math_op_names:
+    priv_sympy_name = f"_sympy_{name}"
+    fn = _get_sym_math_fn(name)
+    fn.__qualname__ = fn.__name__ = priv_sympy_name
+    setattr(current_module, priv_sympy_name, fn)
+
+del fn, name, priv_sympy_name  # type: ignore[possibly-undefined]
+
+
+def _sympy_abs(a):
+    import sympy
+
+    return sympy.Abs(a)
+
+
+def _sympy_round(number, ndigits=None):
+    from torch.utils._sympy.functions import RoundDecimal, RoundToInt
+
+    if ndigits is None:
+        return RoundToInt(number)
+    else:
+        return RoundDecimal(number, ndigits)
+
+
+def _sympy_sym_float(a):
+    from torch.utils._sympy.functions import ToFloat
+
+    # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
+    # reports that it is an integer
+    return ToFloat(a)
+
+
+def _sympy_is_integer(a):
+    import sympy
+
+    from torch.utils._sympy.functions import ToFloat
+
+    return sympy.Eq(ToFloat(sympy.floor(a)), a)
+
+
+magic_methods = {
+    **reflectable_magic_methods,
+    "sym_not": operator.invert,
+    "pos": operator.pos,
+    "eq": _sympy_eq,
+    "ne": _sympy_ne,
+    "gt": _sympy_gt,
+    "lt": _sympy_lt,
+    "le": _sympy_le,
+    "ge": _sympy_ge,
+    "floor": _sympy_floor,
+    "trunc": _sympy_trunc,
+    "sym_float": _sympy_sym_float,
+    "ceil": _sympy_ceil,
+    "neg": operator.neg,
+    "sym_min": _sympy_min,
+    "sym_max": _sympy_max,
+    "sym_ite": _sympy_ite,
+    "abs": _sympy_abs,
+    "round": _sympy_round,
+    "is_integer": _sympy_is_integer,
+}
+
+
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
+
+del name, sym_name, math_op_names, current_module  # type: ignore[possibly-undefined]
+
+
+def sympy_is_contiguous(sizes, strides):
+    dim = len(sizes)
+    return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
+
+
+def sympy_is_contiguous_generic(sizes, strides, dim_order):
+    import sympy
+
+    dim = len(sizes)
+
+    if len(dim_order) != dim:
+        return sympy.false
+
+    is_contiguous = sympy.true
+    z = sympy.S.One
+    # Contiguous if the strides make sense (or the dim is size 1)
+    for d in dim_order:
+        is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
+        z *= sizes[d]
+    # OR if any size is zero
+    for d in range(dim):
+        is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
+    return is_contiguous
+
+
+# NB: There is a TODO in C++ to allow omitting the batch dim.  If that
+# happens you will need to refactor this
+
+
+def sympy_is_channels_last_contiguous_2d(sizes, strides):
+    return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
+
+
+def sympy_is_channels_last_contiguous_3d(sizes, strides):
+    return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
+
+
+def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
+    import sympy
+
+    from torch.utils._sympy.functions import Max
+
+    dim = len(sizes)
+
+    if dim != len(dim_order):
+        return sympy.false
+
+    m = sympy.S.Zero
+    r = sympy.true
+
+    # special case for trivial C dimension. default to NCHW
+    r &= sympy.Ne(strides[1], 0)
+
+    for d in dim_order:
+        r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
+        # Fallback to NCHW as default layout for ambiguous cases
+        # This is the flaw of implicit memory_format from strides.
+        # N111 tensor with identical strides for size 1 dimension;
+        # Two cases could lead us here:
+        # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
+        # b. N11W contiguous Tensor sliced on the W-dimension.
+        # ([N,1,1,1]@[W,W,W,W])
+        if d == 0:
+            r &= sympy.Ne(m, strides[1])
+        # This is necessary to:
+        # 1. distinguish the memory_format of N1H1;
+        #     [H, 1, 1, 1] channels_last stride
+        #     [H, H, 1, 1] contiguous stride
+        # 2. permutation of 1C1W:
+        #     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
+        #     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
+        #     channels_last
+        m = strides[d] * Max(sizes[d], 1)
+
+    return r
+
+
+def sympy_is_channels_last_strides_2d(sizes, strides):
+    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
+
+
+def sympy_is_channels_last_strides_3d(sizes, strides):
+    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
+
+
+def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
+    from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
+
+    return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
+
+
+sizes_strides_methods = {
+    # TODO: These could also be done with indicators, maybe it is better
+    # for reasoning to do it that way
+    "is_contiguous": sympy_is_contiguous,
+    "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
+    "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
+    "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
+    "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
+    "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
+}
+
+
+def to_node(self, num):
+    if isinstance(num, SymTypes):
+        return num.node
+    elif type(num) is bool:
+        return self.wrap_bool(num)
+    elif type(num) is int:
+        return self.wrap_int(num)
+    elif type(num) is float:
+        return self.wrap_float(num)
+    else:
+        # NotImplemented is important so that Python tries the
+        # other magic method
+        return NotImplemented
+
+
+def wrap_node(x):
+    # TODO: let C++ also take advantage of this
+    if isinstance(x, SymNode) and x.constant is not None:
+        return x.constant
+    if x.is_int():
+        return SymInt(x)
+    elif x.is_float():
+        return SymFloat(x)
+    elif x.is_bool():
+        return SymBool(x)
+    else:
+        raise AssertionError(f"unrecognized return type {x}")
+
+
+def method_to_operator(method):
+    return METHOD_TO_OPERATOR[method]
+
+
+def _make_node_magic(method, func):
+    func = lru_cache(256)(func)
+
+    if method in magic_methods_on_operator_with_trailing_underscore:
+        method_attr = f"{method}_"
+    else:
+        method_attr = method
+
+    def uninteresting_files() -> set[str]:
+        import torch
+
+        mods = [
+            torch._dynamo.eval_frame,
+            torch._dynamo.utils,
+            torch.fx.experimental.sym_node,
+            torch,
+        ]
+        import torch._dynamo.guards
+
+        return (
+            {inspect.getfile(m) for m in mods}
+            | torch._dynamo.guards.uninteresting_files()
+            | {""}
+        )
+
+    def capture_provenance(fn):
+        @functools.wraps(fn)
+        def wrapper(self, other=None):
+            if other is None:
+                result = fn(self)
+            else:
+                result = fn(self, other)
+            if torch._logging._internal.GET_DTRACE_STRUCTURED:
+                if other is not None:
+                    arguments = [self, other]
+                else:
+                    arguments = [self]
+
+                def get_id(sym_node) -> Optional[int]:
+                    # We don't want to return an ID if the input is a constant
+                    import sympy
+
+                    if sym_node.constant is not None:
+                        return None
+                    elif id(sym_node) == id(result):
+                        return None
+                    elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
+                        return None
+                    elif sym_node.expr in (sympy.true, sympy.false):
+                        return None
+                    return id(sym_node)
+
+                dtrace_structured(
+                    "expression_created",
+                    metadata_fn=lambda: {
+                        "method": method,
+                        "result": str(result),
+                        "result_id": id(result),
+                        "arguments": [str(a) for a in arguments],
+                        "argument_ids": [
+                            get_id(i) for i in arguments if get_id(i) is not None
+                        ],
+                        "user_stack": structured.get_user_stack(3),
+                        "stack": structured.get_framework_stack(3),
+                    },
+                )
+
+            return result
+
+        return wrapper
+
+    @capture_provenance
+    def binary_magic_impl(self, other):
+        from torch.fx.experimental.proxy_tensor import (
+            get_proxy_mode,
+            handle_sym_dispatch,
+        )
+
+        op = method_to_operator(method)
+
+        out_hint = None
+        if self.hint is not None and other.hint is not None:
+            out_hint = op(self.hint, other.hint)
+
+        if get_proxy_mode():
+            return to_node(
+                self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
+            )
+        assert isinstance(other, SymNode)
+        optimized_summation = False
+        try:
+            if method == "mod":
+                from torch.utils._sympy.functions import Mod, PythonMod
+
+                # Special handling for mod that requires access to the value
+                # ranges
+                shape_env = self.shape_env
+                if (
+                    self.expr.is_nonnegative
+                    or shape_env.bound_sympy(self.expr).lower >= 0
+                ) and (
+                    other.expr.is_nonnegative
+                    or shape_env.bound_sympy(other.expr).lower >= 0
+                ):
+                    out = Mod(self.expr, other.expr)
+                else:
+                    out = PythonMod(self.expr, other.expr)
+            elif method == "add":
+                # see Note [optimized_summation]
+                (optimized_summation, out) = _optimized_add(
+                    self.expr,
+                    other.expr,
+                    self._optimized_summation,
+                    other._optimized_summation,
+                )
+            else:
+                # TODO: consider constant prop here
+                out = func(self.expr, other.expr)
+        except Exception:
+            log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
+            raise
+        sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
+        pytype: type
+        # This is not strictly correct. In Python, a**b may return complex when
+        # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
+        # returns a float while both arguments are ints: 2**(-1). Also, max and
+        # min do not type promote. To avoid having data-dependent control flow
+        # here, we just set the type to float if one of the args is a float. In
+        # case of a type mismatch, we assume that it will be detected during
+        # evaluation.
+        if method in always_float_magic_methods:
+            pytype = float
+        elif method in always_bool_magic_methods:
+            pytype = bool
+        elif self.pytype is float or other.pytype is float:
+            pytype = float
+        else:
+            pytype = self.pytype
+
+        if (
+            pytype is not None
+            and out_hint is not None
+            and not isinstance(out_hint, SymTypes)
+        ):
+            out_hint = pytype(out_hint)
+
+        # Create a FX node that corresponds to the operation being applied to
+        # this node.
+        fx_node, _ = self.shape_env._create_fx_call_function(
+            op, (self.fx_node, other.fx_node)
+        )
+
+        result = SymNode(
+            out,
+            self.shape_env,
+            pytype,
+            out_hint,  # type: ignore[arg-type]
+            fx_node=fx_node,
+            optimized_summation=optimized_summation,  # see Note [optimized_summation]
+        )
+        return result
+
+    @capture_provenance
+    def unary_magic_impl(self):
+        from torch.fx.experimental.proxy_tensor import (
+            get_proxy_mode,
+            handle_sym_dispatch,
+        )
+
+        op = method_to_operator(method)
+        if get_proxy_mode():
+            return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
+        # TODO: consider constant prop here
+        expr = self.expr
+        if method == "floor" or method == "ceiling":
+            expr = self.shape_env._simplify_floor_div(expr)
+
+        try:
+            out = func(expr)
+        except Exception:
+            log.warning("failed to eval %s(%s)", method, expr)
+            raise
+        sym_node_log.debug("%s %s -> %s", func, expr, out)
+        out_hint = None
+        if self.hint is not None:
+            out_hint = op(self.hint)
+        pytype: type
+        if method in always_int_magic_methods:
+            pytype = int
+        elif method in always_bool_magic_methods:
+            pytype = bool
+        elif method in always_float_magic_methods:
+            pytype = float
+        else:
+            pytype = self.pytype
+
+        fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
+        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
+
+    if method in unary_methods:
+        setattr(SymNode, f"_{method_attr}", unary_magic_impl)
+    elif method == "sym_ite":
+
+        def sym_ite_impl(pred_node, then_node, else_node):
+            from torch.fx.experimental.proxy_tensor import (
+                get_proxy_mode,
+                handle_sym_dispatch,
+            )
+
+            out_hint = then_node.hint if pred_node.hint else else_node.hint
+            if get_proxy_mode():
+                return to_node(
+                    pred_node,
+                    handle_sym_dispatch(
+                        sym_ite,
+                        (
+                            wrap_node(pred_node),
+                            wrap_node(then_node),
+                            wrap_node(else_node),
+                        ),
+                        {},
+                    ),
+                )
+
+            try:
+                out = func(pred_node.expr, then_node.expr, else_node.expr)
+            except Exception:
+                log.warning(
+                    "failed to eval %s(%s, %s, %s)",
+                    method,
+                    pred_node.expr,
+                    then_node.expr,
+                    else_node.expr,
+                )
+                raise
+
+            fx_node, _ = pred_node.shape_env._create_fx_call_function(
+                sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
+            )
+            return SymNode(
+                out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
+            )
+
+        setattr(SymNode, f"_{method_attr}", sym_ite_impl)
+    elif method == "round":
+
+        def round_impl(self, ndigits=None):
+            from torch.fx.experimental.proxy_tensor import (
+                get_proxy_mode,
+                handle_sym_dispatch,
+            )
+
+            op = builtins.round
+            if get_proxy_mode():
+                return to_node(
+                    self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
+                )
+
+            expr = self.expr
+            try:
+                out = func(expr, ndigits)
+            except Exception:
+                log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
+                raise
+
+            if ndigits is None:
+                pytype = int
+            else:
+                pytype = self.pytype
+
+            out_hint = None
+            if self.hint is not None:
+                out_hint = op(self.hint, ndigits)
+
+            # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
+            # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
+            # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
+            # hack down below works, because all round function down the line all take ndigits=None as default in their
+            # signature.
+            # TODO: Remove the args construction below if a different sentinel is used by FX.
+            # ezyang(May 2024): LOL
+            args = [self.fx_node]
+            if ndigits is not None:
+                args.append(ndigits)
+            fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
+            return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
+
+        setattr(SymNode, f"_{method_attr}", round_impl)
+    else:
+        setattr(SymNode, f"_{method_attr}", binary_magic_impl)
+
+
+def _make_node_sizes_strides(method, func):
+    # NB: don't LRU cache, lots of arguments
+
+    def sizes_strides_impl(self, sizes, strides):
+        from torch.fx.experimental.proxy_tensor import (
+            get_proxy_mode,
+            handle_sym_dispatch,
+        )
+
+        op = getattr(sys.modules[__name__], method)
+        if get_proxy_mode():
+            return to_node(
+                self,
+                handle_sym_dispatch(
+                    op,
+                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
+                    {},
+                ),
+            )
+        size_exprs = [s.expr for s in sizes]
+        stride_exprs = [s.expr for s in strides]
+        try:
+            out = func(size_exprs, stride_exprs)
+        except Exception:
+            log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
+            raise
+        # bool is never expandable
+
+        size_hints = []
+        out_hint = None
+        for s in sizes:
+            if s.hint is None:
+                break
+            size_hints.append(s.hint)
+        else:
+            stride_hints = []
+            for s in strides:
+                if s.hint is None:
+                    break
+                stride_hints.append(s.hint)
+            else:
+                out_hint = op(size_hints, stride_hints)
+
+        # NB: This is the indicator function, not the actual bool!
+        pytype: type
+        if method.endswith("_indicator"):
+            pytype = int
+        else:
+            pytype = bool
+        return SymNode(out, self.shape_env, pytype, out_hint)
+
+    setattr(SymNode, f"_{method}", sizes_strides_impl)
+
+    # TODO: This is technically hotpath, but in the ideal end state
+    # guards on this will resolve at a higher level so you never
+    # spend time in this code
+    def sizes_strides_user(sizes, strides):
+        import sympy
+
+        from torch.fx.experimental.symbolic_shapes import (
+            eval_is_non_overlapping_and_dense,
+        )
+
+        for a in itertools.chain(sizes, strides):
+            if isinstance(a, SymInt):
+                return wrap_node(
+                    getattr(a.node, method)(
+                        [to_node(a.node, b) for b in sizes],
+                        [to_node(a.node, b) for b in strides],
+                    )
+                )
+        if method == "is_non_overlapping_and_dense_indicator":
+            return eval_is_non_overlapping_and_dense(sizes, strides)
+        else:
+            # TODO: this is an awful implementation
+            return bool(
+                func(
+                    [sympy.sympify(a) for a in sizes],
+                    [sympy.sympify(a) for a in strides],
+                )
+            )
+
+    # Skip for is_non_overlapping_and_dense_indicator
+    if not hasattr(sys.modules[__name__], method):
+        setattr(sys.modules[__name__], method, sizes_strides_user)
+
+
+for method, func in magic_methods.items():
+    _make_node_magic(method, func)
+
+for method, func in sizes_strides_methods.items():
+    _make_node_sizes_strides(method, func)
+
+
+def _make_user_magic(method, user_type):
+    # User magic takes care of wrapping the other operand into a node,
+    # so that our internal logic can assume everything is nodes
+    if method in magic_methods_on_operator_with_trailing_underscore:
+        method_attr = f"sym_{method}"
+    else:
+        method_attr = method
+
+    def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
+        if isinstance(x, (int, float, bool)):
+            return x
+        if isinstance(x, SymInt):
+            return x.node.guard_int("", 0)
+        if isinstance(x, SymBool):
+            return x.node.guard_bool("", 0)
+        raise AssertionError("expect to be called with constant SymBools")
+
+    def is_constant(x):
+        if isinstance(x, (int, float, bool)):
+            return True
+        if isinstance(x, (SymInt, SymFloat, SymBool)):
+            return x.node.is_constant()
+        return False
+
+    # Promotion rules for binary operations.  NB: we preserve PYTHON semantics
+    #   - if args are same type, do nothing
+    #   - if one arg is float, promote other arg to float
+    #       - nb: this applies to floordiv, even though output is integral
+    #       (it's still float)
+    #   - pow is funny business
+    #       - if both ints
+    #       - trigger a guard on exponent >= 0
+    #           - if non-negative, output is int
+    #           - otherwise, output is float
+    #   - otherwise, promote other arg to float
+    #       - nb: complex is impossible to handle correctly lol, with
+    #       negative base and integral float need to diverge semantics and
+    #       just always return complex.  Neener neener pretend this problem
+    #       doesn't exist
+    #   - equality is pain: Python does the fancy thing where it unpacks the
+    #     mantissa from the float and then compares that against the int.
+    #     Which means it is able to tell that
+    #     9007199254740993 != 9007199254740992. (rather than if the LHS was
+    #     promoted to float, in which case it would have truncated to the RHS
+    #     and subsequently been equal).  We'll model this exactly by having
+    #     special mixed type equality operations.  Unfortunately, we need to
+    #     do this for all comparison operations (maybe I'll only implement
+    #     compare)
+    #   - sym_ite mumble mumble really shouldn't allow mixed but whatever
+
+    if method in bool_becomes_int_magic_methods:
+
+        def promote(x):
+            """Implements True+True=2, which works in python but not sympy"""
+            if isinstance(x, SymBool):
+                return SymInt(x.node.wrap_int(int(x)))
+            return x
+
+    else:
+
+        def promote(x):
+            return x
+
+    def promote2(self, other):
+        # TODO: Remove eq and other relations from this list.
+        # CPython has fancy implementations for these to get as much precision
+        # as possible instead of just promoting to float64 and praying, so we
+        # need to handle them specially too.
+        # Also, note that int_truediv doesn't go through this path: both
+        # arguments are "int" so there isn't any promotion
+        if method not in [
+            "add",
+            "sub",
+            "mul",
+            "mod",
+            "float_pow",
+            "float_truediv",
+            "int_floordiv",
+            "sym_min",
+            "sym_max",
+            # TODO: remove these
+            "eq",
+            "ne",
+            "gt",
+            "lt",
+            "le",
+            "ge",
+        ]:
+            return self, other
+        f_self = isinstance(self, (float, torch.SymFloat))
+        f_other = isinstance(other, (float, torch.SymFloat))
+        if f_self or f_other:
+            if not f_self:
+                self = torch.sym_float(self)
+            if not f_other:
+                other = torch.sym_float(other)
+        return self, other
+
+    # Before and after performing the operation, check if any operands are constant.
+    # If so, extract out the constant values first. If `self` itself is a
+    # constant, then "redispatch" by calling back into the operator. Sometimes
+    # this means that operations involving SymBool return plain bools.
+    # Alternatively, we could also rewrap into constant Symbool (i.e. by
+    # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
+    # today for no particular reason.
+    def unary_magic_impl(self):
+        self = promote(self)
+        if is_constant(self):
+            return (method_to_operator(method))(get_constant(self))
+        return wrap_node(getattr(self.node, method_attr)())
+
+    def binary_magic_impl(self, other):
+        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
+            return NotImplemented
+        sym_node_log.debug("MAGIC %s %s %s", method, self, other)
+        self = promote(self)
+        other = promote(other)
+        self, other = promote2(self, other)
+        if is_constant(self):
+            return (method_to_operator(method))(get_constant(self), other)
+        if is_constant(other):
+            other = get_constant(other)
+        other_node = to_node(self.node, other)
+        if other_node is NotImplemented:
+            return NotImplemented
+        ret = wrap_node(getattr(self.node, method_attr)(other_node))
+        return get_constant(ret) if is_constant(ret) else ret
+
+    def rbinary_magic_impl(self, other):
+        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
+            return NotImplemented
+        self = promote(self)
+        other = promote(other)
+        self, other = promote2(self, other)
+        if is_constant(self):
+            return (method_to_operator(method))(other, get_constant(self))
+        if is_constant(other):
+            other = get_constant(other)
+        other_node = to_node(self.node, other)
+        if other_node is NotImplemented:
+            return NotImplemented
+        ret = wrap_node(getattr(other_node, method_attr)(self.node))
+        return get_constant(ret) if is_constant(ret) else ret
+
+    def setattrs(user_type, attr, symnode_impl):
+        """
+        Registers the SymNode magic method on SymInt/Float/Bool,
+        and optionally registers a corresponding wrapped method on DynamicInt.
+        """
+
+        # SymInt/Float/Bool
+        setattr(user_type, attr, symnode_impl)
+
+        # DynamicInt impl
+        def dynamic_int_impl(*args):
+            args = [x.real if isinstance(x, DynamicInt) else x for x in args]
+            out = getattr(int, attr)(*args)
+            if isinstance(out, int) and not isinstance(out, bool):
+                return DynamicInt(out)
+            return out
+
+        if user_type is SymInt:
+            setattr(DynamicInt, attr, dynamic_int_impl)
+
+    if method in unary_magic_methods:
+        setattrs(user_type, f"__{method}__", unary_magic_impl)
+    elif method in unary_nonmagic_methods:
+        orig = getattr(user_type, method)
+        setattrs(user_type, method, update_wrapper(unary_magic_impl, orig))
+    elif method == "sym_ite":
+
+        def sym_ite_magic_impl(pred, then_val, else_val):
+            pred_node = pred.node
+            then_node = to_node(pred_node, then_val)
+            else_node = to_node(pred_node, else_val)
+            if then_node is NotImplemented or else_node is NotImplemented:
+                return NotImplemented
+            assert (
+                isinstance(then_node, SymNode)
+                and isinstance(else_node, SymNode)
+                and then_node.pytype == else_node.pytype
+            )
+            ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
+            return get_constant(ret) if ret.node.is_constant() else ret
+
+        setattrs(user_type, f"__{method}__", sym_ite_magic_impl)
+    elif method == "round":
+
+        def round_magic_impl(self, ndigits=None):
+            if is_constant(self):
+                return builtins.round(get_constant(self), ndigits)
+
+            return wrap_node(getattr(self.node, method)(ndigits))
+
+        setattrs(user_type, f"__{method}__", round_magic_impl)
+    else:
+        method_name = method
+        if method in bitwise_ops:
+            method_name = bitwise_ops[method]
+        setattrs(user_type, f"__{method_name}__", binary_magic_impl)
+        if method in reflectable_magic_methods:
+            setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
+
+
+for method in magic_methods:  # type: ignore[assignment]
+    if method in only_bool_magic_methods:
+        _make_user_magic(method, SymBool)
+        continue
+    if method in only_float_magic_methods:
+        _make_user_magic(method, SymFloat)
+        continue
+    if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
+        _make_user_magic(method, SymBool)
+    _make_user_magic(method, SymInt)
+    if method not in bitwise_ops:
+        _make_user_magic(method, SymFloat)
+
+del method
+del func
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ffc77c23b08e0c35860783658c2c84f3ce0397
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py
@@ -0,0 +1,8121 @@
+from __future__ import annotations
+
+import sympy
+from sympy import S
+
+from torch._prims_common import BoolLike, FloatLike, IntLike
+
+
+"""
+``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
+our symbolic shapes reasoning system that is used heavily in torch.compile.  Although
+this is not generally considered public API, when writing framework code in PyTorch
+as well as extensions to PyTorch (e.g., in custom operator implementations), you may
+need to make use of these APIs to setup dynamic shapes support appropriately.
+"""
+
+import abc
+import atexit
+import collections
+import dis
+import functools
+import hashlib
+import inspect
+import itertools
+import logging
+import math
+import operator
+import os
+import re
+import sys
+import threading
+import traceback
+from collections import Counter, defaultdict
+from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
+from contextlib import _GeneratorContextManager, contextmanager
+from dataclasses import asdict, dataclass, field
+from enum import Enum
+from typing import (
+    Any,
+    cast,
+    Generic,
+    NamedTuple,
+    NoReturn,
+    Optional,
+    TYPE_CHECKING,
+    TypeAlias,
+    TypeGuard,
+    TypeVar,
+    Union,
+)
+from typing_extensions import deprecated, ParamSpec
+
+import torch
+import torch.fx
+import torch.fx.traceback as fx_traceback
+import torch.utils._pytree as pytree
+
+# NB: The sym_* functions are used via getattr() and must be imported here.
+from torch import SymBool, SymFloat, SymInt
+from torch._C._functorch import get_unwrapped, is_batchedtensor
+from torch._guards import ShapeGuard, SLoc, Source, TracingContext
+from torch._logging import dtrace_structured, LazyString, structured, trace_structured
+from torch._subclasses.meta_utils import is_sparse_any
+from torch._utils_internal import signpost_event
+from torch.fx.experimental import _config as config
+from torch.fx.experimental.recording import (
+    FakeTensorMeta,
+    record_shapeenv_event,
+    replay_shape_env_events,
+    shape_env_check_state_equal,
+    ShapeEnvEvent,
+)
+from torch.fx.experimental.sym_node import SymNode, SymTypes
+from torch.types import py_sym_types
+from torch.utils._ordered_set import OrderedSet
+from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+from torch.utils._sympy.functions import (
+    Application,
+    CeilToInt,
+    CleanDiv,
+    FloorDiv,
+    FloorToInt,
+    IntTrueDiv,
+    IsNonOverlappingAndDenseIndicator,
+    Max,
+    Mod,
+    PythonMod,
+    TruncToInt,
+)
+from torch.utils._sympy.numbers import int_oo
+from torch.utils._sympy.printers import CppPrinter, PythonPrinter
+from torch.utils._sympy.singleton_int import SingletonInt
+from torch.utils._sympy.solve import try_solve
+from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
+from torch.utils._sympy.value_ranges import (
+    bound_sympy,
+    SymPyValueRangeAnalysis,
+    ValueRangeError,
+    ValueRanges,
+)
+from torch.utils._traceback import CapturedTraceback, format_frame
+
+
+if TYPE_CHECKING:
+    import types
+
+    from torch import Tensor
+    from torch._dynamo.source import TensorPropertySource
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch.types import BoolLikeType, FloatLikeType, IntLikeType
+
+
+InputList = list
+DimList = list
+
+log = logging.getLogger(__name__)
+
+
+class GuardOnDataDependentSymNode(RuntimeError):
+    cond: sympy.Basic
+
+    def __init__(self, cond: sympy.Basic, *args: Any) -> None:
+        super().__init__(*args)
+        self.cond = cond
+
+
+class PendingUnbackedSymbolNotFound(RuntimeError):
+    pass
+
+
+aten = torch._ops.ops.aten  # type: ignore[has-type]
+
+__all__ = [
+    "size_hint",
+    "guard_or_false",
+    "guard_or_true",
+    "has_symbolic_sizes_strides",
+    "create_contiguous",
+    "ShapeEnv",
+    "is_concrete_int",
+    "is_concrete_float",
+    "is_concrete_bool",
+    "has_static_value",
+    "guard_int",
+    "guard_float",
+    "guard_scalar",
+    "canonicalize_bool_expr",
+    "hint_int",
+    "SYMPY_INTERP",
+    "free_symbols",
+    "is_symbol_binding_fx_node",
+    "is_nested_int",
+    "SHAPEENV_EVENT_KEY",
+    "CURRENT_NODE_KEY",
+    "has_free_symbols",
+    "has_free_unbacked_symbols",
+    "sym_and",
+    "sym_eq",
+    "sym_or",
+    "SymbolicContext",
+    "StatelessSymbolicContext",
+    "StatefulSymbolicContext",
+    "SubclassSymbolicContext",
+    "SymIntSymbolicContext",
+    "TrackedFake",
+    "statically_known_true",
+    "statically_known_false",
+    "guard_size_oblivious",
+    "check_consistent",
+    "compute_unbacked_bindings",
+    "ConvertIntKey",
+    "rebind_unbacked",
+    "resolve_unbacked_bindings",
+    "is_accessor_node",
+    "ValueRangesSLoc",
+    "SymIntEqByExpr",
+    "Specialization",
+]
+
+# FX node metadata keys for symbolic shape FX graph.
+SHAPEENV_EVENT_KEY = "shapeenv_event"
+CURRENT_NODE_KEY = "current_node"
+
+
+def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None:
+    log.debug(
+        "lru_cache_stats %s: %s",
+        wrapped_f.__name__,  # type: ignore[attr-defined]
+        wrapped_f.cumulative_cache_info(),  # type: ignore[attr-defined]
+    )
+
+
+# Note about Sympy Expr/SympyBoolean/Basic typing: the Sympy hierarchy is
+#
+#   Basic
+#       Expr
+#       SympyBoolean
+#           Relational
+#
+# Notably, Expr and SympyBoolean are not related.  So use Basic when the
+# expression could denote int, float OR bool, and otherwise use the more
+# specific Expr for int/float and SympyBoolean for bool.
+#
+# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
+# So make sure only type checker evaluates this alias.
+# Xref: https://www.internalfb.com/diff/D53324783
+SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
+
+
+_T = TypeVar("_T")
+_SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic)
+
+
+class SymIntEqByExpr:
+    """
+    This is a wrapper around SymInt which has alternative semantics for
+    equality and pickling.  Specifically, instead of erroring or guarding, we
+    instead will hash/compare equality based on the underlying sympy
+    expression; e.g., s0 and s1 will always compare as False.
+
+    NB: This does NOT do fancy analysis that maybe_evaluate_static does;
+    we can only reason through equalities that occur because to expressions
+    canonicalize to the same expression via regular simplification.
+    """
+
+    @staticmethod
+    def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr:
+        if isinstance(val, torch.SymInt):
+            return val.node.expr
+        else:
+            return sympy.Integer(val)
+
+    def __init__(self, val: Union[torch.SymInt, int]) -> None:
+        self.val: sympy.Expr = SymIntEqByExpr._extract(val)
+
+    def __repr__(self) -> str:
+        return repr(self.val)
+
+    def __eq__(self, other: object) -> bool:
+        assert isinstance(other, SymIntEqByExpr)
+        return self.val == other.val
+
+    def __hash__(self) -> int:
+        return hash(self.val)
+
+
+def _nested_int_aware_sort(
+    tup: tuple[IntLikeType, int],
+) -> tuple[int, IntLikeType, int]:
+    return (
+        # Order nested ints by their coefficients.
+        # 1 here to order nested ints after non-nested-ints.
+        (1, tup[0].node.nested_int_coeff(), tup[1])
+        if is_nested_int(tup[0])
+        else (0, *tup)
+    )
+
+
+def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None:
+    """Gets a size hint for a given expression from the underlying shapes we had.
+    Does not introduce a guard, so only use this when you can guarantee that
+    your code is still valid for arbitrary shapes (such as optimization decisions)
+    """
+    if isinstance(x, int):
+        return x
+    assert isinstance(x, torch.SymInt)
+    return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none)
+
+
+# Wrapper on lru_cache that reports statistics at process end
+def lru_cache(
+    maxsize: Optional[int],
+) -> Callable[[Callable[..., _T]], functools._lru_cache_wrapper[_T]]:
+    def inner(f: Callable[..., _T]) -> functools._lru_cache_wrapper[_T]:
+        wrapped_f = functools.lru_cache(maxsize)(f)
+        old_cache_clear = wrapped_f.cache_clear
+        prev_hits = 0
+        prev_misses = 0
+
+        # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info
+        # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not
+        # weakref'able on some versions of Python
+
+        def cumulative_cache_info() -> functools._CacheInfo:
+            cur = wrapped_f.cache_info()
+            return functools._CacheInfo(
+                prev_hits + cur.hits,
+                prev_misses + cur.misses,
+                cur.maxsize,
+                cur.currsize,
+            )
+
+        def new_cache_clear() -> None:
+            nonlocal prev_hits, prev_misses
+            cur = wrapped_f.cache_info()
+            prev_hits += cur.hits
+            prev_misses += cur.misses
+            old_cache_clear()
+
+        wrapped_f.cache_clear = new_cache_clear  # type: ignore[attr-defined, method-assign]
+        wrapped_f.cumulative_cache_info = cumulative_cache_info  # type: ignore[attr-defined, method-assign]
+        if log.isEnabledFor(logging.DEBUG):
+            atexit.register(log_lru_cache_stats, wrapped_f)  # type: ignore[arg-type]
+        return wrapped_f
+
+    return inner
+
+
+# These are modules that contain generic code for interacting with ShapeEnv
+# which are unlikely to identify a particular interesting guard statement
+@lru_cache(None)
+def uninteresting_files() -> set[str]:
+    import torch._compile
+    import torch._dynamo.eval_frame
+    import torch._inductor.sizevars
+    import torch._library.custom_ops
+    import torch._library.fake_impl
+    import torch._logging
+    import torch._subclasses.fake_tensor
+    import torch._subclasses.meta_utils
+    import torch.export._trace
+
+    mods = [
+        sys.modules[__name__],
+        torch.export._trace,
+        torch.fx.experimental.recording,
+        torch.fx.experimental.sym_node,
+        torch.fx.interpreter,
+        torch.fx._symbolic_trace,
+        torch,
+        torch._compile,
+        torch._dynamo.eval_frame,
+        torch._inductor.sizevars,
+        torch._library.custom_ops,
+        torch._library.fake_impl,
+        torch._subclasses.meta_utils,
+        torch._subclasses.fake_tensor,
+        torch._logging._internal,
+        torch._logging.structured,
+    ]
+    import torch._dynamo.guards
+
+    return (
+        {inspect.getfile(m) for m in mods}
+        | torch._dynamo.guards.uninteresting_files()
+        | {""}
+    )
+
+
+class ConstraintViolationError(RuntimeError):
+    pass
+
+
+def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool:
+    return elem._has_symbolic_sizes_strides
+
+
+Int: TypeAlias = Union[torch.SymInt, int]
+
+
+def create_contiguous(shape: Sequence[Int]) -> list[Int]:
+    strides: list[Int] = [1]
+    for dim in reversed(shape[:-1]):
+        strides.append(dim * strides[-1])  # type: ignore[operator]
+    return list(reversed(strides))
+
+
+def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
+    """
+    Retrieve the hint for an int (based on the underlying real values as observed
+    at runtime).  If no hint is available (e.g., because data dependent shapes),
+    if fallback is not None, use that instead (otherwise raise an error).
+    """
+    if isinstance(a, torch.SymInt):
+        return a.node.require_hint(fallback)
+    assert type(a) is int, a
+    return a
+
+
+Scalar: TypeAlias = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
+
+
+def has_hint(a: Scalar) -> bool:
+    if isinstance(a, SymTypes):
+        return a.node.has_hint()
+    return True
+
+
+def is_concrete_int(a: IntLikeType) -> bool:
+    """
+    Utility to check if underlying object
+    in SymInt is concrete value. Also returns
+    true if integer is passed in.
+
+    Args:
+        a (SymInt or int): Object to test if it int
+    """
+    assert isinstance(a, (SymInt, int))
+
+    if isinstance(a, int):
+        return True
+
+    if isinstance(a.node.expr, sympy.core.numbers.Integer):
+        return True
+
+    return False
+
+
+def is_concrete_float(a: FloatLikeType) -> bool:
+    r"""Utility to check if underlying object
+    in SymInt is concrete value. Also returns
+    true if integer is passed in.
+
+    Args:
+        a (SymInt or float): Object to test if it float
+    """
+    assert isinstance(a, (SymFloat, float))
+
+    if isinstance(a, float):
+        return True
+
+    if isinstance(a.node.expr, sympy.core.numbers.Float):
+        return True
+
+    return False
+
+
+def is_concrete_bool(a: BoolLikeType) -> bool:
+    """
+    Utility to check if underlying object
+    in SymBool is concrete value. Also returns
+    true if integer is passed in.
+
+    Args:
+        a (SymBool or bool): Object to test if it bool
+    """
+    assert isinstance(a, (SymBool, bool))
+
+    if isinstance(a, bool):
+        return True
+
+    if isinstance(
+        a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)
+    ):
+        return True
+
+    return False
+
+
+def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> bool:
+    """
+    User-code friendly utility to check if a value is static or dynamic.
+    Returns true if given a constant, or a symbolic expression with a fixed value.
+
+    Args:
+        a (Union[SymBool, SymFloat, SymInt, bool, float, int]): Object to test
+    """
+    assert isinstance(a, BoolLike + FloatLike + IntLike)
+    if (
+        isinstance(a, BoolLike)
+        and is_concrete_bool(a)  # type: ignore[arg-type]
+        or isinstance(a, FloatLike)
+        and is_concrete_float(a)  # type: ignore[arg-type]
+        or isinstance(a, IntLike)
+        and is_concrete_int(a)  # type: ignore[arg-type]
+    ):
+        return True
+
+    assert isinstance(a, py_sym_types)
+    return a.node.shape_env.bound_sympy(a.node.expr).is_singleton()  # type: ignore[union-attr]
+
+
+def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
+    """
+    Perform a guard on a symbolic boolean expression in a size oblivious way.
+    This is typically used when a non-oblivious test would result in a guard
+    on a data dependent value of which we don't know the value of at compile time.
+    When a guard is tested this way, we may diverge in behavior from how regular
+    PyTorch semantics would treat it.  For more information, see
+    https://github.com/pytorch/pytorch/pull/118579
+    """
+    if isinstance(expr, torch.SymBool):
+        return expr.node.guard_size_oblivious("", 0)
+    else:
+        assert isinstance(expr, bool), expr
+        return expr
+
+
+def check_consistent(new: _T, old: _T) -> None:
+    """
+    Test that two "meta" values (typically either Tensor or SymInt) have
+    the same values, e.g., after retracing.  If we don't understand the
+    quantities in question, we'll just skip the consistency check.
+    """
+    # TODO: do boolean equality test too, see
+    # https://github.com/pytorch/pytorch/issues/124110
+    scalar_types = (torch.SymInt, torch.SymFloat, int, float)
+
+    if isinstance(new, torch.Tensor):
+        assert isinstance(old, torch.Tensor)
+        torch._check(
+            old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)"
+        )
+        # Do this manually so that each individual test is irrefutable
+        # (TODO: should be a helper for this, maybe sym_eq?  That
+        # gives us a compound expression and I'm not sure it
+        # simplifies right now)
+        for i, j in zip(old.shape, new.shape):
+            torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
+    # NB: bool is subclass of int
+    elif isinstance(new, scalar_types) and not isinstance(new, bool):
+        assert isinstance(old, scalar_types) and not isinstance(old, bool), (
+            f"{old} != {new}"
+        )
+        torch._check(old == new, lambda: f"{old} != {new} (old != new)")
+
+
+def resolve_unbacked_bindings(
+    shape_env: Optional[ShapeEnv],
+    bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
+) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
+    """
+    When we do fake tensor prop, we oftentimes will allocate new unbacked symints.
+    We then run proxy tensor mode, which populates node.meta["unbacked_bindings"]
+    with these new symints. To ensure consistency we use PropagateUnbackedSymInts
+    to rename unbacked bindings to their old ones. But all of the node metas are
+    still using the old bindings from before the renaming. This function helps to
+    post facto apply any renamings discovered in the PropogateUnbackedSymInts pass.
+    """
+    if bindings is None:
+        return None
+    assert shape_env is not None
+    return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()}
+
+
+Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]]
+
+
+def rebind_unbacked(
+    shape_env: Optional[ShapeEnv], n: torch.fx.Node, result: Result
+) -> None:
+    """
+    Suppose we are retracing a pre-existing FX graph that previously had
+    fake tensor propagation (and therefore unbacked SymInts).  When we retrace,
+    we re-propagate fake tensors, which results in new unbacked SymInts.
+    When this happens, we need to tell the shape environment about the equivalence
+    of the old and new unbacked SymInts.  Pass us the old torch.fx.Node (which
+    has the old binding information) and the new result (which we can extract the
+    new unbacked SymInts out from).
+    """
+
+    # Inputs never need rebinding
+    if n.op == "placeholder":
+        return
+
+    if bindings := resolve_unbacked_bindings(
+        shape_env, n.meta.get("unbacked_bindings")
+    ):
+        assert shape_env is not None
+        for raw_u0, path in bindings.items():
+            u1 = pytree.key_get(result, path)
+
+            # Sometimes, things were previously unbacked bindings become constants.
+            # There are two situations this can happen.
+            #
+            # First, you might have a runtime assert that causes the
+            # constant-ification.  In this case, the /binding/ itself will
+            # still be an unbacked symbol (because we will only force it
+            # to be a constant later in fake tensor propagation).  In this
+            # case, u1 is a SymInt and we still do all our work as normal.
+            #
+            # But second, it might be that fake tensor propagation DIRECTLY
+            # converted the unbacked SymInt into a constant.  This happens
+            # more rarely, but we have identified two situations it can
+            # validly occur:
+            #
+            # - If you have a tensor_version operator, these are initially
+            #   allocated as unbacked SymInts, but after AOTAutograd they
+            #   get forced specialized to specific values.  In this case,
+            #   there is no reason to do runtime asserts on them, this is
+            #   just a hack to properly keep track of them to start.
+            #
+            # - If you have an item() call on a constant tensor, the result
+            #   of the item() call is constant and we do not need runtime
+            #   asserts on this symbol.  In
+            #   https://github.com/pytorch/pytorch/issues/140625 we have a
+            #   case where in the initial trace of the program we are unable
+            #   to determine that torch.tensor is constant, but then
+            #   subsequent passes cause torch.tensor to become a constant and
+            #   then the unbacked symbol goes poof.
+            #
+            # In all of these cases, it is no longer necessary to generate
+            # deferred runtime asserts, since other subsystems (e.g., the
+            # constant-ification pass) ensure that the quantity is now truly
+            # static and cannot change at runtime.  So it's OK to discard
+            # in these situations.
+            #
+            # There is one more hazard (re
+            # https://github.com/pytorch/pytorch/issues/141248), the problem
+            # is that you can end up with "dangling" unbacked symbols that
+            # exist in the ShapeEnv but are never bound anywhere.  You might
+            # like an invariant that unbacked symbols never get lost.  But
+            # we do not have this invariant, so do not try to enforce it.
+            if isinstance(u1, (int, float)):
+                log.info(
+                    "rebind_unbacked: discard %s %s %s -> %s",
+                    n.target,
+                    raw_u0,
+                    path,
+                    u1,
+                )
+                continue
+
+            # We only care about rebinding unbacked things
+            if u1.node.hint is not None:
+                continue
+
+            # unbacked symbols bindings might be replaced to other backed or
+            # unbacked replacements.
+            #
+            # Example:
+            #   u = x.item()
+            #   torch._check(u == 5)
+            #
+            # The safest approach is to retrieve raw_u1 from u1.node._expr
+            # and perform the rebinding on the original unbacked symbol,
+            # even if it’s no longer directly referenced.
+            #
+            # In other words, we should always rebind the original symbol
+            # before any replacements are applied.
+            #   u0 -> u0 == s1
+            raw_u1 = u1.node._expr
+
+            # TODO Do we still need this logic below?
+            # Simplify SymBool binding
+            if (
+                isinstance(raw_u1, sympy.Piecewise)
+                and len(raw_u1.args) == 2
+                and (
+                    raw_u1_args0 := cast(
+                        tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
+                    )
+                )
+                and raw_u1_args0[0] == 1
+                and isinstance(eq := raw_u1_args0[1], sympy.Eq)
+                and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol)
+                and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1))
+                and eq.rhs == 1
+                and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
+            ):
+                # This is what the pattern match above is testing
+                repacked = _sympy_cast_symbool_to_symint_guardless(
+                    sympy.Eq(new_raw_u1, 1)
+                )
+                assert repacked == raw_u1, f"{repacked} != {raw_u1}"
+                # Cancel the to_int(to_bool(x)). This is sound because x in
+                # [0, 1]
+
+                raw_u1 = new_raw_u1
+
+            if not isinstance(raw_u1, sympy.Symbol):
+                assert not raw_u1.free_symbols, (
+                    f"should have been constant, but got {raw_u1}"
+                )
+                continue
+
+            # The old and new could be the same if you improperly hit the memo
+            # while retracing.  Make sure you updated FakeTensorMode.epoch
+            assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster"
+            # Reuse the OLD symbol name
+            shape_env._rename_unbacked_to(raw_u1, raw_u0)
+
+
+# NB: You could try to expand this to cover more cases by simply
+# detecting whenever you have an int output, but this is a bit
+# dangerous in case someone adds a function that returns an int but is
+# mutating.  So manually whitelist for now.
+def is_accessor_node(node: torch.fx.Node) -> bool:
+    """
+    Helper function to determine if a node is trying to access
+    a symbolic integer such as size, stride, offset or item. Currently
+    primarily only used in a DCE pass to figure out purity.
+    """
+
+    # Dynamo only exercised condition
+    if (
+        node.op == "call_method"
+        and isinstance(node.args[0], torch.fx.Node)
+        and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
+        and node.target in ["size", "stride", "storage_offset", "item"]
+    ):
+        return True
+
+    if node.op == "call_function" and node.target in [
+        torch.ops.aten.sym_size,
+        torch.ops.aten.sym_size.default,
+        torch.ops.aten.sym_size.int,
+        torch.ops.aten.sym_stride,
+        torch.ops.aten.sym_stride.default,
+        torch.ops.aten.sym_stride.int,
+        torch.ops.aten.sym_storage_offset,
+        torch.ops.aten.sym_storage_offset.default,
+        torch.ops.aten.sym_numel.default,
+    ]:
+        return True
+
+    return False
+
+
+def canonicalize_bool_expr(expr: _T) -> _T:
+    """
+    Canonicalize a boolean expression by transforming it into a lt / le
+    inequality and moving all the non-constant terms to the rhs.
+    We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
+    recursively
+    nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
+
+    Args:
+        expr (sympy.Expr): Expression to canonicalize
+    """
+    # Canonicalise an inequality by transforming it into a lt / le
+    # inequality and moving all the non-constant terms to the rhs
+    # We canonicalise And / Ors / Not via cnf
+    # nb. Relational.canonical in sympy is broken
+    # https://github.com/sympy/sympy/issues/25924
+
+    if not isinstance(
+        expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)
+    ):
+        return expr
+
+    if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
+        expr = sympy.logic.boolalg.to_cnf(expr)
+    return _canonicalize_bool_expr_impl(expr)  # type: ignore[arg-type, return-value]
+
+
+def _sympy_from_args(
+    cls: type[Union[sympy.Add, sympy.Mul]],
+    args: list[sympy.Expr],
+    sort: bool = True,
+    is_commutative: Optional[bool] = None,
+) -> sympy.Expr:
+    """
+    Create a sympy expression from a list of arguments, optimizing for performance.
+
+    This function creates a sympy Add or Mul expression from a list of arguments
+    while avoiding expensive operations like flattening. It handles sorting the
+    arguments appropriately based on the expression type.
+
+    Args:
+        cls: The sympy class to create (Add or Mul)
+        args: List of sympy expressions to combine
+        sort: Whether to sort the arguments (default: True)
+        is_commutative: Whether the operation is commutative (default: None)
+
+    Returns:
+        A sympy expression of type cls combining all arguments
+
+    Raises:
+        ValueError: If cls is not sympy.Add or sympy.Mul
+    """
+
+    if not args:
+        return cls.identity  # type: ignore[union-attr]
+
+    # These args are already in canonical form, so we avoid calling
+    # Add(*args) to avoid expensive Add.flatten operation
+    if sort:
+        if cls is sympy.Add:
+            sort_fn = sympy.core.add._addsort
+        elif cls is sympy.Mul:
+            sort_fn = sympy.core.mul._mulsort
+        else:
+            raise ValueError(f"Unknown cls: {cls}")
+
+        # we don't support non commutative with sort
+        assert is_commutative is True
+        if args[0].is_Number:
+            rest = args[1:]
+            sort_fn(rest)
+            return cls._from_args([args[0]] + rest, is_commutative=is_commutative)  # type: ignore[attr-defined]
+        else:
+            args = args.copy()
+            sort_fn(args)
+            return cls._from_args(args, is_commutative=is_commutative)  # type: ignore[attr-defined]
+    else:
+        # if the args are already sorted, we create directly
+        return cls._from_args(args, is_commutative=is_commutative)  # type: ignore[attr-defined]
+
+
+def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
+    """
+    After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
+    (rewriting them to Le/Lt, respectively).
+    """
+    if isinstance(expr, (sympy.And, sympy.Or)):
+        return type(expr)(*map(canonicalize_bool_expr, expr.args))
+
+    opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
+    t: Union[type[Any]]
+    if isinstance(expr, tuple(opposite.keys())):
+        rhs = expr.lhs - expr.rhs  # type: ignore[attr-defined]
+        t = opposite[type(expr)]  # type: ignore[index]
+    else:
+        assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
+        rhs = expr.rhs - expr.lhs
+        t = type(expr)
+
+    def is_neg(t: sympy.Expr) -> bool:
+        return (t.is_Number and t.is_negative) or (
+            isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative
+        )
+
+    lhs = S.Zero
+    rhs = _reduce_to_lowest_terms(rhs)
+    if isinstance(rhs, sympy.Add):
+        pos = []
+        neg = []
+        for term in rhs.args:
+            if is_neg(term):
+                neg.append(-term)
+            else:
+                pos.append(term)
+        # these are already sorted
+        rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True)
+        # the terms were changed, so needs a sorting
+        lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True)
+    elif is_neg(rhs):
+        # lhs == 0
+        lhs, rhs = -rhs, S.Zero
+    # We don't have to evaluate here because lhs, rhs came from a Boolean
+    # and it was already simplified
+    return t(lhs, rhs, evaluate=False)
+
+
+def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
+    """
+    Eliminates any integer factor from a given expression.
+    E.g., 6x + 4y reduces to 3x + 2y.
+
+    Useful when an expression is == or != to 0.
+    """
+
+    def integer_coefficient(x: sympy.Expr) -> int:
+        if x.is_Integer:
+            return abs(int(x))
+        elif x.is_Mul:
+            # If one of the args of a Mul is an Integer, it is the
+            # first arg. eg: args(2*x*3*y) == (6, x, y)
+            return abs(int(x.args[0])) if x.args[0].is_Integer else 1  # type: ignore[call-overload]
+        else:
+            return 1
+
+    def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr:
+        if x.is_Integer:
+            return x / factor
+        elif x.is_Mul:
+            if x.args[0] != factor:
+                args = [x.args[0] / sympy.Integer(factor), *x.args[1:]]
+            else:
+                # Mul._from_args require a canonical list of args
+                # so we remove the first arg (x.args[0] / factor) if it was 1
+                args = list(x.args[1:])
+            return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative)
+        else:
+            raise AssertionError(f"illegal arg to div_by_factor: {x}")
+
+    if expr.is_Add:
+        atoms = cast(Sequence[sympy.Expr], expr.args)
+        factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
+        if factor == 1:
+            return expr
+        # pyrefly: ignore [bad-argument-type]
+        atoms = [div_by_factor(x, factor) for x in atoms]
+        return _sympy_from_args(
+            sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
+        )
+    elif expr.is_Integer:
+        return S.One
+    elif expr.is_Mul:
+        return div_by_factor(expr, integer_coefficient(expr))
+    return expr
+
+
+def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]:
+    return isinstance(s, torch.SymInt) and s.node.is_nested_int()
+
+
+IterateExprsAtom: TypeAlias = Union[
+    SymInt, SymFloat, SymBool, int, float, bool, sympy.Basic, torch.Tensor
+]
+IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]]
+
+
+def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
+    """
+    Recursively iterate through a value and yield all sympy expressions contained within it.
+
+    This function traverses various data structures (tensors, lists, tuples, etc.) and extracts
+    any symbolic expressions they contain. It's used for operations like finding free symbols
+    in complex nested structures.
+
+    Args:
+        val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool),
+             a sympy expression, a primitive type (int, float, bool), a container (tuple, list),
+             a sparse tensor, a regular tensor, None, or a torch.Generator.
+
+    Yields:
+        sympy.Basic: Each sympy expression found in the value.
+
+    Raises:
+        AssertionError: If the value is of an unsupported type.
+    """
+    # This is almost close enough to implement in terms of _iterate_nodes()
+    # except that it needs to handle `list[sympy.Basic]` which _iterate_nodes()
+    # can't handle.
+    if isinstance(val, SymTypes):
+        # This allow applies to the jagged layout NestedTensor case as
+        # nested ints are not symbolic
+        if is_symbolic(val):
+            yield val.node.expr
+    elif isinstance(val, SymNode):
+        yield val.expr
+    elif isinstance(val, sympy.Basic):
+        yield val
+    elif isinstance(val, (int, float, bool)):
+        pass
+    elif isinstance(val, (tuple, list)):
+        for s in val:
+            yield from _iterate_exprs(s)
+    elif is_sparse_any(val):
+        yield from _iterate_exprs(val.size())
+    elif isinstance(val, torch.Tensor):
+        yield from _iterate_exprs(val.size())
+        yield from _iterate_exprs(val.stride())
+        yield from _iterate_exprs(val.storage_offset())
+    elif val is None:
+        pass
+    # see Note: [Generator arguments in AOTDispatcher]
+    elif isinstance(val, torch.Generator):
+        pass
+    else:
+        raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
+
+
+def _iterate_nodes(val: Any) -> Iterator[SymNode]:
+    """
+    Recursively iterate through a value and yield all SymNodes contained
+    within it.
+    """
+    if isinstance(val, SymNode):
+        yield val
+    elif isinstance(val, py_sym_types):
+        # This allow applies to the jagged layout NestedTensor case as
+        # nested ints are not symbolic
+        if is_symbolic(val):
+            yield val.node
+    elif isinstance(val, (tuple, list, torch.Size)):
+        for s in val:
+            yield from _iterate_nodes(s)
+    elif isinstance(val, torch.Tensor):
+        yield from _iterate_nodes(val.size())
+        if not is_sparse_any(val):
+            yield from _iterate_nodes(val.stride())
+            yield from _iterate_nodes(val.storage_offset())
+
+
+def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]:
+    """
+    Recursively collect all free symbols from a value.
+
+    This function traverses various data structures (tensors, lists, tuples, etc.) and extracts
+    all sympy symbols contained within them. It's useful for finding all symbolic variables
+    that a complex nested structure depends on.
+
+    Args:
+        val: The value to extract symbols from. Can be a symbolic type (SymInt, SymFloat, SymBool),
+             a container (tuple, list), a tensor, or None.
+
+    Returns:
+        OrderedSet[sympy.Symbol]: An ordered set of all free symbols found in the value.
+    """
+    if val is None:
+        return OrderedSet()
+
+    itr = _iterate_exprs(val)
+
+    # we need at least 1 to call union, so we hand code the identity
+    try:
+        first_expr = next(itr)
+    except StopIteration:
+        return OrderedSet()
+
+    # TODO: Apparently, returning an OrderedSet here breaks
+    # python test/distributed/tensor/test_dtensor_compile.py TestDTensorCompile.test_dtensor_dynamic
+    return first_expr.free_symbols.union(*(e.free_symbols for e in itr))  # type: ignore[return-value]
+
+
+def has_free_symbols(val: IterateExprs) -> bool:
+    """Faster version of bool(free_symbols(val))"""
+    return not all((e.is_number or e.is_Boolean) for e in _iterate_exprs(val))
+
+
+def has_free_unbacked_symbols(x: IterateExprs) -> bool:
+    """Faster version of bool(free_unbacked_symbols(val))"""
+    from sympy.core.traversal import iterargs
+
+    for s in _iterate_exprs(x):
+        for arg in iterargs(s):
+            if arg.is_Symbol and symbol_is_type(
+                arg, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)
+            ):
+                return True
+    return False
+
+
+def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]:
+    """Like free_symbols, but filtered to only report unbacked symbols"""
+
+    # NB: keep synced with is_unbacked_symint
+    return OrderedSet(
+        s
+        for s in free_symbols(x)
+        if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))
+    )
+
+
+def _free_non_source_unbacked_symbols(
+    x: IterateExprs, unbacked_inputs: OrderedSet[sympy.Symbol]
+) -> OrderedSet[sympy.Symbol]:
+    """Unbacked symbols that are not inputs to the graph. These are symbols that originated from
+    data-dependent operations as opposed to mark_unbacked calls."""
+    unbacked_symbols = free_unbacked_symbols(x)
+    non_source_symbols = unbacked_symbols - unbacked_inputs
+    return non_source_symbols
+
+
+# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
+# setup!
+def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
+    """
+    Check if a given FX node is a symbol binding node.
+
+    A symbol binding node is one that has a SymInt value in its meta that contains
+    a sympy Symbol expression, and is either a placeholder node or contains unbacked symbols.
+
+    Args:
+        node (torch.fx.Node): The FX node to check
+
+    Returns:
+        Optional[sympy.Symbol]: The sympy Symbol if the node is a symbol binding node, None otherwise
+    """
+    if (
+        "val" in node.meta
+        and isinstance(node.meta["val"], torch.SymInt)
+        and isinstance(node.meta["val"].node.expr, sympy.Symbol)
+        and (
+            node.op == "placeholder"
+            or free_unbacked_symbols(node.meta["val"].node.expr)
+        )
+    ):
+        return node.meta["val"].node.expr
+    return None
+
+
+def find_symbol_binding_fx_nodes(
+    graph: torch.fx.Graph,
+) -> dict[sympy.Symbol, torch.fx.Node]:
+    """
+    Find all nodes in an FX graph that bind sympy Symbols.
+
+    This function scans through all nodes in the given FX graph and identifies
+    nodes that bind sympy Symbols (typically placeholder nodes with SymInt values).
+    When multiple nodes bind the same symbol, only the first occurrence is kept.
+
+    Args:
+        graph: The FX graph to search for symbol binding nodes
+
+    Returns:
+        A dictionary mapping from sympy Symbols to their binding FX nodes
+    """
+    r = {}
+    # NB: Prefer first occurrence of symbol
+    for node in graph.nodes:
+        if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
+            r[s] = node
+    return r
+
+
+@dataclass(frozen=True)
+class Specialization:
+    """
+    This class is used in multi-graph compilation contexts where we generate
+    multiple specialized graphs and dispatch to the appropriate one at runtime.
+    This allows us to optimize the trade-off between performance and generality
+    by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0)
+    while maintaining a general fallback.
+    """
+
+    source: TensorPropertySource
+    check_fn: Callable
+
+
+# Analogous to ConvertIntSource
+@dataclass(frozen=True)
+class ConvertIntKey:
+    def __str__(self) -> str:
+        return ".cast_symbool_to_symint_guardless()"
+
+    def get(self, b: bool) -> IntLikeType:
+        """Get the int value from bool"""
+        return cast_symbool_to_symint_guardless(b)
+
+
+@dataclass(frozen=True)
+class CallMethodKey:
+    name: str
+
+    def __str__(self) -> str:
+        return f".{self.name}()"
+
+    def get(self, o: Any) -> Any:
+        """Call the method on object"""
+        return getattr(o, self.name)()
+
+
+@dataclass(frozen=True)
+class InnerTensorKey:
+    inner_name: str
+
+    def __str__(self) -> str:
+        return f".{self.inner_name}"
+
+    def get(self, o: Any) -> Any:
+        """Get the inner tensor attribute"""
+        return getattr(o, self.inner_name)
+
+
+@dataclass(frozen=True)
+class DivideByKey:
+    divisor: IntLikeType
+
+    def __str__(self) -> str:
+        return f".__floordiv__({self.divisor})"
+
+    def get(self, o: int) -> int:
+        """Divide object by divisor"""
+        return o // self.divisor
+
+
+def _free_unbacked_symbols_with_path(
+    a: object,
+    path: pytree.KeyPath,
+    real: Optional[object] = None,
+    shape_env: Optional[ShapeEnv] = None,
+    pending: Optional[set[sympy.Symbol]] = None,
+    simplify: bool = False,
+) -> dict[sympy.Symbol, pytree.KeyPath]:
+    """
+    Recursively traverses a structure to find unbacked symbols and their access paths.
+
+    This function walks through tensors, lists, tuples, and symbolic values to locate
+    unbacked symbols that are in the pending set, and returns a mapping from those
+    symbols to their access paths in the structure.
+
+    Args:
+        a: The object to traverse (tensor, list, tuple, SymInt, etc.)
+        path: The current path in the object tree
+        real: Optional real tensor corresponding to the fake tensor being traversed
+        shape_env: Optional ShapeEnv to register unbacked values with
+        pending: Set of unbacked symbols to look for (will be modified in-place)
+        simplify: Whether to use simplified expressions
+
+    Returns:
+        A dictionary mapping unbacked symbols to their access paths
+    """
+    go = functools.partial(
+        _free_unbacked_symbols_with_path,
+        shape_env=shape_env,
+        pending=pending,
+        simplify=simplify,
+    )
+
+    def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr:
+        if simplify:
+            return s.node.expr
+        # (When called from compute_unbacked_bindings)
+        # NB: Intentionally access _expr, not expr, do not want
+        # simplification!
+        return s.node._expr
+
+    if pending is None:
+        pending = set()
+    r = {}
+
+    def match_tensor(a: torch.Tensor, real_tensor: Optional[torch.Tensor] = None):
+        r.update(
+            go(
+                a.size(),
+                path + (CallMethodKey("size"),),
+                real=real_tensor.size() if real_tensor is not None else None,
+            )
+        )
+        if a.layout not in [
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        ]:
+            r.update(
+                go(
+                    a.stride(),
+                    path + (CallMethodKey("stride"),),
+                    real=real_tensor.stride() if real_tensor is not None else None,
+                )
+            )
+        r.update(
+            go(
+                a.storage_offset(),
+                path + (CallMethodKey("storage_offset"),),
+                real=(
+                    real_tensor.storage_offset() if real_tensor is not None else None
+                ),
+            )
+        )
+
+    if isinstance(a, (tuple, list)):
+        # NB: real is apparently not always a tuple/list here
+        # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu
+        for i in range(len(a)):
+            r.update(
+                go(
+                    a[i],
+                    path + (pytree.SequenceKey(i),),
+                    real=real[i] if real is not None else None,  # type: ignore[index]
+                )
+            )
+    elif is_traceable_wrapper_subclass(a):
+        # TODO: Determine if this is correct
+        attrs, _ = a.__tensor_flatten__()
+        for attr in attrs:
+            sub = getattr(a, attr)
+            r.update(go(sub, path + (InnerTensorKey(attr),)))
+
+        # match DTensor outer shapes
+        if torch.distributed.is_available() and isinstance(
+            a, torch.distributed.tensor.DTensor
+        ):
+            match_tensor(a)
+    elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
+        unwrapped_tensor = get_unwrapped(a)
+        r.update(go(unwrapped_tensor, path))
+    elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
+        from torch._subclasses.fake_tensor import FakeTensor
+
+        assert isinstance(a, FakeTensor)
+        match_tensor(a, a.real_tensor)
+    elif (
+        isinstance(a, (torch.SymInt, torch.SymFloat))
+        and isinstance(s := expr(a), sympy.Symbol)
+        and s in pending
+    ):
+        r[s] = path
+        if shape_env and real is not None:
+            assert isinstance(real, (int, float))
+
+            shape_env.set_unbacked_var_to_val(s, real)
+
+        pending.remove(s)
+    # When an unbacked SymInt is perfectly divisible by an integer
+    # constant, we replace it with the integer constant to improve
+    # reasoning capabilities.  However, in synthetic examples, it is
+    # then possible that the factor never is explicitly allocated.
+    # Fortunately, we can compute it by division.
+    elif (
+        isinstance(a, torch.SymInt)
+        and isinstance(s := expr(a), sympy.Mul)
+        and len(s.args) == 2
+        and isinstance(lhs := s.args[0], (sympy.Integer, sympy.Symbol))
+        and isinstance(rhs := s.args[1], sympy.Symbol)
+        # support exactly one unbacked for now
+        and ((rhs in pending) ^ (lhs in pending))
+        # support constant coefficient or backed symbolic coefficient
+        and (
+            isinstance(coeff := lhs if lhs not in pending else rhs, sympy.Integer)
+            or shape_env
+            and coeff in shape_env.var_to_val
+        )
+    ):
+
+        def _symint_wrap(s: sympy.Symbol) -> SymInt:
+            return shape_env.create_symintnode(  # type: ignore[union-attr]
+                s,
+                hint=int(shape_env.var_to_val[s]),  # type: ignore[union-attr]
+                source=shape_env.var_to_sources.get(s, [None])[0],  # type: ignore[union-attr]
+            )
+
+        unbacked = lhs if lhs in pending else rhs
+        divisor: IntLikeType = (
+            int(coeff)
+            if shape_env and isinstance(coeff, sympy.Integer)
+            else _symint_wrap(coeff)
+        )
+        # TODO: DivideByKey needs to test divisibility at runtime!
+
+        r[unbacked] = path + (DivideByKey(divisor),)
+        if real is not None:
+            assert isinstance(real, int)
+            val = (
+                real // int(coeff)
+                if isinstance(coeff, sympy.Integer)
+                else CleanDiv(real, coeff)
+            )
+            if shape_env:
+                shape_env.set_unbacked_var_to_val(unbacked, val)
+        pending.remove(unbacked)
+    # The annoyance here arises from the fact that SymBool is
+    # allocated by allocating a SymInt and then testing if it's equal
+    # to one.  So you have a complicated binding site logic for this.
+    elif (
+        isinstance(a, torch.SymBool)
+        and isinstance(s := expr(a), sympy.Eq)
+        # This must match create_unbacked_symbool EXACTLY
+        and isinstance(s.lhs, sympy.Symbol)
+        and s.rhs == 1
+        and s.lhs in pending
+    ):
+        r[s.lhs] = path + (ConvertIntKey(),)
+        if real is not None:
+            assert type(real) is bool
+            if shape_env:
+                shape_env.set_unbacked_var_to_val(s, int(real))
+
+        pending.remove(s.lhs)
+
+    return r
+
+
+def compute_unbacked_bindings(
+    shape_env: Optional[ShapeEnv],
+    example_value: object,
+    old_example_value: Optional[object] = None,
+    peek: bool = False,
+) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
+    """
+    After having run fake tensor propagation and producing example_value
+    result, traverse example_value looking for freshly bound unbacked
+    symbols and record their paths for later.  It is an error if
+    we have allocated an unbacked SymInt but it cannot be found in
+    example_value.  (NB: this means if you have a multi-output
+    function, you must call this on the tuple of tensor output, you
+    cannot wait!)
+
+    The peek parameter lets you check out what the bindings are without
+    changing the affected list.  This is primarily useful for ensuring
+    unbacked_var_to_val is promptly populated when propagate_real_tensors is on.
+    """
+    if shape_env is None:
+        return None
+
+    fs = shape_env.pending_fresh_unbacked_symbols
+
+    pending = set(fs)
+    if not pending:
+        return None
+
+    if not peek:
+        log.info("compute_unbacked_bindings %s", fs)
+        fs.clear()
+
+    symbol_to_path = _free_unbacked_symbols_with_path(
+        example_value, (), shape_env=shape_env, pending=pending, simplify=False
+    )
+    if not peek and pending:
+        extra = (
+            repr((example_value.stride(), example_value.storage_offset()))
+            if isinstance(example_value, torch.Tensor)
+            else ""
+        )
+        raise PendingUnbackedSymbolNotFound(
+            f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n"
+            "Did you accidentally call new_dynamic_size() or item() more times "
+            "than you needed to in your fake implementation?\n"
+            "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit"
+        )
+
+    # Why do we have to do some rebinding here?  If the original FX node
+    # wasn't a binding site because you had a memo hit, but post
+    # translation you aren't a memo hit anymore, there's now a new binding
+    # site... but we know (because it's the same FX node) that the value
+    # is actually the same, they're just not obviously equal anymore.
+    #
+    # The logic here is written carefully, because unlike the
+    # bind_unbacked case, we are not guaranteed to have a symbol for
+    # old_sym.  If we have a symbol, do regular rename unbacked to; but if
+    # we don't, we need to specially eliminate the fresh unbacked symbol
+    # (NB: we are /trusting/ that the memoization is correct, and that we
+    # don't need to generate a new runtime assert.  This is load bearing,
+    # as repropagation can happen after we've frozen runtime asserts.)
+    if old_example_value is not None:
+        for keypath in symbol_to_path.values():
+            old_sym = pytree.key_get(old_example_value, keypath)
+            new_sym = pytree.key_get(example_value, keypath)
+            if isinstance(new_sym, SymTypes) and isinstance(
+                new_s := new_sym.node.expr, sympy.Symbol
+            ):
+                if (
+                    isinstance(old_sym, SymTypes)
+                    and (old_s := old_sym.node.expr) != new_s
+                ):
+                    # If old_s is not an unbacked_symbol,
+                    # we assume that the original unbacked symbol is replaced
+                    # by a backed symbol (old_s). This can happen
+                    # when this node reuses the original symbol (due to memoi)
+                    # and the original symbol gets replaced by the backed symbol.
+                    # When this happens we just replace new_s by the old_s
+                    # because we know the value is the same.
+
+                    if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
+                        shape_env._rename_unbacked_to(new_s, old_s)
+                    else:
+                        shape_env._eliminate_unbacked(new_s, old_s)
+                elif not isinstance(old_sym, SymTypes):
+                    shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
+
+    return symbol_to_path
+
+
+# Note [guard_or_]
+# The following two functions are common utilities used while defining unbacked semantics
+# of various framework code. Those would be used in situations you prefer to guard and know
+# the result of the expression over not guarding, but in case you hit a data dependent error
+# you are ok with just returning true or false.
+#
+# When to use this?
+# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting).
+#
+# (2) It can be used if the program would behave equivalently if _guard_or returned true or false.
+# Many inductor optimizations fall in this bracket for example.
+#
+# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the
+# change is semantics preserving.  It can be semantics preserving if the program errors in more
+# cases than it did previously (but otherwise behaves identically), or if it changes some quantity
+# in a way that doesn't matter (e.g., strides often fall in this bucket.)
+#
+# (4) Specialize for the general case and add a runtime assertion that would fail during
+#     runtime if the conditions for the general case are not satisfied. Examples for this are;
+#      assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path.
+#
+def _guard_or(a: BoolLikeType, default: bool) -> bool:
+    """
+    Try to guard a, if data dependent error encountered just return default.
+    """
+    if not isinstance(a, SymBool):
+        assert isinstance(a, bool)
+        return a
+
+    # if backed_size_oblivious is True we treat backed as unbacked here.
+    if torch.fx.experimental._config.backed_size_oblivious:
+        result = _static_eval_sym_bool(a)
+        return result if result is not None else default
+
+    shape_env = getattr(a.node, "shape_env", None)
+
+    # xla symnode path.
+    if shape_env is None:
+        return guard_bool(a)
+
+    sym_node = a.node
+    r = sym_node.shape_env.evaluate_sym_node(
+        sym_node, size_oblivious=False, fallback_value=default
+    )
+    return bool(r)
+
+
+def guard_or_false(a: BoolLikeType) -> bool:
+    """
+    Try to guard a, if data dependent error encountered just return false.
+    """
+    return _guard_or(a, False)
+
+
+def guard_or_true(a: BoolLikeType) -> bool:
+    """
+    Try to guard a, if data dependent error encountered just return true.
+    """
+    return _guard_or(a, True)
+
+
+def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
+    assert isinstance(x, SymBool)
+    expr = x.node.expr
+
+    try:
+        # Shape env access is inside the try on purpose. xla symnode does not
+        # have it on its attributes.
+        shape_env = x.node.shape_env
+        simplified = shape_env._maybe_evaluate_static(expr)
+        if simplified is not None:
+            return bool(simplified)
+        else:
+            return None
+    except Exception:
+        log.debug("Could not simplify %s", expr)
+        return None
+
+
+def statically_known_false(x: BoolLikeType) -> bool:
+    """
+    Returns True if x can be simplified to a constant and is False.
+    If x cannot be evaluated from static, we return False
+
+    .. note::
+        This function doesn't introduce new guards, so the expression may end
+        up evaluating to False at runtime even if this function returns False.
+
+    Args:
+        x (bool, SymBool): The expression to try statically evaluating
+    """
+    if not isinstance(x, SymBool):
+        assert isinstance(x, bool)
+        return not x
+
+    result = _static_eval_sym_bool(x)
+    if result is None:
+        return False
+
+    return not result
+
+
+def statically_known_true(x: BoolLikeType) -> bool:
+    """
+    Returns True if x can be simplified to a constant and is true.
+
+    .. note::
+        This function doesn't introduce new guards, so the expression may end
+        up evaluating to true at runtime even if this function returns False.
+
+    Args:
+        x (bool, SymBool): The expression to try statically evaluating
+    """
+    if not isinstance(x, SymBool):
+        assert isinstance(x, bool)
+        return x
+    result = _static_eval_sym_bool(x)
+    if result is None:
+        return False
+
+    return result
+
+
+def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
+    """
+    and, but for symbolic expressions, without bool casting.
+    """
+    if len(others) == 0:
+        return x
+    for y in others:
+        x = operator.and_(x, y)
+    return x
+
+
+def sym_eq(x: _T, y: _T) -> BoolLikeType:
+    """
+    Like ==, but when run on list/tuple, it will recursively test equality
+    and use sym_and to join the results together, without guarding.
+    """
+    if isinstance(x, (tuple, list)) and isinstance(y, (list, tuple)):
+        if len(x) != len(y):
+            return False
+        return functools.reduce(operator.and_, map(sym_eq, x, y), True)
+    elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
+        return x == y
+    else:
+        raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
+
+
+def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
+    """
+    or, but for symbolic expressions, without bool casting.
+    """
+    if len(others) == 0:
+        return x
+    for y in others:
+        x = operator.or_(x, y)
+    return x
+
+
+def guard_scalar(
+    a: Union[SymBool, SymInt, SymFloat, int, bool, float],
+) -> Union[bool, int, float]:
+    """
+    Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float.
+
+    This function dispatches to the appropriate guard function based on the type of the input.
+
+    Args:
+        a: A symbolic or concrete scalar value (bool, int, or float)
+
+    Returns:
+        The concrete value after guarding
+
+    Raises:
+        AssertionError: If the input is not a recognized scalar type
+    """
+    if isinstance(a, (SymBool, bool)):
+        return guard_bool(a)
+    elif isinstance(a, (SymInt, int)):
+        return guard_int(a)
+    elif isinstance(a, (SymFloat, float)):
+        return guard_float(a)
+    else:
+        raise AssertionError(f"unrecognized scalar {a}")
+
+
+def _advise_is_size(a: SymInt) -> None:
+    """
+    Don't use this directly; use torch._check_is_size instead.
+
+    This is a softer version of _constrain_range_for_size (with min=0,
+    max=Inf).  Instead of forcibly constraining a variable (and erroring if we
+    failed to constrain it), it will simply advise us that a size is
+    constrained in some way.  We will always defer a runtime assert for this
+    constraint if we cannot prove it at compile-time, but we we only
+    *sometimes* learn useful extra information at compile-time with this
+    information.  This is in contrast to constrain_range_for_size, where if
+    you don't call that on a fresh unbacked symint, chances are we will choke.
+
+    TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
+    code.  Right now this is only really used in code with AOTAutograd trace
+    through, so it is not a big problem that this isn't supported, but in
+    principle all of this code should be Dynamo'able too.
+
+    TODO: I didn't support min/max because I didn't have a use case where this
+    actually helped.  In principle we can support it, it just makes the
+    implementation below more complicated.
+    """
+
+    # This must always succeed, because the sole allowed caller _check_is_size
+    # was responsible for expect_true'ing this
+    # This assert triggers expensive sym compute, do not do it until its cheap.
+    # assert a >= 0
+
+    # NB: it's important not to constrain range for size for *hinted* SymInts,
+    # because it is not only unsound, it will immediately trip our asserts
+    # that hints have to be consistent with static analysis!  If you somehow
+    # have an unbounded SymInt that later constrains to 1, this will be
+    # inconsistent with the range
+    if (
+        isinstance(a, SymInt)
+        and isinstance(a.node, SymNode)
+        and isinstance(a.node.expr, sympy.Symbol)
+        and a.node.shape_env.is_unbacked_symint(a.node.expr)
+    ):
+        _constrain_range_for_size(a)
+
+
+def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None:
+    if (
+        isinstance(a, SymInt)
+        and isinstance(a.node, SymNode)
+        and isinstance(a.node.expr, sympy.Symbol)
+        and a.node.shape_env.is_unbacked_symint(a.node.expr)
+        and isinstance(upper_bound, int)  # TODO: relax
+    ):
+        a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound)
+
+
+def _constrain_range_for_size(
+    a: SymInt, min: Optional[int] = None, max: Optional[int] = None
+) -> None:
+    """
+    This function is NOT INTENDED to be used by itself.
+    """
+
+    if isinstance(a, (SymFloat, SymBool)):
+        raise ValueError("Constraining SymFloat/SymBool is nyi")
+
+    assert isinstance(a, SymInt), "can only constrain range for SymInt"
+    assert isinstance(a.node.expr, sympy.Symbol), f"constraining non-Symbols NYI: {a}"
+
+    a.node.shape_env._constrain_range_for_size(a.node.expr, min, max)
+
+
+# inclusive both ways
+def constrain_range(
+    a: SymInt, *, min: Optional[int], max: Optional[int] = None
+) -> None:
+    """
+    Applies a constraint that the passed in SymInt must lie between min-max
+    inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
+    that it can be used on unbacked SymInts).  If min/max are None, we assume
+    that the dimension is unbounded in that direction.  Repeated application
+    of constrain_range intersects the ranges.  This is a fairly low level API
+    that doesn't have a lot of safety guarantees (TODO: provide higher level
+    APIs).
+
+    Currently, we use this API in the following circumstance: when we allocate
+    an unbacked SymInt, denoting an integer quantity which is data dependent,
+    we ordinarily do not know anything about what values it may take.  This
+    means that any sort of guard on it will immediately fail.  However, in
+    many cases, we know something about the unbacked SymInt: for example, we
+    know that nonzero(x).size(0) must be >= 0.  We use constrain_range to
+    narrow the possible range, declaring that negative symbols are impossible.
+    This permits to definitely answer True to queries like 'nnz >= 0', even if
+    we don't know what the actual (hinted) value of 'nnz' is.  In fact, we
+    actually use constrain_range to unsoundly discharge common guards: for an
+    unbacked SymInt produced by nonzero, we will also assume that it is not
+    equal to 0/1 (even though these are perfectly possible values at runtime),
+    because we generally expect graphs that are valid for N=2 to also be valid
+    for N=1.
+    """
+    if min is None:
+        min = -int_oo
+    if max is None:
+        max = int_oo
+
+    if max < min:
+        raise ValueError(
+            "Maximum value to constrain_as_size can't be less than the specified min value, "
+            f"received min={min} and max={max}"
+        )
+
+    if isinstance(a, int):
+        if not (min <= a <= max):
+            raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
+        return
+
+    a.node.shape_env._constrain_range(a.node.expr, min, max)
+
+
+def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None:
+    """
+    Given two SymInts, constrain them so that they must be equal.  NB:
+    this will not work with SymInts that represent nontrivial expressions
+    (yet!)
+    """
+    if not isinstance(a, SymInt):
+        if not isinstance(b, SymInt):
+            assert a == b
+            return
+        else:
+            shape_env = b.node.shape_env
+    else:
+        shape_env = a.node.shape_env
+
+    shape_env._constrain_unify(a, b)
+
+
+# Assume that a boolean is true for the purposes of subsequent symbolic
+# reasoning.  This will keep track of corresponding runtime checks to verify
+# that the result is upheld: either as a regular guard, or as a special set
+# of asserts which are triggered when an unbacked SymInt is allocated.
+#
+# DO NOT use this function for these cases:
+#
+#  - This is inappropriate for "branching" conditions (where both
+#    true and false result in valid programs).  We will always assume
+#    the condition evaluates true, and so it will never be possible
+#    to trace the false condition when you use it.  For true branching
+#    on unbacked SymInts, you must use torch.cond; if you incorrectly
+#    use expect_true in this case, you will make the false branch
+#    unreachable (as we will simply assume that only the true branch
+#    is ever exercised).
+#
+#  - This is inappropriate for situations where you know some other system
+#    invariant guarantees that this property holds, since you don't
+#    really need to insert a runtime check in that case.  Use something
+#    like constrain_range in that case.
+#
+# This API has a hitch.  To avoid having to reimplement error reporting
+# capabilities, this function CAN return False.  The invariant is that
+# the surrounding code must raise an error when this function returns
+# False.  This is quite low level, so we recommend using other functions
+# like check() which enforce this in a more intuitive way.
+#
+# By the way, this name is a nod to the __builtin_expect macro,
+# which is used similarly (but unlike __builtin_expect, you MUST fail
+# in the unlikely branch.)  (I think expect is a good name; in recent
+# versions of C++, this is replaced with [[likely]], which is weaker
+# and not accurate for this function!)
+def expect_true(a: BoolLikeType, skip: int = 0) -> bool:
+    if isinstance(a, SymBool):
+        # TODO: check perf implications of this
+        frame = inspect.currentframe()
+        for _ in range(skip + 1):  # always run this loop at least once
+            if frame is None:
+                break
+            frame = frame.f_back
+        return a.node.expect_true(
+            frame.f_code.co_filename if frame else "", frame.f_lineno if frame else 0
+        )
+    assert type(a) is bool, a
+    return a
+
+
+def guard_bool(a: BoolLikeType) -> bool:
+    if isinstance(a, SymBool):
+        return a.node.guard_bool("", 0)  # NB: uses Python backtrace
+    assert type(a) is bool, a
+    return a
+
+
+def guard_int(a: IntLikeType) -> int:
+    if isinstance(a, SymInt):
+        return a.node.guard_int("", 0)  # NB: uses Python backtrace
+    assert type(a) is int, a
+    return a
+
+
+def guard_float(a: FloatLikeType) -> float:
+    if isinstance(a, SymFloat):
+        return a.node.guard_float("", 0)  # NB: uses Python backtrace
+    assert isinstance(a, float), a
+    return a
+
+
+# Given a GraphModule, return all the FakeTensors for all the placeholders
+def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]:
+    return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"]
+
+
+def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]:
+    return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
+
+
+# Given a GraphModule and arguments to run it with, evaluate that the guards
+# for its associated ShapeEnv are satisfied by the passed arguments.  This
+# WILL check for duck sizing.
+def eval_guards(
+    gm: torch.fx.GraphModule, *args: Tensor, ignore_static: bool = True
+) -> bool:
+    assert gm.shape_env is not None
+    return gm.shape_env.evaluate_guards_for_args(  # type: ignore[operator, union-attr]
+        fx_placeholder_vals(gm), args, ignore_static=ignore_static
+    )
+
+
+def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]:
+    assert gm.shape_env is not None
+    return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)  # type: ignore[operator, union-attr]
+
+
+class DimDynamic(Enum):
+    """
+    Controls how to perform symbol allocation for a dimension.  It is always
+    sound to default this to DYNAMIC, but the policies DUCK and STATIC can
+    result in better trace-time and compile-time performance, as they reduce
+    the number of allocated symbols and generally make your graph more static.
+
+    NB: If we notice you've applied a constraint to the dimension, we will
+    force it to DYNAMIC for simplicity.
+
+    DimDynamic is controlled by a variety of higher level UX features.
+    Currently:
+
+    - In eager mode, the default policy is DUCK.
+        - The default is changed to STATIC with assume_static_by_default.
+        - An individual dim is marked DYNAMIC if you mark_dynamic_dim.
+    - In export mode, the default policy is STATIC.
+        - An individual dim is marked DYNAMIC if you specify it in
+          dynamic_shapes passed to export.
+    """
+
+    # Treat the dimension symbolically
+    DYNAMIC = 0
+    # Treat the dimension symbolically, but if its hint matches another
+    # dynamic dimension, unify the two symbols ("duck sizing")
+    DUCK = 1
+    # Treat the dimension statically based on its hint
+    STATIC = 2
+    # Treat the dimension as a size-like unbacked
+    SIZE_LIKE_UNBACKED = 3
+    # Infer the strides from stride. If size is static, strides will be static as well.
+    INFER_STRIDE = 4
+    # Like SIZE_LIKE_UNBACKED, but there's a hint
+    OBLIVIOUS_SIZE = 5
+
+
+# NB: These constraints affect both clients and backends: given some
+# constraint C, the client must pass inputs that satisfy the constraint,
+# while a backend must not introduce guards BEYOND this constraint.
+# For clarity, we document the implications on both sides for both the client
+# and the backend.
+#
+# NB: These constraints are on a *single* dimension.  In principle, we could
+# also have multi-dimension constraints, but our guess is that this is not
+# actually useful and so we are not supporting it right now.
+#
+# NB: Strict constraints are typically only suitable for export, as in eager
+# a backend like inductor may validly introduce extra, discretionary guards
+# to improve performance of code.  A StrictMinMaxConstraint would be brittle
+# under future optimizations performed by inductor; we don't guarantee
+# eager code with StrictMinMaxConstraint will keep working in the future!
+
+
+@dataclass(frozen=True)
+class Constraint:
+    warn_only: bool
+
+
+@dataclass(frozen=True)
+class StrictMinMaxConstraint(Constraint):
+    """
+    For clients: the size at this dimension must be within 'vr' (which
+    specifies a lower and upper bound, inclusive-inclusive) AND it
+    must be non-negative and should not be 0 or 1 (but see NB below).
+
+    For backends: there must not be any guards on this dimension which
+    are not implied by the given lower and upper bound.  Regardless of
+    the lower bound, the backend can assume the size is non-negative
+    and that it is not 0 or 1.
+
+    An unbounded StrictMinMaxConstraint can be thought of as a strict version
+    of "RelaxedUnspecConstraint".
+
+    NB: Export will often unsoundly assume that a graph works for 0/1, even
+    though at trace time we assumed size is not 0 or 1.  The idea is that
+    if we produce a graph that works for a range of values, it will be OK
+    for N=0/1 too.
+    """
+
+    vr: ValueRanges
+
+    def render(self, source: Source) -> str:
+        """Format the constrain equation"""
+        # TODO: better printing for -oo and oo
+        return f"{self.vr.lower} <= {source.name} <= {self.vr.upper}"
+
+
+@dataclass(frozen=True)
+class RelaxedUnspecConstraint(Constraint):
+    """
+    For clients: no explicit constraint; constraint is whatever is implicitly
+    inferred by guards from tracing.
+
+    For backends: there must exist at least TWO possible values for the
+    size at this dimension which satisfy the guards for this dimension.
+
+    In other words, this constraint helps us distinguish between "we don't
+    care if this dimension specializes or not" versus "this dimension must be
+    unspecialized."  However, this constraint doesn't say very much about what
+    specialization is permitted; for example, if we guard on a size being
+    even, this would still be acceptable under an unspec constraint.  This
+    makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler
+    may add constraints to otherwise dynamic dimensions; we can't assert that
+    there are NO guards as this is brittle because compilers should be able to
+    add extra constraints.  If you want to assert that there are no guards,
+    use StrictMinMaxConstraint with an unbounded ValueRanges.
+    """
+
+    def render(self, source: Source) -> str:
+        return f"RelaxedUnspecConstraint({source.name})"
+
+
+# NB: None here indicates the client constraint is whatever is implicitly
+# inferred by guards from tracing, and that a backend can add whatever guards
+# it wants (including fully specializing the value).
+DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None]
+
+
+@dataclass(frozen=True)
+class EqualityConstraint(Constraint):
+    """
+    Represent and decide various kinds of equality constraints between input sources.
+
+    A "source pair" is a pair of input sources for dynamic dimensions that
+    are specified equal. We represent `source_pairs` in a union-find forest
+    so that we can efficiently check whether two such sources are transitively equal.
+
+    A "derived equality" relates an input source to an expression over a root.
+    The root can be another input source, corresponding to some dynamic dimension,
+    or a phantom symbol that does not directly represent any dynamic dimension. We
+    represent `derived_equalities` involving input sources in a transitively-closed map
+    so that we can efficiently check whether an input source is transitively equal to
+    a given expression over another input source.
+    (NOTE: In contrast, it is easy to decide whether an input source is transitively equal
+    to a given expression over a phantom symbol; such expressions are already in canonical
+    form and so the problem reduces to symbolic expression equality.)
+    """
+
+    source_pairs: list[tuple[Source, Source]]
+    derived_equalities: list[
+        tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
+    ]
+    phantom_symbols: list[sympy.Symbol]
+    relaxed_sources: set[Source]
+
+    _parents: dict[Source, Source] = field(init=False)
+    _defs: dict[Source, sympy.Expr] = field(init=False)
+
+    def __post_init__(self) -> None:
+        """
+        Pre-processing to answer queries `is_equal` and `is_derived` below.
+
+        Example: Suppose we are given:
+          source_pairs [a = b, b = c]
+          derived_equalities [d = c + 1, e = d - 1]
+        We first construct a union find with source_pairs:
+          _parents = {a: a, b: a, c: a}
+        Then we compute canonical symbolic expressions, recursively applying derived_equalities
+        until we bottom out:
+          _defs = {d: c + 1, e: (c + 1) - 1 aka c}
+        """
+
+        # self._parents is a map from input sources to input sources where, conceptually,
+        # these are directed edges in a union-find forest
+        _parents: dict[Source, Source] = {}
+        object.__setattr__(self, "_parents", _parents)
+        # self._defs is a map from input sources to "canonical" symbolic expressions,
+        # i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
+        # not derived Dims)
+        _defs: dict[Source, sympy.Expr] = {}
+        object.__setattr__(self, "_defs", _defs)
+
+        for source1, source2 in self.source_pairs:
+            # preprocess into a union-find forest
+            self._union(self._find(source1), self._find(source2))
+        for source, root, fn in self.derived_equalities:
+            # preprocess into a transitively-closed map
+            # NOTE(avik): we reuse the union-find forest for canonicalizing input sources
+            if isinstance(root, (sympy.Symbol, sympy.Integer)):
+                self._defs[self._find(source)] = fn(root)
+            else:
+                self._defs[self._find(source)] = fn(self._rewrite(root))
+
+    def _find(self, source: Source) -> Source:
+        # chase edges to find the root of this equivalence class
+        if source in self._parents:
+            return self._find(self._parents[source])
+        else:
+            return source
+
+    def _union(self, root1: Source, root2: Source) -> None:
+        # merge two equivalence classes by adding an edge from one root to the other
+        if root1 != root2:
+            self._parents[root1] = root2
+
+    def _rewrite(self, src: Source) -> sympy.Expr:
+        # always represent the given source by the root of its equivalence class
+        src = self._find(src)
+        if src in self._defs:
+            # simply look up the definition if it exists
+            # NOTE(avik): This works because definitions are always transitively-closed;
+            # otherwise we would have to do recursive rewriting.
+            return self._defs[src]
+        else:
+            # otherwise, create a symbol representing the source
+            return sympy.Symbol(src.name)
+
+    def is_equal(self, source1: Source, source2: Source) -> bool:
+        return (
+            # check whether source1 and source2 have the same root
+            # or are relaxed
+            (src1 := self._find(source1)) in self.relaxed_sources
+            or (src2 := self._find(source2)) in self.relaxed_sources
+            or src1 == src2
+            # check whether source1 is derived equal to source2
+            or self.is_derived(source1, source2, lambda x: x)
+        )
+
+    def is_derived(
+        self, src: Source, symbol_src: Source, fn: Callable[[sympy.Expr], sympy.Expr]
+    ) -> bool:
+        # check whether both src and symbol_src have the same definition
+        return self._rewrite(src) == fn(self._rewrite(symbol_src))
+
+
+def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]:
+    assert isinstance(symbolic_context, SymbolicContext), (
+        "Invalid symbolic_context object"
+    )
+    assert type(symbolic_context) is not SymbolicContext, (
+        "Illegal usage of symbolic_context ABC"
+    )
+    return True
+
+
+def _is_supported_equivalence(expr: sympy.Expr) -> bool:
+    # Currently supported Dim ops are linear expressions with integer coefficients.
+    # So check that expr only contains +, *, ints, and a single occurrence of a symbol.
+    # (See also documentation of dynamic_shapes._DerivedDim.)
+    if isinstance(expr, (sympy.Add, sympy.Mul)):
+        if len(expr.args) > 2:
+            return False
+        lhs, rhs = expr.args
+        return (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or (
+            isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs)
+        )
+    return isinstance(expr, sympy.Symbol)
+
+
+def _has_uninterpretable_sympy_function(expr: sympy.Basic) -> bool:
+    """
+    Add functions that our sympy interpreter can't reify into FX nodes
+    """
+    return expr.has(
+        torch.utils._sympy.functions.ToFloat,
+        torch.utils._sympy.functions.TruncToInt,
+        torch.utils._sympy.functions.CeilToInt,
+    )
+
+
+@dataclass(frozen=True)
+class SymbolicContext:
+    """
+    Data structure specifying how we should create symbols in
+    ``create_symbolic_sizes_strides_storage_offset``; e.g., should
+    they be static or dynamic.
+
+    This is an abstract base class because we are probably going to add
+    another version of this that says "use exactly these SymInts, don't
+    allocate fresh symbols."
+    """
+
+
+@dataclass(frozen=True)
+class SymIntSymbolicContext(SymbolicContext):
+    """
+    Data structure specifying any constraints on a SymInt input
+    """
+
+    constraint: DimConstraint
+
+
+_P1 = ParamSpec("_P1")
+_T1 = TypeVar("_T1")
+
+
+@dataclass(frozen=True)
+class StatelessSymbolicContext(SymbolicContext, Generic[_P1, _T1]):
+    """
+    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
+    a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
+    This will cause fresh symbols to be allocated
+    """
+
+    dynamic_sizes: DimList[DimDynamic]
+    dynamic_strides: DimList[DimDynamic] = None  # type: ignore[assignment]
+    constraint_sizes: DimList[DimConstraint] = None  # type: ignore[assignment]
+    constraint_strides: DimList[DimConstraint] = None  # type: ignore[assignment]
+    specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None
+    # If the tensor is a view, this should be populated for the base. It contains
+    # information on how to allocate symbols when recursively fakeifying the base
+    # during view fake-ification.
+    view_base_context: Optional[SymbolicContext] = None
+    # TODO: add storage offset and stride symbolic_context
+
+    def __post_init__(self) -> None:
+        if self.specialize_on is None:
+            object.__setattr__(
+                self,
+                "specialize_on",
+                [[]] * len(self.dynamic_sizes),
+            )
+        if self.dynamic_strides is None:
+            object.__setattr__(
+                self,
+                "dynamic_strides",
+                [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes),
+            )
+        if self.constraint_sizes is None:
+            object.__setattr__(
+                self, "constraint_sizes", [None] * len(self.dynamic_sizes)
+            )
+        if self.constraint_strides is None:
+            object.__setattr__(
+                self, "constraint_strides", [None] * len(self.dynamic_sizes)
+            )
+        assert all(
+            stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK)
+            for stride in self.dynamic_strides
+        )
+
+
+# note [Tensor Fakification and Symbol Caching]
+#
+# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
+# The reason we do this is because there are certain classes of operations, namely,
+# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
+# state at the end of a dynamo trace is different than the fake tensor state at the beginning
+# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
+# view relationships, etc.
+#
+# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
+# transfer the memoization cache, we instead transfer the shape env. However, with this
+# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
+# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
+# recompilations.
+#
+# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
+# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
+# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
+# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
+# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
+# is used.
+# TODO(voz): Shape env validation
+@dataclass(frozen=True)
+class StatefulSymbolicContext(StatelessSymbolicContext):
+    """
+    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
+    a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
+    will reuse a stored symbol, and a cache miss will write to this cache.
+
+    This behaves like StatelessSymbolicContext, except the cache supersedes the
+    other values - dynamic_sizes and constraint_sizes will not be read if we cache
+    hit.
+
+    It is the cache owner's responsibility to maintain the lifecycle of the cache
+    with respect to different shape_envs, clearing, etc.
+    """
+
+    tensor_source: Source = None  # type: ignore[assignment]
+    # Why is this keyed on int first?
+    # That integer is actually the id of the shape_env. This cache short-circuits symbol
+    # creation, and we must store it per shape env. Now, while tracing invariants are a single
+    # shape env per tracing context, and every new frame gets a new shape_env. So where would we have
+    # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events
+    # is invoked, and creates a new shape_env. Replaying events against this new shape_env will
+    # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
+    # get recorded in var_to_val, etc.
+    # TODO(voz): consider a weakref to the shape_env here
+    shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None  # type: ignore[assignment]
+
+    def __post_init__(self) -> None:
+        super().__post_init__()
+        # The None default is annoying, but required because of dataclass limitations
+        assert self.tensor_source is not None
+        if not self.shape_env_to_source_to_symbol_cache:
+            object.__setattr__(self, "shape_env_to_source_to_symbol_cache", {})
+
+
+@dataclass(frozen=True)
+class SubclassSymbolicContext(StatefulSymbolicContext):
+    """
+    The correct symbolic context for a given inner tensor of a traceable tensor subclass
+    may differ from that of the outer symbolic context. This structure allows for this
+    flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
+    """
+
+    inner_contexts: dict[str, SymbolicContext] = None  # type: ignore[assignment]
+
+    def __post_init__(self) -> None:
+        super().__post_init__()
+        if self.inner_contexts is None:
+            # pyrefly: ignore [bad-assignment]
+            self.inner_contexts = {}
+
+
+@dataclass
+class TrackedFake:
+    """
+    Tracks the sources of all fake tensors we wrap in Dynamo.
+    Used by shape guard computation.
+    """
+
+    fake: Union[FakeTensor, SymInt]
+    source: Source
+    symbolic_context: Optional[SymbolicContext]
+
+    def __hash__(self) -> int:
+        return hash((self.fake, self.source.name))
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, TrackedFake):
+            return self.fake is other.fake and self.source.name == other.source.name
+        return False
+
+
+def is_symbolic(
+    val: Union[int, SymInt, float, SymFloat, bool, SymBool],
+) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]:
+    if isinstance(val, (int, float, bool)):
+        return False
+    return val.node.is_symbolic()
+
+
+IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
+
+
+def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]:
+    """
+    Expand products of sums into sums of products.
+
+    This function takes a list of sympy expressions and separates them into
+    additive expressions (those with is_Add=True) and other expressions.
+    It then computes the distributive product, expanding (a+b)*(c+d) into a*c + a*d + b*c + b*d.
+
+    Args:
+        args: A list of sympy expressions to expand
+
+    Returns:
+        A tuple containing:
+        - The expanded expression as a sympy.Expr
+        - A boolean indicating whether expansion occurred (True if multiple additive
+          expressions were present or if there was at least one additive and one other expression)
+    """
+    adds, other = [], []
+    for arg in args:
+        if arg.is_Add:
+            adds.append(arg)
+        else:
+            other.append(arg)
+
+    result = [sympy.Mul(*other)]
+    for add in adds:
+        result = [a * b for a, b in itertools.product(result, add.args)]
+
+    result = sympy.Add(*result)
+    return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0)
+
+
+def _fast_expand(expr: _SympyT) -> _SympyT:
+    """
+    A faster implementation of sympy's expand function for common cases.
+
+    This function expands expressions like (a+b)^n or (a+b)*(c+d) into sums of products,
+    but avoids the expensive checks and features of sympy's full expand implementation.
+    It only recreates objects when necessary to avoid expensive operations.
+
+    Args:
+        expr: A sympy expression to expand
+
+    Returns:
+        The expanded expression
+    """
+
+    # The expand algorithm in sympy is slow due to all the features is supports
+    # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is
+    # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement
+    # such features here to avoid expensive checks. We also make sure that we
+    # only re-create the objects if any of the args changed to avoid expensive
+    # checks when re-creating objects.
+    new_args = [_fast_expand(arg) for arg in expr.args]  # type: ignore[arg-type]
+    # pyrefly: ignore [missing-attribute]
+    if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
+        # pyrefly: ignore [missing-attribute]
+        return _fast_expand(expr.func(*new_args))
+
+    # pyrefly: ignore [missing-attribute]
+    if expr.is_Pow:
+        base: sympy.Expr
+        exp: sympy.Expr
+        base, exp = expr.args  # type: ignore[assignment]
+        if exp.is_Integer and base.is_Add:
+            if exp > 1:
+                return sympy.expand_multinomial(expr, deep=False)
+            elif exp < 0:
+                return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
+    # pyrefly: ignore [missing-attribute]
+    elif expr.is_Mul:
+        num: list[sympy.Expr] = []
+        den: list[sympy.Expr] = []
+        # pyrefly: ignore [missing-attribute]
+        for arg in expr.args:
+            if arg.is_Pow and arg.args[1] == -1:
+                den.append(S.One / arg)  # type: ignore[operator, arg-type]
+            else:
+                num.append(arg)  # type: ignore[arg-type]
+
+        num, num_changed = _expandsums(num)
+        den, den_changed = _expandsums(den)
+        if num_changed or den_changed:
+            return num / den
+
+    return expr
+
+
+@lru_cache(256)
+def safe_expand(r: _SympyT) -> _SympyT:
+    """
+    Expand the given symbolic expression by recursively rewriting product of
+    sums into sum of products (with the product being either a multiplication or
+    exponentiation).
+
+    NOTE: using this on an intermediate expression may prevent simplification
+    down the line, e.g., if we eagerly expand `(a + b)^2` into `a^2 + 2ab + b^2`,
+    we won't be able to simplify `(a^2 + 2ab + b^2) / (a + b)` as easily.
+    """
+    if hasattr(r, "expand"):
+        try:
+            return _fast_expand(r)
+        except RecursionError:
+            log.warning("RecursionError in _fast_expand(%s)", r)
+            return r
+    else:
+        return r
+
+
+class _SymbolInfo(NamedTuple):
+    k: sympy.Symbol
+    vr: Optional[ValueRanges]
+    val: Optional[sympy.Integer]
+    is_size_like: bool
+
+
+@lru_cache(None)
+def _maybe_evaluate_static_worker(
+    expr: _SympyT,
+    # NB: this is a tuple to ensure it can be LRU cached
+    symbol_info: tuple[_SymbolInfo, ...],
+    unbacked_only: bool,
+    size_oblivious: bool,
+) -> Optional[_SympyT]:
+    """
+    This variant of ShapeEnv._maybe_evaluate_static has no dependence on
+    ShapeEnv and thus can be cached indefinitely.  It does the "heavy" lifting
+    for static evaluation, including nontrivial reliance on Sympy simplification
+    that occurs when we reallocate the symbols
+    """
+
+    # Simplify making use of value range lower bound
+    new_shape_env = {}
+    new_range_env = {}
+    for idx, sinfo in enumerate(symbol_info):
+        k, vr, val, is_size_like = sinfo
+        if isinstance(val, SingletonInt):
+            # Skip var_ranges logic for SingletonInt which is only used
+            # for jagged layout NestedTensors today
+            continue
+        assert vr is not None
+        if size_oblivious and is_size_like:
+            lower = max(2, vr.lower)
+            # Clamping size-oblivious to some quantity below sys.maxsize
+            # helps us determine that f(u0) != sys.maxsize, which is a
+            # test that is looking for sys.maxsize as a sentinel, but you
+            # don't really want to worry about it for unbacked SymInts.
+            # This is similar to the flavor where size oblivious omits
+            # 0/1, it changes semantics but in a benign way.
+            upper = min(2**48, vr.upper)
+            # Excluding the very upper bound can be helpful
+            if upper > lower:
+                upper = upper - 1
+            # This is a bit dodgy: what this means is that there was a
+            # size-like unbacked symbol whose upper bound < 2.  This
+            # causes... problems.
+            if lower <= upper:
+                vr = ValueRanges(lower, upper)
+        else:
+            lower = vr.lower
+        # Don't do anything if we don't have a nontrivial lower bound
+        # Also don't do anything if we asked only to simplify unbacked
+        # SymInt
+        if lower is -int_oo or (unbacked_only and val is not None) or not vr.is_int:
+            new_range_env[k] = vr
+            continue
+        # The goal is to take our symbols which have various lower bounds
+        # and reallocate them into new symbols which are exactly positive;
+        # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
+        # [1, inf], where s0 = ess0 + 1.  This gives the most information
+        # to sympy for subsequent simplifications.
+        #
+        # Positive means >= 1
+        # Positive - 1 means >= 0
+        # Positive + lower - 1 means >= lower
+        # The new symbol 's' is "too low", so when we substitute it in
+        # we have to increase it by offset (and conversely, the new
+        # variables have to have their value range bounds adjusted as
+        # well)
+        s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
+
+        # Note:
+        #   Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
+        #   Sympy might give unexpected results when comparing an integer with a non-integer
+        #   Therefore, we cast offset to int here.
+        #   For example:
+        #       shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
+        #       expr = sympy.Eq(shape_0 - 1/3, 4)
+        #       expr.xreplace({}) # False
+        offset = int(lower - 1)
+        new_shape_env[k] = s + offset
+        new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
+
+    # TODO: remove this try catch (esp for unbacked_only)
+    try:
+        # pyrefly: ignore [missing-attribute]
+        new_expr = expr.xreplace(new_shape_env)
+    except RecursionError:
+        log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
+        return None
+
+    # We need to canonicalize, as after expand we may have something like `a + b = a` and
+    # sympy will not simplify the a. The two appeareances of the a will then make value ranges
+    # analysis give lose bounds
+    new_expr = canonicalize_bool_expr(safe_expand(new_expr))
+    if new_expr.is_number:
+        return new_expr
+
+    # Check if the range can solve it statically
+    out = bound_sympy(new_expr, new_range_env)
+    if out.is_singleton():
+        return out.lower
+
+    return new_expr if unbacked_only else None
+
+
+def error() -> NoReturn:
+    raise AssertionError("shouldn't be hit")
+
+
+# TODO: Deduplicate this with torch/_prims_common/__init__.py
+def eval_is_non_overlapping_and_dense(
+    sizes: Sequence[int], strides: Sequence[int]
+) -> int:
+    return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
+
+
+def _eval_is_non_overlapping_and_dense(
+    sizes: Sequence[int], strides: Sequence[int]
+) -> bool:
+    """
+    Evaluates whether a tensor with the given sizes and strides is non-overlapping and dense.
+
+    A tensor is non-overlapping if there's no memory location that belongs to more than one element.
+    A tensor is dense if all elements are stored in memory without gaps.
+
+    Args:
+        sizes: Sequence of dimension sizes for the tensor
+        strides: Sequence of strides for the tensor
+
+    Returns:
+        True if the tensor is non-overlapping and dense, False otherwise
+    """
+    dim = len(sizes)
+
+    # Short-circuits for tensors of rank one, which are
+    # non-overlapping and "dense" if their stride is one
+    # or it is a 0/1 element tensor
+    if dim == 1:
+        return strides[0] == 1 or sizes[0] < 2
+
+    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
+    # Sorts (length, stride) pairs by stride
+    lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1))
+
+    # Unlike the C++ code, we don't move the 0/1 size dimensions to the
+    # end.  So we have to keep going for this code.
+    expected_stride = 1
+    for length, stride in lengths_and_strides:
+        if length == 1:
+            continue
+
+        if stride != expected_stride:
+            return False
+
+        expected_stride *= length
+
+    return True
+
+
+def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr:
+    return sympy.Piecewise((1, x), (0, True))
+
+
+def cast_symbool_to_symint_guardless(
+    symbool: Union[bool, torch.SymBool],
+) -> Union[int, torch.SymInt]:
+    """
+    Converts a SymBool or bool to a SymInt or int without introducing guards.
+
+    This function maps True to 1 and False to 0, preserving the symbolic nature
+    of the input when it's a SymBool. Unlike regular casting which might introduce
+    guards, this function performs the conversion without adding any guards.
+
+    Args:
+        symbool: A boolean value, either a concrete bool or symbolic SymBool
+
+    Returns:
+        The corresponding integer value (1 for True, 0 for False) as either
+        a concrete int or symbolic SymInt
+    """
+    if isinstance(symbool, bool):
+        return 1 if symbool else 0
+    int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr)
+    return symbool.node.shape_env.create_symintnode(
+        int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None
+    )
+
+
+SYMPY_INTERP = {
+    "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
+    "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
+    "math": math,
+    "torch": torch,
+}
+
+
+def _lru_cache(
+    fn: Callable[..., _T], maxsize: Optional[int] = None
+) -> functools._lru_cache_wrapper[_T]:
+    """
+    Wrapper around lru_cache that clears when new info about shapes has been
+    updated.
+
+    Use lru_cache if the output is always the same, regardless of the
+    constraints we know now (i.e. evaluate_expr)
+
+    Use _lru_cache otherwise.
+
+    Also note that this depends on _update_version_counter being called on the
+    shape environment whenever the constraints are updated, otherwise the cache
+    will not be cleared.
+    """
+    fn_cache = lru_cache(maxsize)(fn)
+    prior_version = 0
+
+    if config.validate_shape_env_version_key:
+        prior_key = None
+
+        @functools.wraps(fn)
+        def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T:
+            nonlocal prior_version, prior_key
+            if prior_key is None:
+                prior_key = self._get_key()
+
+            if prior_version != self._version_counter:
+                fn_cache.cache_clear()
+                prior_version = self._version_counter
+                prior_key = self._get_key()
+            else:
+                assert prior_key == self._get_key(), (
+                    "ShapeEnv cache key changed without version being updated!"
+                )
+
+            return fn_cache(self, *args, **kwargs)
+
+    else:
+
+        @functools.wraps(fn)
+        def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T:  # type: ignore[misc]
+            nonlocal prior_version
+            if prior_version != self._version_counter:
+                fn_cache.cache_clear()
+                prior_version = self._version_counter
+
+            return fn_cache(self, *args, **kwargs)
+
+    wrapper.cache_clear = fn_cache.cache_clear  # type: ignore[attr-defined]
+    wrapper.cache_info = fn_cache.cache_info  # type: ignore[attr-defined]
+    return wrapper  # type: ignore[return-value]
+
+
+@dataclass(frozen=True)
+class RuntimeAssert:
+    """
+    This is pretty similar to ShapeGuard but it also comes with a message,
+    and is exclusively used for things that MUST be true (unlike guards,
+    which can evaluate False, in which case you just choose not to use
+    a particular specialization)
+    """
+
+    expr: SympyBoolean
+    msg: str = field(repr=False)
+    stack: CapturedTraceback = field(repr=False)
+
+
+# Used for printing SymExprs in compile_fx
+class SymExprPrinter(PythonPrinter):
+    def _print_Float(self, expr: sympy.Float) -> str:
+        return str(float(expr))
+
+
+class _ShapeGuardPrinter(abc.ABC):
+    """
+    Abstract base class for printers that convert symbolic expressions to string representations.
+
+    This class provides common functionality for printing symbolic expressions with
+    special handling for symbols that represent tensor shapes, strides, etc.
+    Subclasses implement specific formatting for different output languages.
+
+    Args:
+        symbol_to_source: Mapping from sympy symbols to their source objects
+        source_ref: Function to convert a source to its string representation
+        var_to_sources: Mapping from sympy symbols to their source objects (for error reporting)
+    """
+
+    def __init__(
+        self,
+        symbol_to_source: Mapping[sympy.Symbol, list[Source]],
+        source_ref: Callable[[Source], str],
+        var_to_sources: Mapping[sympy.Symbol, list[Source]],
+    ) -> None:
+        self.symbol_to_source = symbol_to_source
+        self.source_ref = source_ref
+        self.var_to_sources = var_to_sources
+        super().__init__()
+
+    def _print_Float(self, expr: sympy.Float) -> str:
+        """Convert a sympy Float to a Python float string representation."""
+        return str(float(expr))
+
+    def _print_Symbol(self, expr: sympy.Symbol) -> str:
+        """
+        Convert a sympy Symbol to its source representation.
+
+        This method looks up the symbol in symbol_to_source mapping and returns
+        the string representation of its first source. If the symbol is not in
+        symbol_to_source (which can happen when symbols appear in guard expressions
+        through simplification or substitution), it falls back to var_to_sources.
+
+        Args:
+            expr: The sympy Symbol to convert
+
+        Returns:
+            String representation of the symbol's source
+
+        Raises:
+            AssertionError: If the symbol is not found in either mapping
+        """
+        assert isinstance(expr, sympy.Symbol), str(type(expr))
+
+        # Try symbol_to_source first, fall back to var_to_sources if not found
+        if source := self.symbol_to_source.get(expr):
+            return self.print_source(source[0])
+        elif source := self.var_to_sources.get(expr):
+            return self.print_source(source[0])
+        else:
+
+            def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str:
+                return repr(
+                    {
+                        symbol: [s.name for s in sources]
+                        for symbol, sources in src.items()
+                    }
+                )
+
+            raise RuntimeError(
+                f"{expr} not in {repr_sources(self.symbol_to_source)} or "
+                f"{repr_sources(self.var_to_sources)}.  This could be due to "
+                "the issue described in https://github.com/pytorch/pytorch/pull/90665"
+            )
+
+    @abc.abstractmethod
+    def print_source(self, source: Source) -> str:
+        """
+        Convert a source object to its string representation.
+
+        Args:
+            source: The source object to convert
+
+        Returns:
+            String representation of the source
+        """
+        ...
+
+    @abc.abstractmethod
+    def doprint(self, expr: sympy.Expr) -> str:
+        """
+        Convert a sympy expression to its string representation.
+
+        Args:
+            expr: The sympy expression to convert
+
+        Returns:
+            String representation of the expression
+        """
+        ...
+
+
+class ShapeGuardPythonPrinter(_ShapeGuardPrinter, PythonPrinter):
+    """
+    Python printer for shape guards that extends the base ShapeGuardPrinter.
+
+    This class provides functionality to print symbolic expressions as Python code,
+    with caching to improve performance when printing the same expressions multiple times.
+    It handles printing of sources and expressions according to Python syntax.
+
+    Args:
+        *args: Arguments passed to the parent classes.
+    """
+
+    def __init__(self, *args: Any) -> None:
+        super().__init__(*args)
+        self._print_cache: dict[sympy.Expr, str] = {}
+
+    def print_source(self, source: Source) -> str:
+        """
+        Convert a source object to its string representation using the source_ref function.
+
+        Args:
+            source: The source object to convert
+
+        Returns:
+            String representation of the source
+        """
+        return self.source_ref(source)
+
+    def doprint(self, expr: sympy.Expr) -> str:
+        """
+        Convert a sympy expression to its Python string representation with caching.
+
+        This method first checks if the expression is already in the cache.
+        If found, it returns the cached result; otherwise, it delegates to
+        PythonPrinter's doprint method and caches the result.
+
+        Args:
+            expr: The sympy expression to convert
+
+        Returns:
+            String representation of the expression in Python syntax
+        """
+        val = self._print_cache.get(expr, None)
+        if val is not None:
+            return val
+        else:
+            res = PythonPrinter.doprint(self, expr)
+            self._print_cache[expr] = res
+            return res
+
+
+@deprecated(
+    "`torch.fx.experimental.symbolic_shapes.ShapeGuardPrinter` is deprecated, "
+    "please use `torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter` instead.",
+    category=FutureWarning,
+)
+class ShapeGuardPrinter(ShapeGuardPythonPrinter):
+    pass
+
+
+class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter):
+    def __init__(self, *args: Any) -> None:
+        self.all_symbols: set[str] = set()
+        self.source_to_symbol: dict[Source, sympy.Symbol] = {}
+        super().__init__(*args)
+
+    def print_source(self, source: Source) -> str:
+        if source in self.source_to_symbol:
+            return self.source_to_symbol[source].name
+
+        source_name = source.name
+        mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name)
+        old_mangled_name = mangled_name
+        count = 0
+        while mangled_name in self.all_symbols:
+            mangled_name = f"{old_mangled_name}_{count}"
+            count += 1
+        self.source_to_symbol[source] = sympy.Symbol(mangled_name)
+        self.all_symbols.add(mangled_name)
+        return mangled_name
+
+    def doprint(self, expr: sympy.Expr) -> str:
+        return CppPrinter.doprint(self, expr)
+
+
+# A dataclass for storing shape guards
+@dataclass(frozen=True)
+class _ShapeGuardsHelper:
+    exprs: list[str]
+
+
+# A dataclass for storing C++ expressions and helper variables
+@dataclass(frozen=True)
+class _CppShapeGuardsHelper(_ShapeGuardsHelper):
+    source_to_symbol: dict[Source, sympy.Symbol]
+
+
+class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
+    def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
+        super().__init__(var_to_sources, lambda n: n.name, var_to_sources)
+
+
+class DynamicDimConstraintPrinter(PythonPrinter):
+    """
+    Printer for dynamic dim constraints.
+    - Instead of symbol s_k it prints its source t.size()[i]
+    - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc.
+
+    We use this to suggest code for specifying dynamic dim constraints.
+    """
+
+    def __init__(
+        self,
+        symbol_to_source: dict[sympy.Symbol, list[Source]],
+        source_name_to_debug_name: Mapping[str, str],
+    ):
+        super().__init__()
+        self.symbol_to_source = symbol_to_source
+        self.source_name_to_debug_name = source_name_to_debug_name
+
+    def _print_Symbol(self, expr: sympy.Symbol) -> str:
+        assert isinstance(expr, sympy.Symbol), str(type(expr))
+        assert self.symbol_to_source.get(expr), (
+            f"Unknown symbol {expr} created by constraints solver"
+        )
+        return self.symbol_to_source[expr][0].name
+
+
+class DimConstraints:
+    """
+    Custom solver for a system of constraints on symbolic dimensions.
+    Solutions are "static" values or simplified "dynamic" constraints.
+    """
+
+    def __init__(
+        self,
+        symbol_to_source: dict[sympy.Symbol, list[Source]],
+        var_to_val: Mapping[sympy.Symbol, sympy.Integer],
+        marked_dynamic: set[sympy.Symbol],
+        source_name_to_debug_name: Mapping[str, str],
+    ) -> None:
+        # We try to solve systems of inequalities with 1 free variable.
+        self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = (
+            defaultdict(set)
+        )
+        # Among them, we prioritize solving for a free variable that has equalities.
+        # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
+        # and removing a symbol from the former => removing it from the latter.
+        self._symbols_with_equalities: set[sympy.Symbol] = set()
+        # A solution of a free variable with equalities becomes a substitution.
+        # We use these substitutions to simplify other constraints.
+        # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
+        self._substitutions: dict[sympy.Symbol, sympy.Integer] = {}
+
+        # In general, constraints may have // and % operations.
+        # Of course, // can be expressed in terms of / and %.
+        # Our inequality solver can handle / but not %. So we need to transform them away.
+        # We do so by using the values of variables as hints to evaluate %.
+        # For soundness we record additional congruence guards and solve them separately.
+        self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val
+        self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set)
+
+        # We do not try to (directly) solve inequalities with > 1 free variables.
+        # NOTE: free variables in these inequalities cannot also be in _substitutions.
+        self._multivariate_inequalities: set[SympyBoolean] = set()
+
+        # We park external equalities between free variables here.
+        self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = []
+
+        # Solutions come in two forms:
+        # - (static) specializations
+        # - (dynamic) inequalities / congruences
+        self._static_results: set[str] = set()
+        self._dynamic_results: set[str] = set()
+
+        # printer for solutions
+        self._dcp = DynamicDimConstraintPrinter(
+            symbol_to_source, source_name_to_debug_name
+        )
+
+        # inconsistencies found on substituting with concrete values / static solutions
+        self._inconsistencies: list[str] = []
+
+        # symbols that are marked dynamic
+        self._marked_dynamic = marked_dynamic
+
+        # track supported sympy functions and subtract from list of all sympy functions
+        self._supported_sympy_functions: set[sympy.Function] = {
+            Application,
+            Mod,
+            PythonMod,
+            FloorDiv,
+        }
+        self._enumerate_sympy_functions()
+
+    def rewrite_with_congruences(self, s: sympy.Symbol, expr: _SympyT) -> _SympyT:
+        """
+        Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
+        This leaves rational operators (in particular of the form b / d) that our inequality solver can handle.
+        We solve the added congruences separately (using our congruence solver, see below).
+        """
+
+        def mod_handler(*args: sympy.Expr) -> sympy.Expr:
+            # Suppose that we have an expression of the form b % d with free variable s.
+            # Using the value of s as a "hint," we can evaluate b % d to a value k.
+            # Then we can rewrite b % d to k while adding the guard b % d == k.
+
+            # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF
+            # the original expression always evaluates to a constant value (i.e., it does not vary with s).
+            # In other words,
+            # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with
+            #   the original expression;
+            # - while it may be possible to find solutions of s with the original expression that are not
+            #   solutions with the rewritten expression, in that case the original expression cannot evaluate
+            #   to the same value for all solutions of s.
+            #
+            # Should we be worried about this incompleteness? No, because of the following reasons:
+            # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech
+            #    (i.e., "don't let perfect be the enemy of the good").
+            # 2. We already have a tradition of using hints to add guards in the compiler for making progress.
+            # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards
+            #    we generate (or simplify to) seem to be of the form b % d == k where k is a constant.
+            #
+            # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.
+            # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
+            # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
+            base, divisor = args
+            base, divisor = (
+                self.rewrite_with_congruences(s, base),
+                self.rewrite_with_congruences(s, divisor),
+            )
+            mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
+                self._var_to_val
+            )
+            congruence = (base - mod_reduced) % divisor
+            if congruence != 0:
+                self._congruences[s].add(congruence)
+            return mod_reduced
+
+        def floor_div_handler(*args: sympy.Expr) -> sympy.Expr:
+            # Suppose that we have an expression of the form b // d with free variable s.
+            # Using the value of s, we can evaluate b % d to a value k.
+            # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.
+
+            # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
+            # and eliminating b % d as above.
+            base, divisor = args
+            base, divisor = (
+                self.rewrite_with_congruences(s, base),
+                self.rewrite_with_congruences(s, divisor),
+            )
+            mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
+                self._var_to_val
+            )
+            congruence = (base - mod_reduced) % divisor
+            if congruence != 0:
+                self._congruences[s].add(congruence)
+            # NB: Must not be CleanDiv, it needs to be regular sympy division
+            # so inequality solver works.  This is sort of problematic for
+            # is_integer tests though haha
+            return (base - mod_reduced) / divisor
+
+        # pyrefly: ignore [missing-attribute]
+        if expr.has(Mod):
+            # pyrefly: ignore [missing-attribute]
+            expr = expr.replace(Mod, mod_handler)
+        # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
+        # arguments should be OK.
+        # pyrefly: ignore [missing-attribute]
+        if expr.has(PythonMod):
+            # pyrefly: ignore [missing-attribute]
+            expr = expr.replace(PythonMod, mod_handler)
+        # pyrefly: ignore [missing-attribute]
+        if expr.has(FloorDiv):
+            # pyrefly: ignore [missing-attribute]
+            expr = expr.replace(FloorDiv, floor_div_handler)
+        return expr
+
+    def _enumerate_sympy_functions(self) -> None:
+        module = torch.utils._sympy.functions
+        all_functions = set()
+        for attr in dir(module):
+            if isinstance(func := getattr(module, attr), sympy.FunctionClass):
+                all_functions.add(func)
+        self._unsupported_sympy_functions = all_functions.difference(
+            self._supported_sympy_functions
+        )
+
+    def _has_unsupported_sympy_function(self, expr: sympy.Basic) -> bool:
+        """
+        Tracks list of sympy.Functions the export solver doesn't know how to handle.
+        """
+        return expr.has(*self._unsupported_sympy_functions)
+
+    def add(self, expr: SympyBoolean) -> bool:
+        """Add an expression to the set of constraints.
+
+        Return whether the expression is a trivial constraint (i.e., an obvious tautology).
+        """
+        if expr == sympy.true:
+            return True
+        orig_expr = expr
+        orig_reduced = orig_expr.xreplace(self._var_to_val)
+        # TODO(avik): https://github.com/pytorch/pytorch/issues/101093
+        # It is possible that `expr` will fail the consistency check because of
+        # precision errors. Specifically, on substituting its free symbols with
+        # their concrete values, we might end up comparing floats. Until we have
+        # a fix for this issue, we delay raising such failures. See solve().
+        if orig_reduced == sympy.false:
+            self._inconsistencies.append(f"{orig_expr} is inconsistent!")
+        if isinstance(
+            expr, (sympy.Ne, sympy.Or, sympy.And)
+        ) or self._has_unsupported_sympy_function(expr):
+            # we're not going to do anything useful with these, so drop them
+            return False
+        free_symbols = expr.free_symbols
+        assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
+        if len(free_symbols) > 1:
+            # multivariate: record and move on
+            self._multivariate_inequalities.add(expr)
+        else:
+            # univariate: can solve these immediately
+            s = next(iter(free_symbols))
+            # eliminate // and % (see documentation of `rewrite_with_congruences` above)
+            old_n_congruences = len(self._congruences[s])
+            expr = self.rewrite_with_congruences(s, expr)
+            new_n_congruences = len(self._congruences[s])
+            if expr == sympy.true:
+                return old_n_congruences == new_n_congruences
+            reduced = expr.xreplace(self._var_to_val)
+            if reduced == sympy.false:
+                self._inconsistencies.append(
+                    f"{expr}, obtained by rewriting {orig_expr} with congruences, "
+                    "is inconsistent!"
+                )
+            if isinstance(expr, sympy.Eq):
+                # special status for symbols that have equalities (see `solve` below)
+                self._symbols_with_equalities.add(s)
+            self._univariate_inequalities[s].add(expr)
+        return False
+
+    def add_equality(self, source: Source, expr: sympy.Expr) -> None:
+        """Add an equality constraint"""
+        if expr.is_number:
+            # specialization, right here
+            self._static_results.add(f"{source.name} == {expr}")
+        else:
+            # these will resolve to either specializations or dynamic equality constraints
+            self._symbolic_equivalences.append((source, expr))
+
+    def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]:
+        reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {}
+        for s, congruences in self._congruences.items():
+            remainder_modulus_pairs = []
+            congruences_to_check = set()
+            for congruence in congruences:
+                base, divisor = congruence.args
+                # We are given a congruence of the form base % divisor == 0 with a free variable s. So:
+                # - we transform this into an equation of the form base = divisor * tmp;
+                # - we solve this equation for s to get a linear solution with free variable tmp.
+                tmp = sympy.Symbol("reduce_congruences_tmp", integer=True)
+                symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
+                # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
+                # for how to interpret the results.
+                if s == symbol:
+                    # This means the solution is of the form s = modulus*tmp + remainder.
+                    modulus, remainder = sympy.polys.polytools.div(solution, tmp)
+                    if isinstance(modulus, sympy.Integer) and isinstance(
+                        remainder, sympy.Integer
+                    ):
+                        # Make sure 0 <= remainder <= modulus.
+                        remainder = remainder % modulus
+                        remainder_modulus_pairs.append((remainder, modulus))
+                        continue
+                # This means that we did not get a unique solution to the equation.
+                # No problem, we will check it.
+                congruences_to_check.add(congruence)
+            # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).
+            # The solution will be a congruence of the form s = r mod m.
+            # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.
+            if remainder_modulus_pairs:
+                remainder, modulus = sympy.ntheory.modular.solve_congruence(
+                    *remainder_modulus_pairs
+                )
+                reduced_congruences[s] = {(s - remainder) % modulus}
+                substitution = {
+                    s: modulus * sympy.Symbol("tmp", integer=True) + remainder
+                }
+                reduced_congruences[s].update(
+                    congruence
+                    for congruence in congruences_to_check
+                    if not sympy.checksol(congruence, substitution)
+                )
+            else:
+                reduced_congruences[s] = congruences_to_check
+
+        return reduced_congruences
+
+    def _raise_inconsistencies(self) -> None:
+        if self._inconsistencies:
+            msg = "\n".join(self._inconsistencies)
+            self._inconsistencies.clear()
+            raise ValueError(f"The following inconsistencies were found:\n{msg}")
+
+    def solve(self) -> None:
+        """Solve the system of constraint equations to find simplified constraints"""
+        self._raise_inconsistencies()
+        # as long as there are symbols with equalities, solve for them
+        # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
+        while self._symbols_with_equalities:
+            s = self._symbols_with_equalities.pop()
+            exprs = self._univariate_inequalities.pop(s)
+            solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
+            if isinstance(solution, sympy.And):
+                solution = next(
+                    (arg for arg in solution.args if isinstance(arg, sympy.Eq)),
+                    solution,
+                )
+            assert isinstance(solution, sympy.Eq), (
+                f"Expected an equality constraint for {s}, got {solution}"
+            )
+            symbol, val = solution.args
+            assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
+            # because this is univariate, the solution is a specialization
+            self._static_results.add(
+                f"{self._dcp.symbol_to_source[s][0].name} == {val}"
+            )
+            # add this as a substitution to simplify other constraints
+            self._substitutions[s] = val  # type: ignore[assignment]
+
+            # simplify multivariate inequalities: some of them will now become univariate!
+            multivariate_inequalities = self._multivariate_inequalities
+            self._multivariate_inequalities = set()
+            for expr in multivariate_inequalities:
+                self.add(expr.xreplace({s: self._substitutions[s]}))
+            self._raise_inconsistencies()
+
+        # solve linear congruences
+        # NOTE(avik): We do not need to solve them for symbols that have already been specialized.
+        reduced_congruences = self._reduce_congruences()
+        for s, congruences in reduced_congruences.items():
+            for congruence in congruences:
+                # any congruence that cannot be checked becomes a dynamic constraint as well
+                if s not in self._substitutions or not sympy.checksol(
+                    congruence, {s: self._substitutions[s]}
+                ):
+                    if self._is_supported_congruence(congruence):
+                        base, divisor = congruence.args
+                        tmp_name = "_" + str(
+                            self._dcp.source_name_to_debug_name.get(
+                                self._dcp.symbol_to_source[s][0].name,
+                                self._dcp.symbol_to_source[s][0].name,
+                            )
+                        )
+                        tmp = sympy.Symbol(tmp_name, integer=True)
+                        from torch._dynamo.source import ConstantSource
+
+                        self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
+                        r = try_solve(sympy.Eq(base, divisor * tmp), s)
+                        assert r is not None
+                        self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
+
+        # remaining symbols have only pure inequalities (no equalities)
+        for s, exprs in self._univariate_inequalities.items():
+            try:
+                solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
+                # because this is univariate, the solution is a dynamic (range) constraint
+                if isinstance(solution, sympy.Or):
+                    solution = next(
+                        iter(
+                            arg
+                            for arg in solution.args
+                            if arg.xreplace(self._var_to_val)
+                        )
+                    )
+                if isinstance(solution, sympy.And):
+                    for arg in solution.args:
+                        self._dynamic_results.add(self._dcp.doprint(arg))
+                else:
+                    self._dynamic_results.add(self._dcp.doprint(solution))
+            except (NotImplementedError, AssertionError):
+                log.warning("Failed to reduce inequalities", exc_info=True)
+                for expr2 in exprs:
+                    self._dynamic_results.add(self._dcp.doprint(expr2))
+
+        # simplify symbolic equivalences: some of them will now become specializations!
+        symbolic_equivalences = self._symbolic_equivalences
+        self._symbolic_equivalences = []
+        for source, expr3 in symbolic_equivalences:
+            self.add_equality(source, expr3.xreplace(self._substitutions))
+
+        # remaining symbolic equivalences become dynamic equality constraints
+        for source, expr3 in self._symbolic_equivalences:
+            self._dynamic_results.add(f"{source.name} == {self._dcp.doprint(expr3)}")
+
+    @classmethod
+    def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool:
+        base, divisor = congruence.args
+        # Congruences that can be currently expressed with supported Dim ops are
+        # of the form (x + a) % b == 0, where x is a Dim and a and b are constants.
+        # This allows us to derive x as b*y - a for some Dim y.
+        # (See also documentation of dynamic_shapes._DerivedDim.)
+        if isinstance(base, sympy.Add):
+            lhs, rhs = base.args
+            cond = (
+                isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)
+            ) or (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol))
+        else:
+            cond = isinstance(base, sympy.Symbol)
+        cond = cond and isinstance(divisor, sympy.Integer)
+        return cond
+
+    def forced_specializations(self) -> dict[str, sympy.Expr]:
+        """Returns a dictionary of the names of symbols to their specialized value"""
+
+        def debug_name(src: Source) -> str:
+            name = src.name
+            if self._dcp.source_name_to_debug_name:
+                return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
+            else:
+                return name
+
+        return {
+            debug_name(self._dcp.symbol_to_source[s][0]): val
+            for s, val in self._substitutions.items()
+            if s in self._marked_dynamic
+        }
+
+    def _is_derived_dim(
+        self, dim: object
+    ) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]:
+        return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
+
+    def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]:
+        return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance(
+            dim, torch.export.dynamic_shapes._DerivedDim
+        )
+
+    def _process_derived_dim_roots(
+        self,
+        results: dict[str, dict[str, Any]],
+        name_to_dim: dict[str, Any],
+    ) -> None:
+        """
+        Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
+        and 2) root swapping.
+
+        1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests
+        dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final
+        suggested fixes handle this correctly, but we can get intermediate results that look like
+        {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}
+        and this routine prettifies this by unifying to a single root, and making each suggestion
+        either a derived dim or min/max range, not both.
+
+        2) With suggested fixes for derived dims, roots can be swapped,
+        e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name,
+        since this leads to messages like "dx - 1 = Dim("dx - 1", ...)".
+        Instead we evaluate the new root value, and remove results for its derivations.
+
+        First we find all the original roots (specified in dynamic_shapes), that are found in the
+        values of results (i.e. used for computing suggesting fix values). These original roots
+        (suppose `dx`) are either specialized, unchanged, refined, or swapped
+        (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value
+        in results, and remove suggestions for derivations of `dx`, assuming the derived relation
+        is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value,
+        and then do the same with `dx`'s derivations.
+
+        Assuming the originally specified derived relations are correct is valid, because:
+            1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1))
+               produce_guards() will catch this and crash before hand.
+            2) if the relations are numerically correct but do not match the emitted guard,
+               for example:
+
+                    def forward(self, x, y):
+                        return x.reshape([-1]) + y  # guard: s0 * 2 = s1
+                    inputs = (torch.randn(6, 2), torch.randn(12))
+                    dx = Dim("dx", min=2, max=32)
+                    dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}  # this matches values but not op
+
+               then this leads to 2 linear equations, and a) produce_guards() is able to solve for
+               the unique solution of dx = 6 and specialize, and b) the export constraint solver will
+               raise an issue due to range constraints (a unique solution means not all values in a
+               range satisfy a guard) and also force specializations.
+        """
+        from torch.export.dynamic_shapes import Dim
+
+        def _check_same_range(c: Mapping[str, int], dim: object) -> bool:
+            # returns True if c & dim are both min/max ranges with same values
+            return (
+                self._is_dim(dim)
+                and ("min" in c or "max" in c)
+                and (
+                    (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2)  # type: ignore[attr-defined]
+                )  # let pass if analysis min = 2 and specified min = 0/1
+                and dim.max == c.get("max", int_oo)  # type: ignore[attr-defined]
+            )
+
+        # 1) newly introduced roots
+        # this part we handle adding newly introduced roots
+        # these arise from guards like "x.shape[0] % 3 == 0"
+        # leading to suggested fixes like "dx = 3*_dx"
+        # extract _dx, and find appropriate min/max values
+        #
+        # before, we have something like:
+        # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
+        # we want instead:
+        # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
+        introduced_roots: dict[str, str] = {}  # map new root -> old root
+        for k, c in list(results.items()):
+            if "eq" in c and isinstance(c["eq"], sympy.Expr):  # derived dim
+                root = next(iter(c["eq"].free_symbols))
+                if str(root) not in name_to_dim:
+                    introduced_roots[str(root)] = k
+                    # calculate necessary min & max
+                    modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
+                    c_min = c.get("min", 2)
+                    min_ = math.ceil((c_min - remainder) / modulus)
+                    c_max = c.get("max", int_oo)
+                    max_ = math.floor((c_max - remainder) / modulus)
+                    # create result & dim
+                    results[str(root)] = {"min": min_, "max": max_}
+                    name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_)
+                    # remove old root min/max bounds
+                    c.pop("min", None)
+                    c.pop("max", None)
+
+        # alter derivations that depend on old root, to unify to new root
+        # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
+        for old_root in introduced_roots.values():
+            for c in results.values():
+                if (
+                    "eq" in c
+                    and isinstance(c["eq"], sympy.Expr)
+                    and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
+                ):  # derived dim with root = old_root
+                    new_root_expr = results[str(old_root)]["eq"]  # dx=3*_dx+1
+
+                    new_expr = c["eq"].subs({symbol: new_root_expr})  # dy=(3*_dx+1)+1
+                    c["eq"] = new_expr
+
+        # 2) root swapping
+        # collect all the original roots that are used for calculating values of suggested fixes
+        # this consists of:
+        # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
+        # 2) {"dy": "dx + 1"} -> dx: root for suggested fix
+        modified_roots: set[str] = set()
+        for k, c in results.items():
+            if k not in name_to_dim:  # _dynamo.export() may handle source directly
+                continue
+            if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c):  # case 1)
+                modified_roots.add(k)
+            elif "eq" in c and isinstance(c["eq"], sympy.Expr):  # case 2)
+                root = next(iter(c["eq"].free_symbols))
+                assert root is not None
+                modified_roots.add(str(root))
+
+        # exclude newly introduced roots, we've already processed these
+        modified_roots = modified_roots.difference(introduced_roots)
+
+        # evaluate the new value for each root
+        # this is now either 1) unchanged, 2) refined with a new range,
+        # or 3) specialized to a concrete value
+        modified_root_values: dict[str, dict[str, Any]] = {}
+        for mroot in modified_roots:
+            swapped_root = True
+            if mroot in results:
+                c = results[mroot]
+                if ("min" in c or "max" in c) or isinstance(  # range
+                    c["eq"], int
+                ):  # specialized
+                    # here, the original root is a root Dim or concrete value in results.
+                    # if it is a derived dim, it is swapped, and we handle that below.
+                    if not _check_same_range(
+                        c, name_to_dim[mroot]
+                    ):  # ignore if unchanged
+                        modified_root_values[mroot] = c
+                    swapped_root = False
+
+            if swapped_root:
+                # if the original root has been swapped in results, that means the new root
+                # is a range (if it had specialized, the original root would have too).
+                # find this new root, and solve for the original root's range.
+                for k, c in results.items():
+                    if k not in name_to_dim:
+                        continue
+                    dim = name_to_dim[k]
+                    if (
+                        dim.__class__.__name__ == "_DerivedDim"
+                        and dim.root.__name__ == mroot
+                    ):
+                        # only look for min/max root, otherwise root would have specialized
+                        if "min" in c or "max" in c:
+                            expr = sympy.sympify(k)
+                            s = next(iter(expr.free_symbols))
+                            result = {
+                                "min": try_solve(sympy.Eq(expr, c["min"]), s)[1],  # type: ignore[arg-type, index]
+                                "max": try_solve(sympy.Eq(expr, c["max"]), s)[1],  # type: ignore[arg-type, index]
+                            }
+                            if not _check_same_range(
+                                result,
+                                name_to_dim[mroot],  # type: ignore[index, arg-type]
+                            ):  # ignore if unchanged
+                                modified_root_values[mroot] = result  # type: ignore[index]
+                                break
+
+        # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4})
+        # we only want to suggest fixes for the root, to avoid derived names.
+        # also, remove anything in modified_roots, since we either add new modified values after this,
+        # or have decided they are unchanged.
+        for k in list(results.keys()):
+            if k not in name_to_dim:
+                continue
+            if self._is_derived_dim(name_to_dim[k]) or k in modified_roots:
+                del results[k]
+
+        # update results with modified root values
+        # now results has the following properties:
+        # - only contains original roots as keys
+        # - each root is now either specialized, refined, or derived from another original root
+        results.update(modified_root_values)
+
+    def prettify_results(
+        self,
+        original_signature: inspect.Signature,
+        dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
+        constraint_violation_error: object,
+        forced_specializations: dict[str, str],
+    ) -> str:
+        """Format a message for constraint violation errors"""
+        from torch.export.dynamic_shapes import _get_dim_name_mapping
+
+        if not self._dcp.source_name_to_debug_name:
+            # nothing to do
+            return ""
+
+        def transform(s: str, inverse: bool = False) -> str:
+            for k, v in self._dcp.source_name_to_debug_name.items():
+                s = s.replace(k, v) if not inverse else s.replace(v, k)
+            return s
+
+        results: defaultdict[str, dict[str, Any]] = defaultdict(dict)
+        if dynamic_shapes is None:
+            dynamic_shapes = {}
+
+        def flip(op: str) -> str:
+            if op == "<=":
+                return ">="
+            if op == ">=":
+                return "<="
+            if op == "<":
+                return ">"
+            if op == ">":
+                return "<"
+            assert op == "=="
+            return op
+
+        def relation_with_digit(expr: str, op: str, digit: int) -> None:
+            if op == "<=":
+                results[expr]["max"] = digit
+            elif op == "<":
+                results[expr]["max"] = digit - 1
+            elif op == ">=":
+                results[expr]["min"] = digit
+            elif op == ">":
+                results[expr]["min"] = digit + 1
+            else:
+                assert op == "=="
+                results[expr]["eq"] = digit
+
+        # retrieve dynamic shapes
+        name_to_dim = _get_dim_name_mapping(dynamic_shapes)
+
+        for s in self._static_results.union(self._dynamic_results):
+            t = transform(s)
+            if t == s:
+                continue
+            left, op, right = re.split(r"( == | <= | >= | < | > )", t)
+            op = op.strip()
+            if op == "==" and left == right:
+                continue
+            if right.isdigit():
+                relation_with_digit(left, op, int(right))
+            elif left.isdigit():
+                relation_with_digit(right, flip(op), int(left))
+            else:
+                assert op == "==", t
+                try:
+                    results[left]["eq"] = sympy.sympify(right)
+                except TypeError:  # rhs source is not linked to Dim name
+                    pass
+
+        # order forced specializations based on name
+        forced_specializations = {
+            k: forced_specializations[k]
+            for k in sorted(
+                forced_specializations.keys(),
+                key=lambda x: x.split(" = ")[1],
+            )
+        }
+
+        buf = ""
+        if forced_specializations:
+            debug_names = set()
+            for k in forced_specializations:
+                dim = name_to_dim[k.split(" = ")[0]]
+                if self._is_derived_dim(dim):
+                    debug_names.add(dim.root.__name__)  # type: ignore[attr-defined]
+                else:
+                    debug_names.add(dim.__name__)
+
+            buf += (
+                f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
+                'For more information, run with TORCH_LOGS="+dynamic".\n'
+            )
+            for s, val in forced_specializations.items():
+                buf += f"  - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
+
+        self._process_derived_dim_roots(results, name_to_dim)
+
+        dims = []
+        others = []
+
+        # order results by source name
+        results2 = {
+            k: results[k]
+            for k in sorted(
+                results.keys(),
+                key=lambda x: transform(x, inverse=True),
+            )
+        }
+        for k, c in results2.items():
+            if "eq" in c:
+                other = c["eq"]
+                if isinstance(other, int):
+                    others.append(f"{k} = {other}")
+                elif _is_supported_equivalence(other):
+                    others.append(f"{k} = {other}")
+            else:
+                min_ = c.get("min", None)
+                if min_ == 2:
+                    min_ = None
+                max_ = c.get("max", None)
+                if min_ is not None and max_ is not None:
+                    dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
+                elif min_ is not None:
+                    dims.append(f"{k} = Dim('{k}', min={min_})")
+                elif max_ is not None:
+                    dims.append(f"{k} = Dim('{k}', max={max_})")
+                else:
+                    dims.append(f"{k} = Dim('{k}')")
+
+        # results2 will get filtered out if no new suggestions,
+        # this can happen if guards are too complex.
+        # in that case don't suggest fix
+        if dims or others:
+            buf += "\nSuggested fixes:\n  "
+            buf += "\n  ".join(dims + others)
+
+        return buf
+
+
+TLS = threading.local()
+
+
+@dataclass(frozen=True)
+class ShapeEnvSettings:
+    """
+    Encapsulates all shape env settings that could potentially affect
+    FakeTensor dispatch. Used when creating dispatch cache keys.
+    """
+
+    allow_scalar_outputs: bool
+    allow_dynamic_output_shape_ops: bool
+    assume_static_by_default: bool
+    specialize_zero_one: bool
+    duck_shape: bool
+    prefer_deferred_runtime_asserts_over_guards: bool
+    trace_asserts: bool
+
+
+@dataclass
+class ValueRangesSLoc:
+    """
+    Locations of the guards that triggered lower and upper bound.
+    """
+
+    lower: SLoc
+    upper: SLoc
+
+
+@contextmanager
+def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]:
+    shape_env._suppress_guards_enter()
+    try:
+        yield
+    finally:
+        shape_env._suppress_guards_exit()
+
+
+@dataclass
+class _FrameLocalResult:
+    loc: Optional[str] = None
+    locals: dict[str, Any] = field(default_factory=dict)
+    symbols: dict[str, str] = field(default_factory=dict)
+
+
+class ShapeEnv:
+    # This is a wrapper over the actual __init__ function.
+    #
+    # Where to add a new constructor parameter to ShapeEnv?
+    # =====================================================
+    # This __init__ function should be used only for parameters related to event recording.
+    # These are parameters that we don't wish to pass down the road to new ShapeEnv instances
+    # created from replaying events.
+    #
+    # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
+    # recording, do so in the _init function.
+    def __init__(
+        self,
+        *,
+        should_record_events: Optional[bool] = None,
+        tracked_fakes: Optional[list[Any]] = None,
+        **kwargs: Any,
+    ) -> None:
+        self._init(**kwargs)
+
+        # Disable event recording when replaying.
+        kwargs["should_record_events"] = False
+
+        from torch.fx.experimental.validator import translation_validation_enabled
+
+        self._translation_validation_enabled = translation_validation_enabled()
+
+        # If not specified, enable event recording if both:
+        #   - Translation validation is on
+        #   - Translation validation bisection is not disabled
+        self.should_record_events = (
+            should_record_events
+            if should_record_events is not None
+            else (
+                self._translation_validation_enabled
+                and not config.translation_validation_no_bisect
+            )
+        )
+
+        # Enable event recording check if both:
+        #   - It should record events
+        #   - The recording check is enabled
+        self.check_recorded_events = (
+            self.should_record_events and config.check_shape_env_recorded_events
+        )
+
+        # This will make sure we only record the top-level function call.
+        self.is_recording = False
+        # Keep track of the list of tracked fakes.
+        self.tracked_fakes = tracked_fakes
+        # List of events for reconstructing ShapeEnv at arbitrary points in time.
+        self.events: list[ShapeEnvEvent] = (
+            [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)]
+            if self.should_record_events
+            else []
+        )
+
+        # FakeTensor per-ShapeEnv operation cache. This is used for caching
+        # operations that contain symbolic shapes which have guards on the
+        # ShapeEnv (so are ShapeEnv-dependent).
+        #
+        # NOTE: It's important that SymNodes in this cache have their ShapeEnv
+        # stripped otherwise you end up with cycles which can only be cleaned
+        # with the GC.
+        self.fake_tensor_cache: dict[
+            torch._subclasses.fake_tensor._DispatchCacheKey,
+            torch._subclasses.fake_tensor._DispatchCacheEntry,
+        ] = {}
+
+    # Pro-tip: if you add new field to ShapeEnv, this affects some accept
+    # tests.  Accept their output with:
+    #
+    #   EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal
+    #
+    def _init(
+        self,
+        *,
+        allow_scalar_outputs: bool = True,
+        allow_dynamic_output_shape_ops: bool = True,
+        # NB: These are legacy configuration that help us make good choices
+        # when the constraint/dynamic dims are not explicitly passed to us.
+        # Ideally we will fix all call sites to be explicit and not have
+        # implicit choices, but this apparently was pretty involved.
+        assume_static_by_default: bool = False,
+        # Note - On 0/1 specialization
+        #
+        # The following options affect decisions we make about eager
+        # specialization.  Disabling them will increase trace time (as we do
+        # more symbolic reasoning) and can also harm the quality of generated
+        # code (because inductor may not be able to specialize for bounds
+        # being equal--although if we later respecialize because of a guard,
+        # your code may be just as good as it was before.)
+        #
+        # When True, eagerly specialize input sizes which have 0/1.
+        specialize_zero_one: bool = True,
+        # When True, assume input sizes which have the same size are
+        # symbolically equal.
+        duck_shape: Optional[bool] = None,
+        # For debugging
+        co_fields: Optional[dict[str, str]] = None,
+        # When True, whenever safe, we will generate a deferred runtime assert
+        # instead of a guard whenever we know that an expression must be True,
+        # otherwise it would be an error, even for backed SymInts (where we
+        # could ostensibly unconditionally generate guards).  This is useful
+        # for export, where preventing "error checking" sizes from showing up
+        # in guards is helpful, since these guards in some sense are overly
+        # pedantic.  See also https://github.com/pytorch/pytorch/issues/121749
+        prefer_deferred_runtime_asserts_over_guards: bool = False,
+        # XXX Add any new settings that could affect FakeTensor evaluation
+        # to: torch._subclasses.fake_tensor._ShapeEnvSettings
+        trace_asserts: bool = False,
+    ) -> None:
+        if duck_shape is None:
+            duck_shape = config.use_duck_shape
+
+        self.settings = ShapeEnvSettings(
+            # Not directly used by ShapeEnv; indirectly used by FakeTensor
+            allow_scalar_outputs=allow_scalar_outputs,
+            allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops,
+            # End
+            assume_static_by_default=assume_static_by_default,
+            specialize_zero_one=specialize_zero_one,
+            duck_shape=duck_shape,
+            prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
+            trace_asserts=trace_asserts,
+        )
+
+        self.guards: list[ShapeGuard] = []
+        self.axioms: dict[sympy.Expr, sympy.Expr] = {}
+
+        # A set of ids that have already been allocated. This is used
+        # for when we allocate symbol ids using the hash of the source
+        # names to ensure we don't have collisions via linear probing
+        self.unique_ids: set[int] = set()
+        # Maps symbolic ints to their original concrete values
+        # Currently populated from tensors
+        self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
+        # Like var_to_val, but only set when propagate_real_tensors is on.
+        # Used as last resort to avoid GuardOnDataDependent error
+        self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
+        # Like above, but used exclusively for OBLIVIOUS_SIZE.  These
+        # potentially could be put together but I am not sure, writing out
+        # the logic individually before abstracting.
+        self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
+        # Maps symbolic ints to their min/max range.  These ranges
+        # are conservative: the int MUST fall in the range, but the
+        # range may contain ints which may not actually appear in
+        # practice
+        self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
+        self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
+        self.source_name_to_debug_name: dict[str, str] = {}
+        self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
+        # A set of unbacked symbols that are inputs (i.e: not data dependent).
+        self.unbacked_inputs: OrderedSet[sympy.Symbol] = OrderedSet()
+        self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
+        self.var_to_hint_override: dict[sympy.Symbol, int] = {}
+        # Maps a source to the *original* symbol that was assigned to it
+        self.source_to_var: dict[str, sympy.Symbol] = {}
+        # Maps from sympy ints to expressions representing them
+        # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
+        self.replacements: dict[sympy.Symbol, sympy.Expr] = {}
+        # The sloc of the guard that triggered this replacement to be added
+        self.replacements_slocs: dict[sympy.Symbol, SLoc] = {}
+        self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {}
+        # Set holds a % b expressions that evaluate to 0.
+        self.divisible: set[sympy.Expr] = set()
+        # Set that holds "size-like" symbols.  When we perform
+        # "size-oblivious" tests, these can be assumed to be >= 2.
+        self.size_like: set[sympy.Symbol] = set()
+        # Duck-shaping says that if two input tensors have the same size,
+        # they get assigned the same symbolic variable
+        self.val_to_var: dict[int, sympy.Symbol] = {}
+        self.unbacked_symfloat_counter = 0
+        self.unbacked_symint_counter = 0
+        # Similar to guards, but these MUST evaluate to true and can
+        # only be evaluated at runtime midway through (i.e., they always
+        # involve unbacked symints)
+        #
+        # For efficiency reasons, we index in the following way.  Suppose you have
+        # a runtime assert i0 + i1 <= s1.  We pick the most recently allocated
+        # symbol in the source expression and add the assert to the list for
+        # that symbol e.g., {i1: [i0 + i1 <= s1]}.
+        #
+        # We access the runtime asserts in two situations:
+        #
+        #   - When we are guarding on an expression, we will attempt to
+        #     statically evaluate it, in case the unbacked SymInts can
+        #     simplify away.  If we have a runtime assert, we may be able
+        #     to discharge the guard entirely.  We only need to attempt
+        #     runtime asserts that mention freevars of the expression in
+        #     question.
+        #
+        #   - When we are performing codegen (in Inductor for eager, or
+        #     when finalizing the export FX graph), we need to know what
+        #     extra runtime asserts to insert.  Whenever an unbacked
+        #     SymInt comes into scope, all runtime asserts involving it
+        #     become eligible for insertion (so long as all of their other
+        #     free unbacked symbols are also in scope).  We technically
+        #     can handle any choice of key by kicking inexpressible asserts
+        #     to the next unbacked symbol to wait on, but if we choose the
+        #     latest key, an assert will only show up at the moment when
+        #     we can actually codegen it.
+        self.deferred_runtime_asserts: dict[
+            Optional[sympy.Symbol], list[RuntimeAssert]
+        ] = {}
+        # This exists so we can efficiently invalidate the cache (it's used as
+        # part of the cache key); otherwise we'd have to iterate through
+        # deferred_runtime_asserts to compute its length
+        self.num_deferred_runtime_asserts = 0
+        self.log = log
+        self.log.info("create_env")
+        self.frozen = False
+        self.runtime_asserts_frozen = False
+        self.dim_constraints: Optional[DimConstraints] = None
+        self.counter: Counter[str] = collections.Counter()
+        # Mapping from sympy.Symbol to the number of guards which mention this
+        # symbol
+        self.symbol_guard_counter: Counter[sympy.Symbol] = collections.Counter()
+        # A selection of important fields on co_field; solely used for
+        # signpost_event
+        self.co_fields = co_fields if co_fields else {}
+
+        # Whenever we allocate a fresh unbacked Symbol, we add it to this
+        # pending list.  Unbacked symbol allocation can occur at unpredictable
+        # points during meta tensor propagation, but at some point, we
+        # have to know what the binding site for an unbacked symbol is, and
+        # this is computed when we actually place the node in the graph. The
+        # important thing is that we always actually handle every unaccounted
+        # for unbacked symbol, so this list helps us keep track of them and
+        # then make sure they are all accounted for.
+        #
+        # We could potentially give rise to errors earlier by lexically
+        # scoping when we do propagation, and only allowing unbacked symbols
+        # to be allocated at this point in time.  However this is inconvenient
+        # to do in Dynamo, because fake tensor propagation is far from when we
+        # analyze binding sites (set_example_value), so we do it in a more
+        # mutatey way.
+        #
+        # NB: fresh unbacked symbols NEVER get substitutions applied to them,
+        # they are binding sites!
+        self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = []
+
+        # Version counter used to invalidate cached values
+        self._prev_cache_key = self._get_key()
+        self._version_counter = 0
+
+        # Each time divisible is changed this should be set to True, this is set in _update_version_counter.
+        self._resimplify_floor_div_axioms = True
+
+        # Cache for FX nodes.
+        # Maps an already built node a tuple of:
+        #   1. node's target
+        #   2. list of arguments
+        # This drastically reduces the size of the FX graph, avoiding
+        # duplicated nodes.
+        self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {}
+        self.source_to_symbol: dict[str, sympy.Symbol] = {}
+
+        # Suppose you want to replace an unbacked symbol with another
+        # unbacked symbol.  This is error prone because you can cause
+        # references to unbacked symbols to time travel backwards.  E.g.,
+        #
+        # u1 = x.item()
+        # ... use of u1 ...
+        # u2 = y.item()
+        # u3 = z.item()
+        # torch._check(u1 == u2 + u3)
+        #
+        # If you replace u1 with u2 + u3, then the use of u1 now
+        # references u2 and u3 prior to them actually being bound at
+        # runtime.
+        #
+        # To control for this, we track the order unbacked symbols
+        # were allocated, and only allow substitutions if they respect
+        # the dependency from this order; an unbacked symbol can only
+        # be substituted with unbacked symbols that come before it in the
+        # order.
+        #
+        # This also imposes an ordering on the unbacked symbol binding
+        # sites themselves: you are not allowed to reorder unbacked symbol
+        # bindings.  At the moment, this is not tracked, but we potentially
+        # could track this at the IR level using a higher order operator
+        # with something like effect token tracking.
+        self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
+
+        self.specialization_stacks: dict[Source, traceback.StackSummary] = {}
+
+        self.trace_asserts = trace_asserts
+
+        self.specializations: OrderedSet[Specialization] = OrderedSet()
+
+        from torch.fx.experimental.validator import translation_validation_enabled
+
+        self._translation_validation_enabled = translation_validation_enabled()
+
+        if self._translation_validation_enabled:
+            from torch.fx.experimental.validator import TranslationValidator
+
+            self.validator = TranslationValidator()
+            self.graph = torch.fx.Graph()
+            # Create an output graph and start inserting before that.
+            # This is needed when 'deepcopy'-ing this object.
+            self.graph.inserting_before(self.graph.output(None))
+
+            # Mapping of each node name to the node itself.
+            #
+            # This is useful for matching an FX node from a recorded ShapeEnv.graph
+            # to the FX node of the ShapeEnv we are running the event on.
+            #
+            # Whenever you add a node to self.graph, you must add a mapping to this
+            # variable. Otherwise, the built FX graph on the replayed ShapeEnv will
+            # not be valid.
+            self.name_to_node: dict[str, torch.fx.Node] = {}
+
+    @property
+    def allow_scalar_outputs(self) -> bool:
+        return self.settings.allow_scalar_outputs
+
+    @property
+    def allow_dynamic_output_shape_ops(self) -> bool:
+        return self.settings.allow_dynamic_output_shape_ops
+
+    @property
+    def assume_static_by_default(self) -> bool:
+        return self.settings.assume_static_by_default
+
+    @property
+    def specialize_zero_one(self) -> bool:
+        return self.settings.specialize_zero_one
+
+    @property
+    def duck_shape(self) -> bool:
+        return self.settings.duck_shape
+
+    @property
+    def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
+        return self.settings.prefer_deferred_runtime_asserts_over_guards
+
+    @contextmanager
+    def patch_source_specialization(
+        self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
+    ) -> Iterator[None]:
+        """
+        Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork"
+        and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph
+        compile so we can support various graphs with varying levels of specializations.
+
+        This context manager allows for temporarily adding constraints to the shape environment
+        based on a specialization function applied to a symbol associated with a source.
+
+        Args:
+            source: The source of the symbol to specialize
+            check_fn: A function that takes a sympy Symbol and returns a sympy expression
+                     representing a constraint/specialization to be applied
+        """
+        name = source.name
+        sym = self.source_to_var[name]
+        expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr
+        new_axioms = dict(self.get_implications(self.simplify(expr)))
+        added_replacements = {}
+
+        for axiom in new_axioms:
+            if (
+                isinstance(axiom, sympy.Eq)
+                and isinstance(axiom.lhs, sympy.Symbol)
+                and isinstance(axiom.rhs, sympy.Integer)
+                and axiom.lhs not in self.replacements
+            ):
+                self.replacements[axiom.lhs] = axiom.rhs
+                added_replacements[axiom.lhs] = axiom.rhs
+        self.axioms.update(new_axioms)
+
+        # We need to freeze the ShapeEnv because any additional modification of
+        # the ShapeEnv will cause unsoundness for subsequent specialization calls.
+        self.frozen = True
+        try:
+            yield
+        finally:
+            for k in new_axioms:
+                self.axioms.pop(k, None)
+            for k in added_replacements:
+                self.replacements.pop(k, None)
+            self.frozen = False
+
+    def check_equal(self, other: ShapeEnv) -> None:
+        """Compare another ShapeEnv for equivalence"""
+        # ShapeEnv fields that are not relevant for the outcome of
+        # ShapeEnv.produce_guards call:
+        #   - Debugging variables
+        #   - Translation validation related variables
+        #   - Events recording related variables
+        non_state_variable_names = (
+            "counter",
+            "log",
+            "var_to_stack",
+            "fx_node_cache",
+            "graph",
+            "validator",
+            "check_recorded_events",
+            "should_record_events",
+            "is_recording",
+            "tracked_fakes",
+            "events",
+            "source_name_to_debug_name",
+            "_prev_cache_key",
+            "_version_counter",
+            "dim_constraints",
+            # source locations are OK to diverge
+            "var_to_range_sloc",
+            "replacements_slocs",
+            "_resimplify_floor_div_axioms",
+            "_expr_sym_node_id",
+            "specialization_stacks",
+        )
+
+        # Mapping of the value of each to-be-compared field into the values that
+        # should actually be compared.
+        #
+        # You should modify this if, for example, the field that holds state and
+        # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
+        # and the stack when it was added to the set of guards. In order to compare
+        # it, we throw away the stack information.
+        def map_value(key: str, value: Any) -> Any:
+            if key == "guards":
+                # Transform the list of ShapeGuard into a list of expressions.
+                return [g.expr for g in value]
+            elif key == "deferred_runtime_asserts":
+                # Transform the list of RuntimeAsserts into a list of expressions.
+                return {s: [ra.expr for ra in ras] for s, ras in value.items()}
+            elif key == "name_to_node":
+                # Compare just the set of keys is the same.
+                return set(value.keys())
+            elif key in (
+                "symbol_guard_counter",
+                "pending_fresh_unbacked_symbols",
+                "fake_tensor_cache",
+            ):
+                # Skip this for comparisons
+                return None
+            return value
+
+        shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
+
+    def _snapshot_tracked_fakes(self) -> Optional[list[Any]]:
+        if self.tracked_fakes is None:
+            return None
+
+        from torch._dynamo.variables.builder import TrackedFake
+
+        def maybe_transform_fake(fake: TrackedFake) -> TrackedFake:
+            inner_fake = (
+                fake.fake
+                if isinstance(fake.fake, (torch.SymInt, torch.SymFloat))
+                else FakeTensorMeta.from_fake(fake.fake)
+            )
+            # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
+            # FakeTensorMeta for two reasons:
+            #   1. this is all the information we need when recording ShapeEnvEvents.
+            #   2. it works even if each TrackedFake changes its metadata.
+            return TrackedFake(inner_fake, fake.source, fake.symbolic_context)  # type: ignore[arg-type]
+
+        return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
+
+    def _last_event_index(self) -> int:
+        return len(self.events) - 1
+
+    @contextmanager
+    def _recording(self) -> Iterator[None]:
+        self.is_recording = True
+        try:
+            yield
+        finally:
+            self.is_recording = False
+
+    @record_shapeenv_event()
+    def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr) -> None:
+        self._set_replacement(orig_s, new_s, "eliminate_unbacked")
+
+    @record_shapeenv_event()
+    def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None:
+        """Used only when propagate_real_tensors; registers a value for an
+        unbacked symbol, which can be used last resort to resolve hints."""
+        log.info("set_unbacked_var_to_val %s = %s", k, v)
+        self.unbacked_var_to_val[k] = sympy.sympify(v)
+
+    # Unlike set_replacement, this records a shapeenv event
+    @record_shapeenv_event()
+    def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol) -> None:
+        assert isinstance(orig_s, sympy.Symbol), orig_s
+        assert isinstance(new_s, sympy.Symbol), new_s
+        assert free_unbacked_symbols(new_s), new_s
+        assert free_unbacked_symbols(orig_s), orig_s
+        dest = self.replacements.get(orig_s)
+        if dest is not None:
+            assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}"
+        self._set_replacement(orig_s, new_s, "rename_unbacked_to")
+        self.unbacked_renamings[orig_s] = new_s
+        if dest is not None:
+            self._set_replacement(new_s, dest, "rename_unbacked_to_dest")
+
+    @record_shapeenv_event()
+    def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
+        # TODO: Do something nontrivial when upper_bound is expression
+        pass
+
+    @record_shapeenv_event()
+    def _constrain_range_for_size(
+        self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None
+    ) -> None:
+        if min is None:
+            min = 0
+        if max is None:
+            max = int_oo
+
+        if max < min:
+            raise ValueError(
+                "Maximum value to constrain_as_size can't be less than the specified min value, "
+                f"received min={min} and max={max}"
+            )
+
+        self.constrain_symbol_range(
+            a,
+            compiler_min=min,
+            compiler_max=max,
+        )
+        self.size_like.add(a)
+
+    @record_shapeenv_event()
+    def _constrain_range(self, a: sympy.Expr, min: int, max: int) -> None:
+        if isinstance(a, sympy.Integer):
+            if not (min <= int(a) <= max):
+                raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]")
+            return
+
+        # TODO: Shouldn't we install a guard if the symbol is backed?  Or is the
+        # semantics that this is an "unchecked" assert (but it this actually
+        # something useful?  Might be better to restrict only for unbacked
+        # SymInt).
+        if isinstance(a, sympy.Symbol):
+            self.constrain_symbol_range(
+                a,
+                compiler_min=min,
+                compiler_max=max,
+            )
+
+    @record_shapeenv_event()
+    def _constrain_unify(self, a: SymInt, b: SymInt) -> None:
+        """
+        Given two SymInts, constrain them so that they must be equal.  NB:
+        this will not work with SymInts that represent nontrivial expressions
+        (yet!)
+        """
+        # TODO: this does not install a deferred runtime assert yet
+
+        # TODO: Maybe dedupe this with _maybe_guard_rel?
+        # Update Feb 2024: this is extra important to do, this doesn't handle
+        # unbacked replacements properly nor does it generate deferred runtime
+        # asserts
+        if not isinstance(a, SymInt):
+            if not isinstance(b, SymInt):
+                assert a == b
+            else:
+                assert isinstance(b.node.expr, sympy.Symbol), (
+                    "constraining non-Symbols NYI"
+                )
+                assert b.node.shape_env is self
+                self.replacements[b.node.expr] = sympy.Integer(a)
+        else:
+            # TODO: Actually, we can support this as long as one of them is a symbol.
+            # NB: We can't actually do "unification" as our operators are not
+            # injective
+            assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+            assert a.node.shape_env is self
+            if not isinstance(b, SymInt):
+                self.replacements[a.node.expr] = sympy.Integer(b)
+            else:
+                assert a.node.shape_env is b.node.shape_env
+                assert isinstance(b.node.expr, sympy.Symbol), (
+                    "constraining non-Symbols NYI"
+                )
+                new_var = self._find(a.node.expr)
+                self.replacements[b.node.expr] = new_var
+
+    def _ignore_fresh_unbacked_symbols_tls(self) -> bool:
+        return getattr(TLS, "ignore_fresh_unbacked_symbols", False)
+
+    @record_shapeenv_event()
+    def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool:
+        prev = self._ignore_fresh_unbacked_symbols_tls()
+        TLS.ignore_fresh_unbacked_symbols = b
+        return prev
+
+    @contextmanager
+    def ignore_fresh_unbacked_symbols(self) -> Iterator[None]:
+        """
+        Indicates that the newly allocated unbacked SymInts are being
+        discarded
+        """
+        prev = self._ignore_fresh_unbacked_symbols_set(True)
+        try:
+            yield
+        finally:
+            self._ignore_fresh_unbacked_symbols_set(prev)
+
+    @record_shapeenv_event()
+    def freeze(self) -> None:
+        """Freeze this ShapeEnv to stop accumulating guards
+
+        A frozen ShapeEnv will ignore any further guards generated on it and
+        only emit a warning which may lead to accuracy problems.
+        """
+        self.frozen = True
+
+    @record_shapeenv_event()
+    def freeze_runtime_asserts(self) -> None:
+        """Freeze this ShapeEnv to stop adding deferred runtime asserts.
+
+        We will error if you try to install a new runtime assert when it is
+        frozen.  This would indicate a lowering violation, or perhaps something
+        we know statically is already True but we are checking it again in a way
+        that is not clearly dischargeable.
+        """
+        # self.prefer_deferred_runtime_asserts_over_guards = False
+        self.runtime_asserts_frozen = True
+
+    def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
+        if not self._translation_validation_enabled:
+            return None
+        srcname = source.name
+        if source not in self.source_to_symbol:
+            self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
+        return self.source_to_symbol[srcname]
+
+    def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_var(symbol, type)
+
+    def _add_target_expr(self, expr: SympyBoolean) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_target_expr(expr)
+
+    def _add_assertion(self, expr: SympyBoolean) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_assertion(expr)
+
+    def _check_translation_validate(self) -> None:
+        if self._translation_validation_enabled:
+            self.validator.validate()
+
+    @record_shapeenv_event()
+    def _create_fx_call_function(
+        self,
+        op: Callable,
+        args: tuple,
+    ) -> tuple[Optional[torch.fx.Node], bool]:
+        # Cache this tuple in order to avoid duplicated nodes.
+        node_key = (op, args)
+        # Flags whether the returned node was cached or not.
+        fresh = False
+
+        if self._translation_validation_enabled and node_key not in self.fx_node_cache:
+            # Presence of None in the arguments implies that we should ignore this operation.
+            if any(a is None for a in args):
+                # We check if we are not mixing SymNode that should not be ignored
+                # (fx_node is not None) with those that should (fx_node is None).
+                assert all(not isinstance(a, torch.fx.Node) for a in args)
+                return None, fresh
+
+            fresh = True
+
+            # If translation validation is enabled, all arguments must have its
+            # own FX node.
+            assert all(a is not None for a in args), (
+                f"missing arg in FX graph ({op.__name__}): {args}"
+            )
+            node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
+            self.name_to_node[node.name] = node
+
+        return self.fx_node_cache.get(node_key, None), fresh
+
+    def _create_fx_placeholder_and_z3var(
+        self,
+        symbol: sympy.Symbol,
+        type: type,
+    ) -> Optional[torch.fx.Node]:
+        if not self._translation_validation_enabled:
+            return None
+
+        node_key = (self.graph.placeholder, (symbol,))
+
+        # Check if we haven't added this symbol already.
+        # If so, skip the placeholder creation, as it
+        # generates invalid Python code.
+        if node_key not in self.fx_node_cache:
+            # Add a Z3 variable according to 'type'.
+            self._add_z3var(symbol, type)
+            # Create the FX placeholder out of a mangled name.
+            mangled_name = re.sub(
+                r"[^a-zA-Z0-9]", "_", re.sub(r"[()]", "", symbol.name)
+            )
+            node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
+            self.name_to_node[node.name] = node
+            # Attach the 'symbol' to the placeholder so that we can retrieve
+            # the Z3 variable later.
+            node.meta["symbol"] = symbol
+
+        return self.fx_node_cache[node_key]
+
+    def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
+        if self._translation_validation_enabled and node is not None:
+            self.name_to_node.pop(node.name)
+            self.graph.erase_node(node)
+
+    def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
+        from torch._dynamo.utils import get_current_node
+
+        if self.should_record_events:
+            node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
+            node.meta[CURRENT_NODE_KEY] = get_current_node()
+
+    @staticmethod
+    def _suppress_guards_tls() -> bool:
+        return getattr(TLS, "suppress_guards", False)
+
+    @record_shapeenv_event()
+    def _suppress_guards_enter(self) -> None:
+        if not hasattr(TLS, "suppress_guards_stack"):
+            TLS.suppress_guards_stack = []
+        old = self._suppress_guards_tls()
+        TLS.suppress_guards_stack.append(old)
+        TLS.suppress_guards = True
+
+    @record_shapeenv_event()
+    def _suppress_guards_exit(self) -> None:
+        old = (
+            TLS.suppress_guards_stack.pop()
+            if len(TLS.suppress_guards_stack) > 0
+            else False
+        )
+        TLS.suppress_guards = old
+
+    def suppress_guards(self) -> _GeneratorContextManager[None]:
+        """Context manager to ignore all guards generated inside"""
+        return _suppress_guards(self)
+
+    def _get_key(self) -> tuple[int, int, int, int]:
+        """
+        Defines the current "state" of the guards we've accumulated in this ShapeEnv.
+        Determines when we need to invalidate our cache
+        """
+        return (
+            len(self.replacements),
+            len(self.divisible),
+            self.num_deferred_runtime_asserts,
+            len(self.unbacked_var_to_val),
+        )
+
+    def _update_version_counter(self) -> None:
+        # if the change to shape env effects self.divisible set
+        # _resimplify_floor_div_axioms.
+        # This is used to trigger a resimplication of FloorDiv to CleanDivs
+        # in implication inside the function resimplify_floor_div.
+        if len(self.divisible) != self._prev_cache_key[1]:
+            self._resimplify_floor_div_axioms = True
+
+        # The shape environment is queried orders of magnitude more often than
+        # it is changed, so we summarise the cache key into a linearly
+        # increasing version counter which is cheaper to check in _lru_cache
+
+        # Only update version counter if the state actually changed
+        cur_key = self._get_key()
+
+        if self._prev_cache_key != cur_key:
+            self._prev_cache_key = cur_key
+            self._version_counter += 1
+
+    def _produce_dyn_sizes(
+        self,
+        ex_size: Sequence[IntLikeType],
+        source: Source,
+        symbolic_context: SymbolicContext,
+    ) -> list[sympy.Expr]:
+        return self._produce_dyn_sizes_from_int_tuple(
+            tuple(ex_size), source, symbolic_context
+        )
+
+    def _produce_dyn_sizes_from_int_tuple(
+        self,
+        tensor_size: Sequence[IntLikeType],
+        source: Source,
+        symbolic_context: SymbolicContext,
+        hint_overrides: Optional[dict[int, int]] = None,
+    ) -> list[sympy.Expr]:
+        assert all(not is_symbolic(val) for val in tensor_size), (
+            f"Expect size to be a plain tuple of ints but got {tensor_size}"
+        )
+        from torch._dynamo.source import TensorProperty, TensorPropertySource
+
+        if not hint_overrides:
+            hint_overrides = {}
+
+        _assert_symbol_context(symbolic_context)
+        dynamic_dims = symbolic_context.dynamic_sizes  # type: ignore[attr-defined]
+        constraint_dims = symbolic_context.constraint_sizes  # type: ignore[attr-defined]
+        size = []
+        for i, val in enumerate(tensor_size):
+            sym = self.create_symbol(
+                hint_overrides.get(i, val),
+                TensorPropertySource(source, TensorProperty.SIZE, i),
+                dynamic_dims[i],
+                constraint_dims[i],
+                do_not_specialize_zero_one=config.backed_size_oblivious,
+                symbolic_context=symbolic_context,
+            )
+            if (
+                isinstance(symbolic_context, StatelessSymbolicContext)
+                and symbolic_context.specialize_on
+            ):
+                for specialization in symbolic_context.specialize_on[i]:
+                    self.specializations.add(
+                        Specialization(
+                            TensorPropertySource(source, TensorProperty.SIZE, i),
+                            specialization,
+                        )
+                    )
+            if (
+                config.backed_size_oblivious
+                and isinstance(sym, sympy.Symbol)  # could be static
+                and symbol_is_type(sym, SymT.SIZE)
+            ):
+                self.size_like.add(sym)
+            size.append(sym)
+        return size
+
+    def create_symbolic_sizes_strides_storage_offset(
+        self,
+        ex: torch.Tensor,
+        source: Source,
+        *,
+        symbolic_context: Optional[SymbolicContext] = None,
+    ) -> tuple[
+        tuple[IntLikeType, ...],
+        tuple[IntLikeType, ...],
+        IntLikeType,
+    ]:
+        """
+        Returns a list of symbolic sizes and strides for the given tensor.
+        We try our best to express stride in terms of the sizes, so as to not
+        introduce new symbolic variables.
+        """
+
+        ex_size = tuple(
+            self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size()
+        )
+        ex_stride = tuple(
+            self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride()
+        )
+        ex_storage_offset = self._maybe_specialize_sym_int_with_hint(
+            ex.storage_offset()
+        )
+
+        return self._create_symbolic_sizes_strides_storage_offset(
+            ex_size,
+            ex_stride,
+            ex_storage_offset,
+            [_is_dim_dynamic(ex, i) for i in range(ex.dim())],
+            source,
+            symbolic_context=symbolic_context,
+        )
+
+    # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
+    # We create symbols in shape_env using the backed hints behind SymInt.
+
+    # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
+    # produce_guards will trigger specializations on the outer stuff
+
+    # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
+    #
+    # It's probably good for now but it's important to note that this approach has implications for
+    # the original shape_env when checking guards in different order.
+
+    # Example:
+    # ---------
+    # Consider a function "opt_f" as shown below:
+
+    # @torch.compile()
+    # def opt_f(x: bool, y: Tensor):
+    #   if x == True:
+    #     return y + torch.randn([4])
+    #   else:
+    #     return y
+    # Depending on the sequence of calls, we might install two different sets of guards:
+
+    # 1. opt_f(False, y):
+    #    - "x == False" (always works for any size y)
+
+    # 2. opt_f(True, y):
+    #    - Triggers recompilation and results in guards like:
+    #      - "x == True and y.size(0) == 4"
+    #      - (or "y.size(0) == 4 and x == True")
+
+    # The order of checking the guards matters. In this specific example:
+    # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
+    # we may have an unnecessary shape specialization for y.
+    def _maybe_specialize_sym_int_with_hint(
+        self, maybe_sym: IntLikeType
+    ) -> IntLikeType:
+        assert isinstance(maybe_sym, (int, torch.SymInt))
+        if is_symbolic(maybe_sym):
+            assert maybe_sym.node.shape_env is not self, (
+                "expect the symbol is created from an shape env other than current one."
+            )
+            return maybe_sym.node.require_hint()
+        return maybe_sym
+
+    @record_shapeenv_event()
+    def _create_symbolic_sizes_strides_storage_offset(
+        self,
+        # NB: SymInt is allowed here due to nested int, normally you don't
+        # actually pass true symbolic sizes to this function
+        ex_size: Sequence[IntLikeType],
+        ex_stride: Sequence[IntLikeType],
+        ex_storage_offset: IntLikeType,
+        is_dim_dynamic: Sequence[bool],
+        source: Source,
+        *,
+        symbolic_context: Optional[SymbolicContext] = None,
+        hint_overrides: Optional[dict[int, int]] = None,
+    ) -> tuple[
+        tuple[IntLikeType, ...],
+        tuple[IntLikeType, ...],
+        IntLikeType,
+    ]:
+        dim = len(ex_size)
+
+        if not hint_overrides:
+            hint_overrides = {}
+
+        # Reimplement the legacy behavior
+        if symbolic_context is None:
+            constraint_sizes: list[DimConstraint] = [None] * dim
+            constraint_strides: list[DimConstraint] = [None] * dim
+            dynamic_dims = []
+            dynamic_strides = []
+            for i in range(dim):
+                # NB: This is encapsulation breaking!  Legacy behavior was
+                # bad.
+                if is_dim_dynamic[i]:
+                    r = DimDynamic.DYNAMIC
+                elif self.assume_static_by_default:
+                    r = DimDynamic.STATIC
+                else:
+                    r = DimDynamic.DUCK
+                dynamic_dims.append(r)
+                dynamic_strides.append(r)
+            dynamic_dims = [DimDynamic.DUCK] * dim
+            dynamic_strides = [DimDynamic.INFER_STRIDE] * dim
+            # symbolic_context is None - set one
+            symbolic_context = StatelessSymbolicContext(
+                dynamic_sizes=dynamic_dims,
+                dynamic_strides=dynamic_strides,
+                constraint_sizes=constraint_sizes,
+                constraint_strides=constraint_strides,
+            )
+        # We got a StatelessSymbolicContext
+        _assert_symbol_context(symbolic_context)
+        constraint_sizes = symbolic_context.constraint_sizes  # type: ignore[attr-defined]
+        constraint_strides = symbolic_context.constraint_strides  # type: ignore[attr-defined]
+        dynamic_sizes = symbolic_context.dynamic_sizes  # type: ignore[attr-defined]
+        dynamic_strides = symbolic_context.dynamic_strides  # type: ignore[attr-defined]
+
+        # TODO: make this configurable from outside symbolic_context; we made a symbolic_context
+        # decision here where if all sizes are static, we are going to
+        # specialize all of the inner strides/offset too. We don't have to
+        # do this, and arguably we should ALWAYS allow for dynamic offset,
+        # this is cheap.
+        # TODO: This should be DYNAMIC, using DUCK for BC
+        dynamic_offset = (
+            DimDynamic.STATIC
+            if all(r == DimDynamic.STATIC for r in dynamic_sizes)
+            else DimDynamic.DUCK
+        )
+        are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes)
+
+        assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
+        assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}"
+        assert len(constraint_sizes) == dim
+        assert len(constraint_strides) == dim
+
+        from torch._dynamo.source import TensorProperty, TensorPropertySource
+
+        size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
+            ex_size, source, symbolic_context, hint_overrides=hint_overrides
+        )
+        stride = self._compute_symbolic_stride(
+            source,
+            size,
+            ex_size,
+            ex_stride,
+            dynamic_strides,
+            constraint_strides,
+            are_sizes_static,
+            symbolic_context,
+        )
+
+        sym_sizes = [
+            self.create_symintnode(
+                sym,
+                hint=hint_overrides.get(i, hint),
+                source=TensorPropertySource(source, TensorProperty.SIZE, i),
+            )
+            for i, (sym, hint) in enumerate(zip(size, ex_size))
+        ]
+
+        for i, sym in enumerate(sym_sizes):
+            if isinstance(sym, torch.SymInt) and i in hint_overrides:
+                self.var_to_hint_override[sym.node.expr] = hint_overrides[i]
+
+        sym_stride = []
+        for i, stride_expr in enumerate(stride):
+            # NB: Don't duck size the stride; instead use the expression
+            # we computed
+            assert stride_expr is not None
+            sym_stride.append(
+                self.create_symintnode(
+                    stride_expr,
+                    hint=ex_stride[i],
+                    source=TensorPropertySource(source, TensorProperty.STRIDE, i),
+                )
+            )
+        sym_storage_offset = self.create_symintnode(
+            self.create_symbol(
+                ex_storage_offset,
+                TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
+                dynamic_dim=dynamic_offset,
+                constraint_dim=None,
+                symbolic_context=symbolic_context,
+            ),
+            hint=ex_storage_offset,
+            source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
+        )
+        return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
+
+    def _compute_symbolic_stride(
+        self,
+        source: Source,
+        size: Sequence[sympy.Expr],
+        ex_size: Sequence[IntLikeType],
+        ex_stride: Sequence[IntLikeType],
+        dynamic_strides: Sequence[DimDynamic],
+        constraint_strides: Sequence[
+            Optional[Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]]
+        ],
+        are_sizes_static: bool,
+        symbolic_context: SymbolicContext,
+    ) -> list[sympy.Expr]:
+        from torch._dynamo.source import TensorProperty, TensorPropertySource
+
+        stride: list[Optional[sympy.Expr]] = [None] * len(size)
+        candidates: dict[IntLikeType, sympy.Expr] = {}
+
+        # iterate over unbound strides in val ascending order with
+        # index descending as a tie breaker since for cases like
+        # [(1, 1), (1, 0)], we want to fill in the right most
+        # stride first.
+        val_list = [(val, -i) for i, val in enumerate(ex_stride)]
+        val_list.sort(key=_nested_int_aware_sort)
+
+        for val, neg_i in val_list:
+            i = -neg_i
+            contiguous_stride = (
+                i != len(ex_stride) - 1
+                and ex_stride[i] == ex_size[i + 1] * ex_stride[i + 1]
+            )
+            if val in (0, 1) and not contiguous_stride:
+                out_stride = sympy.Integer(val)
+            else:
+                dynamic_stride = dynamic_strides[i]
+                if dynamic_stride == DimDynamic.INFER_STRIDE and val in candidates:
+                    # Set stride to a candidate only for DimDynamic.INFER_STRIDE
+                    out_stride = candidates[val]
+                else:
+                    # Set INFER_STRIDE to STATIC or DUCK depending on sizes
+                    dyn_stride = dynamic_stride
+                    if dynamic_stride == DimDynamic.INFER_STRIDE:
+                        dyn_stride = (
+                            DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK
+                        )
+                    out_stride = self.create_symbol(
+                        val,
+                        TensorPropertySource(source, TensorProperty.STRIDE, i),
+                        dynamic_dim=dyn_stride,
+                        constraint_dim=constraint_strides[i],
+                        symbolic_context=symbolic_context,
+                    )
+            stride[i] = out_stride
+            candidates[ex_size[i] * val] = size[i] * out_stride
+
+        assert all(x is not None for x in stride)
+        return stride
+
+    @record_shapeenv_event()
+    def create_symintnode(
+        self,
+        sym: sympy.Expr,
+        *,
+        hint: Optional[int],
+        source: Optional[Source] = None,
+    ) -> IntLikeType:
+        """Create a SymInt value from a symbolic expression
+
+        If you know what the current hint value of the SymInt to be created
+        is, pass it into hint.  Otherwise, pass None and we will make our best
+        guess
+
+        """
+        if self._translation_validation_enabled and source is not None:
+            # Create a new symbol for this source.
+            symbol = self._create_symbol_for_source(source)
+            assert symbol is not None
+
+            # Create a new FX placeholder and Z3 variable for 'symbol'.
+            fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
+
+            # Add an equality assertion for the newly created symbol and 'sym'.
+            self._add_assertion(sympy.Eq(symbol, sym))
+        else:
+            fx_node = None
+
+        out: IntLikeType
+        if isinstance(sym, sympy.Integer):
+            if hint is not None:
+                assert int(sym) == hint
+            out = int(sym)
+        else:
+            # How can this occur? When we mark_unbacked, we end up with a real
+            # tensor that has hints for all sizes, but we MUST NOT create a
+            # SymNode with a hint, because we're hiding the hint from our eyes
+            # with the unbacked Symbol.  And in fact, the hint compute may be
+            # inconsistent with size oblivious tests.
+            if free_unbacked_symbols(sym):
+                hint = None
+            out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
+        return out
+
+    @record_shapeenv_event()
+    def create_symfloatnode(
+        self,
+        sym: sympy.Expr,
+        *,
+        hint: Optional[int | float | bool],
+        source: Optional[Source] = None,
+    ) -> FloatLikeType:
+        """Create a SymFloat value from a symbolic expression"""
+        if self._translation_validation_enabled and source is not None:
+            # Create a new symbol for this source.
+            symbol = self._create_symbol_for_source(source)
+            assert symbol is not None
+
+            # Create a new FX placeholder and Z3 variable for 'symbol'.
+            fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
+
+            # Add an equality assertion for the newly created symbol and 'sym'.
+            self._add_assertion(sympy.Eq(symbol, sym))
+        else:
+            fx_node = None
+
+        out: FloatLikeType
+        if isinstance(sym, sympy.Float):
+            if hint is not None:
+                assert float(sym) == hint
+            out = float(sym)
+        else:
+            # You could give this the same treatment as SymInt above if
+            # you supported mark_unbacked on a float, but it's a kind of
+            # strange thing to do though because floats don't get 0/1
+            # specialization anyway
+            if free_unbacked_symbols(sym):
+                assert hint is None, sym
+            out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node))
+        return out
+
+    @record_shapeenv_event()
+    def create_unspecified_symint_and_symbol(
+        self, value: int, source: Source, dynamic_dim: DimDynamic
+    ) -> IntLikeType:
+        """Create a SymInt wrapping a new unspecified symbol"""
+        return self.create_symintnode(
+            self.create_unspecified_symbol(
+                value,
+                source=source,
+                dynamic_dim=dynamic_dim,
+            ),
+            hint=value,
+            source=source,
+        )
+
+    def create_symboolnode(self, sym: sympy.Expr) -> SymBool:
+        """Create a SymBool object from a sympy boolean expression"""
+        # This function is only being used in serialization, so we do not track it
+        # for validation.
+        return SymBool(SymNode(sym, self, bool, None))
+
+    def _log_create_unbacked_symbol(
+        self,
+        prefix: str,
+        symbol: sympy.Symbol,
+        vr: ValueRanges,
+        source: Optional[Source] = None,
+        sym_node: Optional[SymNode] = None,
+    ) -> None:
+        is_debug = config.extended_debug_create_symbol is not None and str(
+            symbol
+        ) in config.extended_debug_create_symbol.split(",")
+        sloc: Union[str, SLoc]
+        if source is None:
+            sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
+        else:
+            sloc, maybe_extra_debug = source.name, ""
+        log.info(
+            "%s %s [%s, %s] %s%s",
+            prefix,
+            symbol,
+            vr.lower,
+            vr.upper,
+            sloc,
+            maybe_extra_debug,
+            stack_info=is_debug,
+        )
+        trace_structured(
+            "create_unbacked_symbol",
+            metadata_fn=lambda: {
+                "symbol": str(symbol),
+                "node_id": id(sym_node),
+                "vr": f"[{vr.lower}, {vr.upper}]",
+                "user_stack": structured.get_user_stack(3),
+                "stack": structured.get_framework_stack(),
+            },
+        )
+
+    @record_shapeenv_event()
+    def create_unbacked_symfloat(self) -> SymFloat:
+        """Create a symbolic float without a hint value"""
+        symbol: sympy.Symbol = make_symbol(
+            SymT.UNBACKED_FLOAT, self.unbacked_symfloat_counter
+        )
+        self.unbacked_symfloat_counter += 1
+        self.counter["create_unbacked_symbol"] += 1
+        if not self._ignore_fresh_unbacked_symbols_tls():
+            self.pending_fresh_unbacked_symbols.append(symbol)
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = ValueRanges.unknown()
+        assert vr.is_float
+        sloc = self._get_sloc()
+        self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc)
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
+
+        sym_node = SymNode(symbol, self, float, None, fx_node=fx_node)
+        self._log_create_unbacked_symbol(
+            "create_unbacked_symfloat", symbol, vr, sym_node=sym_node
+        )
+
+        return SymFloat(sym_node)
+
+    @record_shapeenv_event()
+    def create_unbacked_symint(self, source: Optional[Source] = None) -> SymInt:
+        """Create a symbolic integer without a hint value"""
+        symbol: sympy.Symbol = make_symbol(
+            SymT.UNBACKED_INT, self.unbacked_symint_counter, integer=True
+        )
+        self.unbacked_symint_counter += 1
+        if not self._ignore_fresh_unbacked_symbols_tls():
+            self.pending_fresh_unbacked_symbols.append(symbol)
+        self.counter["create_unbacked_symbol"] += 1
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
+        assert vr.is_int
+        sloc = self._get_sloc()
+        self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc)
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
+
+        sym_node = SymNode(symbol, self, int, None, fx_node=fx_node)
+        self._log_create_unbacked_symbol(
+            "create_unbacked_symint", symbol, vr, source, sym_node=sym_node
+        )
+        return SymInt(sym_node)
+
+    def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
+        """Check if a sympy symbol matches the naming convention for unbacked symbols"""
+        return symbol_is_type(symbol, SymT.UNBACKED_INT)
+
+    @record_shapeenv_event()
+    def create_unbacked_symbool(self) -> SymBool:
+        """Create a symbolic boolean without a hint value"""
+        symbol: sympy.Symbol = make_symbol(
+            SymT.UNBACKED_INT, self.unbacked_symint_counter, integer=True
+        )
+        self.unbacked_symint_counter += 1
+        if not self._ignore_fresh_unbacked_symbols_tls():
+            self.pending_fresh_unbacked_symbols.append(symbol)
+        self.counter["create_unbacked_symbol"] += 1
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = ValueRanges(0, 1)
+        assert vr.is_int
+        sloc = self._get_sloc("default value range for unbacked SymBool")
+        self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc)
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
+
+        sym_node = SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node)
+        self._log_create_unbacked_symbol(
+            "create_unbacked_symbool", symbol, vr, sym_node=sym_node
+        )
+
+        return SymBool(sym_node)
+
+    @record_shapeenv_event()
+    def create_unspecified_symbol(
+        self,
+        val: Union[int, SymInt, float, SymFloat],
+        source: Source,
+        dynamic_dim: DimDynamic = DimDynamic.DUCK,
+        constraint_dim: DimConstraint = None,  # NB: includes None
+        symbolic_context: Optional[StatelessSymbolicContext] = None,
+    ) -> sympy.Expr:
+        """
+        Create a symbol with an unspecified value
+
+        Compared to standard symbols we do not assume the value is positive,
+        nor do we specialze on zero or one values.
+        """
+        # 'positive' is None for unspecified symbols, since we can't
+        # assume that it will be neither positive nor negative.
+
+        # We don't want to specialize zero one val for unspecified symbol
+        # so that we can always get a new symbol despite val.
+        return self.create_symbol(
+            val,
+            source,
+            dynamic_dim,
+            constraint_dim,
+            positive=None,
+            do_not_specialize_zero_one=True,
+            symbolic_context=symbolic_context,
+        )
+
+    @record_shapeenv_event()
+    def create_symbol(
+        self,
+        val: int,
+        source: Source,
+        dynamic_dim: DimDynamic = DimDynamic.DUCK,
+        constraint_dim: DimConstraint = None,  # NB: includes None
+        positive: Optional[bool] = True,
+        do_not_specialize_zero_one: bool = False,
+        symbolic_context: Optional[StatelessSymbolicContext] = None,
+    ) -> sympy.Expr:
+        """Create a new symbol which is tracked by this ShapeEnv"""
+        # check if constraint_dim is actually static integer
+        if (
+            isinstance(constraint_dim, StrictMinMaxConstraint)
+            and constraint_dim.vr.lower == constraint_dim.vr.upper
+        ):
+            dynamic_dim = DimDynamic.STATIC
+            if constraint_dim.vr.lower != val:
+                raise ConstraintViolationError(
+                    f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, "
+                    f"for {source.name}"
+                )
+            if symbolic_context:
+                from torch._dynamo.source import TensorPropertySource
+
+                assert isinstance(source, TensorPropertySource)
+                # TODO: storage_offset handling?
+                assert source.idx is not None
+                symbolic_context.dynamic_sizes[source.idx] = dynamic_dim
+                symbolic_context.constraint_sizes[source.idx] = None
+            constraint_dim = None
+
+        # see note [Tensor Fakification and Symbol Caching]
+        source_name = source.name
+        if (
+            isinstance(symbolic_context, StatefulSymbolicContext)
+            and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache
+        ):
+            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {}
+
+        if (
+            isinstance(symbolic_context, StatefulSymbolicContext)
+            and source_name
+            and (
+                source_name
+                in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)]
+            )
+        ):
+            return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
+                source_name
+            ]
+
+        if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE):
+            out = self.create_unbacked_symint(source).node.expr
+            self._constrain_range_for_size(out)
+
+            self.unbacked_inputs.add(out)
+
+            if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
+                symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
+                    source_name
+                ] = out
+            if dynamic_dim is DimDynamic.OBLIVIOUS_SIZE:
+                self.oblivious_var_to_val[out] = val
+            return out
+
+        if do_not_specialize_zero_one:
+            specialize_zero_one = False
+        else:
+            specialize_zero_one = self.specialize_zero_one
+
+        assert isinstance(source, Source), f"{type(source)} {source}"
+        assert not (positive and val < 0), f"positive set for negative value: {val}"
+        # It's always sound to allocate a symbol as DYNAMIC.  If the user
+        # constrained the symbol, force the symbolic_context to DYNAMIC, because our
+        # constraint code will do weird stuff if, e.g., it's duck shaped
+        if constraint_dim is not None:
+            dynamic_dim = DimDynamic.DYNAMIC
+
+        if dynamic_dim is DimDynamic.STATIC:
+            out = sympy.Integer(val)
+            if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
+                symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
+                    source_name
+                ] = out
+            return out
+
+        elif dynamic_dim is DimDynamic.DUCK:
+            # duck_shape can be used to globally turn off duck shaping, even
+            # if it was requested
+            duck = self.duck_shape
+        elif dynamic_dim is DimDynamic.DYNAMIC:
+            duck = False
+        else:
+            raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
+
+        sloc = self._get_sloc()
+
+        if val in (0, 1) and specialize_zero_one:
+            if val == 0:
+                return sympy.S.Zero
+            else:
+                return sympy.S.One
+        elif not duck or val not in self.val_to_var:
+            # If we're not duck shaping, we always create a new symbol
+            # Even if we're duck shaping, if we haven't seen this particular
+            # value before, we also create a new symbol
+            symbol_id = self._generate_unique_id(source.name)
+            if type(val) is int or is_nested_int(val):
+                sympy_expr = make_symbol(
+                    SymT.SIZE, symbol_id, positive=positive, integer=True
+                )
+            else:
+                sympy_expr = make_symbol(
+                    SymT.FLOAT, symbol_id, positive=positive, real=True
+                )
+            self.source_to_var[source_name] = sympy_expr
+            # We always associate vars to vals
+            if isinstance(val, int):
+                self.var_to_val[sympy_expr] = sympy.Integer(val)
+            elif isinstance(val, float):
+                self.var_to_val[sympy_expr] = sympy.Float(val)
+            else:
+                # Only used for jagged layout nested tensors
+                self.var_to_val[sympy_expr] = SingletonInt(
+                    val.node.nested_int(), coeff=val.node.nested_int_coeff()
+                )
+
+            # Do the appending later, because we always want to populate this
+            self.var_to_sources[sympy_expr] = []
+            # Create a Z3 variable for the new symbol.
+            self._add_z3var(sympy_expr, int)
+
+            if duck:
+                # Make sure to reuse this symbol for subsequent duck shaping
+                # pyrefly: ignore [unsupported-operation]
+                self.val_to_var[val] = sympy_expr
+
+            if isinstance(val, int):
+                if positive:
+                    # Add assertions for the newly created symbols
+                    self._add_assertion(sympy_expr > 1)
+
+                    # Apply default range, which assumes not zero-one
+                    self.var_to_range[sympy_expr] = self._default_value_range(
+                        do_not_specialize_zero_one
+                    )
+                    self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(
+                        self._get_sloc(
+                            "user code shown is first use of this value--the guard itself is not "
+                            "due user code but due to 0/1 specialization in the framework; to "
+                            "avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim)"
+                            if self.specialize_zero_one
+                            else None
+                        ),
+                        sloc,
+                    )
+                else:
+                    self.var_to_range[sympy_expr] = (
+                        self._default_unspecified_value_range()
+                    )
+                    self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc)
+
+                # Small performance optimization: if we have a min-max constraint,
+                # we can proactively narrow to that range
+                if isinstance(constraint_dim, StrictMinMaxConstraint):
+                    assert not duck
+                    self._update_var_to_range(
+                        sympy_expr, constraint_dim.vr, is_constraint=True
+                    )
+
+                vr = self.var_to_range[sympy_expr]
+                assert vr.is_int
+
+                if val not in vr:
+                    raise ConstraintViolationError(
+                        f"{val} not in range [{vr.lower}, {vr.upper}]"
+                    )
+
+                range_str = f"[{vr.lower}, {vr.upper}]"
+            elif isinstance(val, float):
+                self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo)
+                self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc)
+                range_str = f"[{vr.lower}, {vr.upper}]"
+                assert vr.is_float
+            else:
+                # Skip var_range logic for SingletonInt
+                # Only used for jagged layout nested tensors
+                range_str = ""
+
+            r = sympy_expr
+
+            is_debug = config.extended_debug_create_symbol is not None and str(
+                sympy_expr
+            ) in config.extended_debug_create_symbol.split(",")
+            maybe_more_info = ""
+            if not is_debug and os.getenv("TORCHDYNAMO_EXTENDED_ADVICE", "1") not in (
+                "0",
+                "",
+            ):
+                maybe_more_info = (
+                    ", for more info run with "
+                    f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}" '
+                    "or to suppress this message run with "
+                    'TORCHDYNAMO_EXTENDED_ADVICE="0"'
+                )
+            sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
+            self.log.info(
+                "create_symbol %s = %s for %s %s %s%s%s",
+                sympy_expr,
+                val,
+                source.name,
+                range_str,
+                sloc,
+                maybe_more_info,
+                maybe_extra_debug,
+                stack_info=is_debug,
+            )
+            trace_structured(
+                "create_symbol",
+                metadata_fn=lambda: {
+                    "symbol": str(sympy_expr),
+                    "val": repr(val),
+                    "vr": range_str,
+                    "source": source.name,
+                    "user_stack": structured.from_traceback(
+                        TracingContext.extract_stack()
+                    ),
+                    "stack": structured.from_traceback(
+                        CapturedTraceback.extract(skip=1).summary()
+                    ),
+                },
+            )
+
+            self.counter["create_symbol"] += 1
+        else:
+            # This implements duck-shaping: input sizes that match are assigned
+            # the same symint
+            r = self.val_to_var[val]
+            self.source_to_var[source_name] = r
+            self.log.debug("create_symbol %s duck sized %s", r, source.name)
+
+        if isinstance(r, sympy.Symbol):
+            r_sources = self.var_to_sources[r]
+            r_sources.append(source)
+            if not source.is_ephemeral() and r_sources[0].is_ephemeral():
+                # prefer non-ephemeral source first since it may be guarded on later
+                r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
+
+            # This ensures we get zeros in symbol_guard_counts, which makes
+            # some queries simpler (since we will accumulate mass on 0 this
+            # way)
+            self.symbol_guard_counter[r] = 0
+
+        if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
+            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
+                source_name
+            ] = r
+        return r
+
+    def add_var_to_val(self, expr: sympy.Symbol, val: int) -> None:
+        """Adds a new symbol to the symbolic environment."""
+        log.debug("add_var_to_val %s %s", expr, val, stack_info=True)
+        assert expr not in self.var_to_val, f"{expr} already exists"
+        self.var_to_val[expr] = sympy.Integer(val)
+
+    def _debug_name(self, source: Source) -> str:
+        src_name = source.name
+        return self.source_name_to_debug_name.get(src_name, src_name)
+
+    def _render_range_for_constraint_violation(
+        self, source: Source, c: Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]
+    ) -> str:
+        if isinstance(c, StrictMinMaxConstraint):
+            lower, upper = c.vr.lower, c.vr.upper
+            default = self._default_value_range()
+            if lower <= default.lower:
+                lower = None
+            if upper >= default.upper:
+                upper = None
+            c_render = (
+                f"{self._debug_name(source)} = {source.name} in the specified range"
+            )
+            if lower is not None and upper is not None:
+                c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
+            elif lower is None and upper is not None:
+                c_render += f" {self._debug_name(source)} <= {upper}"
+            elif lower is not None and upper is None:
+                c_render += f" {lower} <= {self._debug_name(source)}"
+            return c_render
+        return c.render(source)
+
+    def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]:
+        """
+        Like produce_guards_verbose, but only returns the non-verbose python guard expressions
+        (no verbose guards produced.)
+        """
+        return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
+
+    def produce_guards_verbose(
+        self,
+        placeholders: Sequence[FakeTensor],
+        sources: Sequence[Source],
+        source_ref: Callable[[Source], str] = lambda n: n.name,
+        *,
+        guards: Optional[list[ShapeGuard]] = None,
+        input_contexts: Optional[DimList[SymbolicContext]] = None,
+        # Encodes user-specified input shape equations of the form s = s' and s = fn(s').
+        # (See docs on EqualityConstraint for details of the encoding.)
+        equalities_inputs: Optional[EqualityConstraint] = None,
+        _simplified: bool = False,
+        # Indicates if we should produce guards for known static values.
+        ignore_static: bool = True,
+        langs: tuple[str, ...] = ("python", "verbose_python"),
+    ) -> list[_ShapeGuardsHelper]:
+        """
+        Generates a list of guards strings which, when evaluated in a context that
+        defines tensors for all the sources, returns True or False depending
+        on if the guards in the list evaluated to True or not.  Primarily used by Dynamo,
+        but this is also helpful for manual testing of guards (see
+        evaluate_guards_for_args)
+
+        For convenience in testing, a source is allowed to be a str,
+        in which case we will assume it is a LocalSource
+
+        simplified lets you omit duck sizing, equality and 0/1 guards.
+        This is useful for testing when you don't care about the boilerplate
+        guards, and it may be helpful for user output too (be careful though;
+        some equality guards are nontrivial!  It would be nice to get simplified
+        output to print them too).  It's private because it's not
+        intended for normal use
+
+        Returns guards in python and python with verbose comments (verbose) by
+        default.
+        """
+        self.log.info("produce_guards")
+
+        # Check if we get to the same ShapeEnv state by replaying the recorded events.
+        # This will create a new ShapeEnv instance, and call all recorded function
+        # calls on this new instance. Finally, it will check whether this new instance
+        # has equal state.
+        #
+        # It's important that we do it in the beginning of this function, since it modifies
+        # self.dim_constraints through its execution. Changes that happen in this method
+        # aren't interesting, since this is the function call we wish to reproduce at the
+        # end. If we wish to simply reproduce ShapeEnv instances even after this call,
+        # this method should also be recorded.
+        if self.check_recorded_events:
+            shape_env = replay_shape_env_events(self.events)
+            self.check_equal(shape_env)
+
+        assert len(placeholders) == len(sources), (
+            f"len({placeholders}) != len({sources})"
+        )
+        Tensorlike = (torch.Tensor, FakeTensorMeta)
+
+        def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext:
+            return StatelessSymbolicContext(
+                # Ignored; only the constraints part is relevant below.
+                dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
+                dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(),
+                constraint_sizes=[None] * t.dim(),
+                constraint_strides=[None] * t.dim(),
+            )
+
+        # Expand optional inputs, or verify invariants are upheld
+        if input_contexts is None:
+            # pyrefly: ignore [bad-assignment]
+            input_contexts = [
+                # pyrefly: ignore [bad-argument-type]
+                _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
+                for t in placeholders
+            ]
+        else:
+            assert len(input_contexts) == len(placeholders)
+
+            for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
+                if isinstance(t, Tensorlike):
+                    if context is None:
+                        # pyrefly: ignore [bad-argument-type]
+                        input_contexts[i] = _create_no_constraints_context(t)
+                else:
+                    assert isinstance(t, (SymInt, int, SymFloat, float))
+                    assert not isinstance(context, list)
+
+        # It took a lot of sweat to figure out the algorithm here.  Let's
+        # explain how it works.
+        #
+        # The ShapeEnv lifecycle looks something like this:
+        #
+        # - For each input, you either generate a fresh Sympy symbol (s0) to
+        #   represent its value (a binding site), or you reuse some
+        #   preexisting symbol or expression, skipping the symbol allocation
+        #   (e.g., duck sizing to a preexisting symbol, or expressing a
+        #   stride as a multiplication of a separate stride and size.)
+        #   Naively, you might expect to bind a fresh Sympy symbol for
+        #   every input, but this is fairly wasteful as most of these
+        #   symbols immediately simplify away, and if you don't eagerly
+        #   specialize, e.g., 0/1 symbols, you end up with very complicated
+        #   expressions that are not optimizable in practice.
+        #
+        # - You perform some compute on these symbols, occasionally
+        #   introducing guards on boolean expressions on these symbols.
+        #   In particular, whenever we guard on equality (_maybe_guard_rel),
+        #   we can simplify shapes; e.g., when s0 == s1 * 2, we can now
+        #   replace all occurrences of s0 with s1 * 2.  Sometimes, a
+        #   boolean expression evaluation doesn't introduce a guard, as
+        #   the guard is already entailed by the simplifications we have
+        #   applied.
+        #
+        # - In the end, you have a bunch of replacements (saying how to
+        #   simplify shapes) and a bunch of guards (all the equality guards
+        #   are trivial, because they're covered by the replacements).
+        #
+        # From the ShapeEnv, we must generate a Python expression that, when
+        # evaluated on a set of inputs, tells us whether or not these boolean
+        # expressions would have evaluated in the same way.  However,
+        # we cannot easily compute this, as we elide recording boolean
+        # expressions when we think they are vacuously true.  Thus, we seek
+        # an approximation: we must generate an expression, if true, would have
+        # produced an "equivalent" ShapeEnv, which would answer guard
+        # expressions in the same way.
+        #
+        # Our notion of equivalence is a bit subtle.  For example, consider
+        # the ShapeEnv created from an input of size (5, 4) versus (4, 4)
+        # (no other guards.)  Duck sizing would generate (s0, s1) in the first
+        # case but (s0, s0) in the second.  We do NOT assume that size
+        # variables are disjoint; so in fact a graph that assumes the input
+        # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
+        # vice versa.  However, consider an analogous case (1,) versus (2,).
+        # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
+        # subsume the (1,) graph because we assume that any size variables
+        # is NOT 0/1 (and make simplifications according to this; e.g., if
+        # we queried s0 == 0, we would immediately return False without
+        # returning a guard.)
+        #
+        # So, it is perhaps easier to flip things on their head: the guard
+        # expressions we generate here say what simplifications are valid,
+        # and what are not. Below, we explain each of the guard expressions
+        # we generate
+
+        # TODO: Make this more efficient by binding all the size/stride/offsets
+        # to locals before performing tests on them.
+
+        from torch._dynamo.source import TensorProperty, TensorPropertySource
+
+        # Actual codegen must be delayed as we don't necessarily know what
+        # the symbol mapping is
+        input_guards = []
+
+        symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
+            list
+        )
+        symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = (
+            collections.defaultdict(set)
+        )
+        constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
+
+        printers: list[_ShapeGuardPrinter] = []
+        py_printer = ShapeGuardPythonPrinter(
+            symbol_to_source, source_ref, self.var_to_sources
+        )
+        for lang in langs:
+            if lang in ["python", "verbose_python"]:
+                printers.append(py_printer)
+            elif lang == "cpp":
+                printers.append(
+                    _ShapeGuardCppPrinter(
+                        symbol_to_source, source_ref, self.var_to_sources
+                    )
+                )
+            else:
+                raise NotImplementedError(f"Unknown lang: {lang}")
+
+        def record_constraint_violation(
+            warn_only: bool,
+            debug_name: str,
+            msg: str,
+            hint: Optional[Callable[[], str]] = None,
+        ) -> None:
+            constraint_violations.append(
+                (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg)
+            )
+
+        def is_dim(src: object) -> TypeGuard[TensorPropertySource]:
+            return (
+                isinstance(src, TensorPropertySource)
+                and src.prop is TensorProperty.SIZE
+            )
+
+        if equalities_inputs:
+            source_index = {}
+            for i, src in enumerate(sources):
+                source_index[src.name] = i
+
+            def get_expression(tensor_dim_src: Source) -> sympy.Expr:
+                fake = placeholders[source_index[tensor_dim_src.base.name]]  # type: ignore[attr-defined]
+                assert tensor_dim_src.idx is not None  # type: ignore[attr-defined]
+                symint = fake.shape[tensor_dim_src.idx]  # type: ignore[attr-defined]
+                if isinstance(symint, torch.SymInt):
+                    return symint.node.expr
+                else:
+                    assert type(symint) is int, f"Expected int, got {type(symint)}"
+                    return sympy.Integer(symint)
+
+            for src1, src2 in equalities_inputs.source_pairs:
+                expr1, expr2 = get_expression(src1), get_expression(src2)  # type: ignore[]
+                # Check whether given input shape values satisfy a specified equation s = s'.
+                # - Raise when the equation was violated by the given input shape values.
+                # - Otherwise issue a guard to constrain them.
+                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
+                if not concrete_val:
+                    raise ConstraintViolationError(
+                        f"{src1.name} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
+                        " is not equal to "
+                        f"{src2.name} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
+                    )
+
+            for srcEq, root, fn in equalities_inputs.derived_equalities:
+                expr1 = get_expression(srcEq)
+                # recall that root is either a phantom symbol or an input source
+                if isinstance(root, sympy.Symbol):
+                    expr2, debug_name = root, self.var_to_sources[root][0].name
+                elif isinstance(root, sympy.Integer):
+                    expr2, debug_name = root, str(root)
+                else:
+                    expr2, debug_name = get_expression(root), self._debug_name(root)
+                expr2_ = fn(expr2)
+                # Check whether given input shape values satisfy a specified equation s = fn(s').
+                # - Raise when the equation was violated by the given input shape values.
+                # - Otherwise issue a guard to constrain them.
+                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
+                if not concrete_val:
+                    raise ConstraintViolationError(
+                        f"Expected input {srcEq.name} to be equal to "
+                        f"{fn(sympy.Symbol(debug_name))}, "
+                        f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, "
+                        f"but got {expr1.xreplace(self.var_to_val)}"
+                    )
+
+            for phantom_symbol in equalities_inputs.phantom_symbols:
+                if isinstance(phantom_symbol, sympy.Symbol):
+                    # we created additional phantom symbols that are not input shape dimensions
+                    symbol_to_source[phantom_symbol].extend(
+                        self.var_to_sources[phantom_symbol]
+                    )
+
+        # How do we know what the value of s0 is?  Fresh variables can only be
+        # bound by inputs, so there MUST be some other input which binds the
+        # variable.  If there is no such input, this is an error in our
+        # system.  We record where all symbols come from, to help you diagnose
+        # why those symbols didn't occur.
+        #
+        # In fact, generally speaking it is only possible for the "outermost"
+        # user of a ShapeEnv to evaluate the guards, because some inputs may
+        # not be available to inner levels.  For example, Dynamo can guard on
+        # tensors that never actually become graph arguments (they are
+        # pruned).  In this case, only Dynamo knows about these arguments.
+        def track_symint(
+            source: Source, val: IntLikeType, constraint: DimConstraint = None
+        ) -> None:
+            log.debug(
+                "track_symint %s %s %s",
+                LazyString(lambda: source.name),
+                val,
+                constraint,
+            )
+            assert not isinstance(val, SymInt) or is_symbolic(val)
+
+            if isinstance(val, SymInt) and val.node.maybe_as_int() is not None:
+                val = val.node.maybe_as_int()
+
+            if isinstance(val, SymInt):
+                s = val.node.expr
+                if isinstance(s, sympy.Symbol):
+                    symbol_to_source[s].append(source)
+                    if constraint is not None and not isinstance(
+                        constraint, RelaxedUnspecConstraint
+                    ):
+                        symbol_to_constraints[s].add(constraint)
+                else:
+                    constraint_violated = False
+                    if isinstance(constraint, StrictMinMaxConstraint):
+                        # try inferring the ranges of the expr s
+                        sym_vrs = {
+                            x: self.var_to_range.get(x, None) for x in s.free_symbols
+                        }
+                        if any(vr is None for vr in sym_vrs.values()):
+                            # some of the free symbols in s don't have ranges
+                            constraint_violated = True
+                    elif isinstance(constraint, RelaxedUnspecConstraint):
+                        if s.is_number:
+                            i = int(s)
+                            # Don't complain about 0/1 specialization, we
+                            # expect to have to compile in this case anyway
+                            if i not in (0, 1):
+                                constraint_violated = True
+                    if constraint_violated:
+                        assert constraint is not None
+
+                        def hint(s: sympy.Expr) -> str:
+                            sexpr = py_printer.doprint(s)
+                            return f"{sexpr}."
+
+                        var_with_range = self._render_range_for_constraint_violation(
+                            source, constraint
+                        )
+                        msg = (
+                            f"Not all values of {var_with_range} are valid because "
+                            f"{self._debug_name(source)} was inferred to be equal to "
+                        )
+                        record_constraint_violation(
+                            constraint.warn_only,
+                            self._debug_name(source),
+                            msg,
+                            hint=functools.partial(hint, s),
+                        )
+
+                input_guards.append((source, s))
+            else:
+                s = sympy.Integer(val)
+                input_guards.append((source, s))
+                constraint_violated = False
+                if isinstance(constraint, StrictMinMaxConstraint):
+                    if not (
+                        s == constraint.vr.lower == constraint.vr.upper
+                    ):  # allow static constraints
+                        constraint_violated = True
+                elif isinstance(constraint, RelaxedUnspecConstraint):
+                    # Don't complain about 0/1 specialization, we
+                    # expect to have to compile in this case anyway
+                    if val not in (0, 1):
+                        constraint_violated = True
+                if constraint_violated:
+                    assert constraint is not None
+                    var_with_range = self._render_range_for_constraint_violation(
+                        source, constraint
+                    )
+                    user_stack = self.specialization_stacks.get(source, None)
+                    msg = (
+                        f"You marked {self._debug_name(source)} as dynamic but your code "
+                        f"specialized it to be a constant ({val}). If you're using mark_dynamic, "
+                        f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, "
+                        f"replace it with either Dim.STATIC or Dim.AUTO."
+                        + (
+                            "\n\nUser stack:\n" + "".join(user_stack.format())
+                            if user_stack
+                            else ""
+                        )
+                    )
+                    record_constraint_violation(
+                        constraint.warn_only, self._debug_name(source), msg
+                    )
+
+        def track_symfloat(source: Source, val: FloatLikeType) -> None:
+            log.debug("track_symfloat %s %s", LazyString(lambda: source.name), val)
+            assert not isinstance(val, SymFloat) or is_symbolic(val)
+
+            if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None:
+                val = val.node.maybe_as_float()
+
+            if isinstance(val, SymFloat):
+                s = val.node.expr
+                if isinstance(s, sympy.Symbol):
+                    symbol_to_source[s].append(source)
+                input_guards.append((source, s))
+            else:
+                s = sympy.Float(val)
+                input_guards.append((source, s))
+
+        # pyrefly: ignore [no-matching-overload]
+        for t, source, context in zip(placeholders, sources, input_contexts):
+            if isinstance(source, str):
+                from torch._dynamo.source import LocalSource
+
+                source = LocalSource(source)
+            assert isinstance(source, Source)
+            if t is None:
+                continue
+            if isinstance(t, (SymInt, int)):
+                constraint = (
+                    None if context is None else getattr(context, "constraint", None)
+                )
+                track_symint(source, t, constraint)
+                continue
+            elif isinstance(t, (SymFloat, float)):
+                track_symfloat(source, t)
+                continue
+            assert isinstance(t, Tensorlike)
+            if is_traceable_wrapper_subclass(t):
+                from torch._dynamo.source import AttrSource
+
+                assert isinstance(context, SubclassSymbolicContext)
+
+                # For subclasses, we need to track symints on BOTH the outer
+                # and inner tensors.
+                # TODO: type this better
+                sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [
+                    (source, t, context.constraint_sizes, context.constraint_strides)
+                ]
+                attrs, _ = t.__tensor_flatten__()
+                for attr in attrs:
+                    inner_t = getattr(t, attr)
+                    inner_context = context.inner_contexts[attr]
+                    sources_tensors_constraints.append(
+                        (
+                            AttrSource(source, attr),
+                            inner_t,
+                            inner_context.constraint_sizes,  # type: ignore[attr-defined]
+                            inner_context.constraint_strides,  # type: ignore[attr-defined]
+                        )
+                    )
+            else:
+                sources_tensors_constraints = [
+                    (source, t, context.constraint_sizes, context.constraint_strides)  # type: ignore[attr-defined]
+                ]
+
+            for (
+                src,
+                curr_t,
+                constraint_size,
+                constraint_stride,
+            ) in sources_tensors_constraints:
+                if is_sparse_any(curr_t):
+                    for i, ss in enumerate(curr_t.size()):
+                        property_source = TensorPropertySource(
+                            src, TensorProperty.SIZE, i
+                        )
+                        track_symint(property_source, ss, constraint_size[i])
+                else:
+                    for i, ss in enumerate(curr_t.size()):
+                        property_source = TensorPropertySource(
+                            src, TensorProperty.SIZE, i
+                        )
+                        track_symint(property_source, ss, constraint_size[i])
+
+                    for i, ss in enumerate(curr_t.stride()):
+                        property_source = TensorPropertySource(
+                            src, TensorProperty.STRIDE, i
+                        )
+                        track_symint(property_source, ss, constraint_stride[i])
+                    track_symint(
+                        TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
+                        curr_t.storage_offset(),
+                    )
+
+        # 1. Every input must equal the final simplified symbolic expression
+        #    stored on the placeholder.  Given a placeholder (s0*2, s1),
+        #    if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
+        #    This does a lot of work: it covers duck sizing and equality guards.
+        all_exprs: list[list[str]] = [[] for _ in langs]
+
+        self.dim_constraints = DimConstraints(
+            symbol_to_source,
+            self.var_to_val,
+            set(symbol_to_constraints.keys()),
+            self.source_name_to_debug_name,
+        )
+
+        if not _simplified:
+            for source, expr in input_guards:
+                srcname = source.name
+                if self._translation_validation_enabled:
+                    # Ignore sources that were not turned into SymInts.
+                    if srcname in self.source_to_symbol:
+                        self._add_target_expr(
+                            sympy.Eq(self.source_to_symbol[srcname], expr)
+                        )
+
+                # Small optimization
+                if (
+                    isinstance(expr, sympy.Symbol)
+                    and symbol_to_source.get(expr)
+                    and source == symbol_to_source[expr][0]
+                ):
+                    continue
+
+                # This logic excludes static values found on tensors from guarding, because
+                # dynamo's check_tensor_fn does that (see guards.cpp).
+                # However, for non tensor sources, we still need to guard here.
+                if ignore_static and isinstance(source, TensorPropertySource):
+                    if expr.is_number:
+                        self.log.debug(
+                            "Skipping guard %s", f"{source_ref(source)} == {expr}"
+                        )
+                        continue
+
+                if is_dim(source):
+                    self.dim_constraints.add_equality(source, expr)
+
+                for exprs, printer, lang in zip(all_exprs, printers, langs):
+                    res = f"{printer.print_source(source)} == {printer.doprint(expr)}"
+
+                    if lang == "verbose_python":
+                        if (s0 := self.source_to_var.get(srcname)) is not None:
+                            if source != self.var_to_sources[s0][0]:
+                                res = (
+                                    f"{res}  # duck sizing added this equality because these "
+                                    f"variables had the same size {self.var_to_val[s0]} "
+                                    "(to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)"
+                                )
+                            elif (sloc := self.replacements_slocs.get(s0)) is not None:
+                                res = f"{res}  # {sloc}"
+                            else:
+                                res = f"{res}  # (unknown var {s0}, please file a bug)"
+                        else:
+                            res = f"{res}  # (unknown source {srcname}, please file a bug)"
+                    exprs.append(res)
+
+                if (
+                    isinstance(source, TensorPropertySource)
+                    and source.prop is TensorProperty.SIZE
+                    and equalities_inputs
+                    and len(expr.free_symbols) == 1
+                ):
+                    symbol = next(iter(expr.free_symbols))
+                    if (
+                        isinstance(expr, sympy.Symbol)
+                        and expr in symbol_to_constraints
+                        and not equalities_inputs.is_equal(
+                            source, symbol_to_source[expr][0]
+                        )
+                    ):
+                        msg = (
+                            f"The values of {self._debug_name(source)} = {source.name} and "
+                            f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name} "
+                            "must always be equal."
+                        )
+                        record_constraint_violation(
+                            equalities_inputs.warn_only, self._debug_name(source), msg
+                        )
+
+                    if (
+                        not isinstance(expr, sympy.Symbol)
+                        and symbol in symbol_to_constraints
+                        and not equalities_inputs.is_derived(
+                            source,
+                            symbol_to_source[symbol][0],
+                            lambda x: expr.xreplace({symbol: x}),
+                        )
+                    ):
+                        src = symbol_to_source[symbol][0]
+                        msg = (
+                            f"The values of {self._debug_name(source)} = {source.name} must always be related to "
+                            f"the values of {self._debug_name(src)} = {src.name} by "
+                            f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}."
+                        )
+                        record_constraint_violation(
+                            equalities_inputs.warn_only, self._debug_name(source), msg
+                        )
+
+                # NB: Not necessary to report constraint violations here:
+                # constraints are guaranteed to be on symbols (we've already
+                # caught constants and non-atomic expressions), so we only
+                # have relational constraints, but we don't support those
+                # at the moment
+
+        # 2. Every guard must evaluate to True (but remember many guards
+        #    like s0 == s1*2 because trivial due to simplification)
+        issued = set()
+
+        def issue_guard(guard: ShapeGuard) -> None:
+            expr = self.simplify(guard.expr)
+
+            # Avoid re-issuing the same guard.
+            if expr in issued:
+                return
+
+            issued.add(expr)
+
+            try:
+                is_trivial = False
+                if any(
+                    is_dim(source)
+                    for s in expr.free_symbols
+                    for source in symbol_to_source[s]
+                ):
+                    assert self.dim_constraints is not None
+                    is_trivial = self.dim_constraints.add(expr)
+
+                for exprs, printer, lang in zip(all_exprs, printers, langs):
+                    guard_expr = printer.doprint(expr)
+                    if lang == "verbose_python":
+                        guard_expr = f"{guard_expr}  # {guard.sloc}"
+                    exprs.append(guard_expr)
+
+                self._add_target_expr(expr)
+                # A non-relational constraint on a single sizevar can violate
+                # a constraint
+                if not is_trivial and len(expr.free_symbols) == 1:
+                    symbol = next(iter(expr.free_symbols))
+                    source = symbol_to_source[symbol][0]
+                    constraints = symbol_to_constraints[symbol]
+                    for c in constraints:
+                        if isinstance(c, StrictMinMaxConstraint):
+                            var_with_range = (
+                                self._render_range_for_constraint_violation(source, c)
+                            )
+                            msg = (
+                                f"Not all values of {var_with_range} "
+                                f"satisfy the generated guard {py_printer.doprint(expr)}."
+                            )
+                            record_constraint_violation(
+                                c.warn_only, self._debug_name(source), msg
+                            )
+                        elif isinstance(c, RelaxedUnspecConstraint):
+                            # This is fine, we allow guards here as long as it
+                            # didn't constrain it to one value  (we don't
+                            # actually know this; this depends on our
+                            # ValueRanges reasoning capability)
+                            pass
+                        else:
+                            raise AssertionError(f"unrecognized constraint {c}")
+            except Exception:
+                self.log.warning("Failing guard allocated at %s", guard.sloc)
+                raise
+
+        # First, issue all guards.
+        # This removes all the checks that follow from bounds
+        # We could simply emit those and also the bounds 2 <= size when necessary
+        for guard in guards if guards is not None else self.guards:
+            if (
+                self._maybe_evaluate_static(
+                    guard.expr, axioms=(), size_oblivious=guard.size_oblivious
+                )
+                is not None
+            ):
+                continue
+
+            issue_guard(guard)
+
+        # Because there are guards that export's constraint solver can suggest good fixes for, that we may have
+        # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards),
+        # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts,
+        # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide
+        # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph).
+        for ra in self.deferred_runtime_asserts.get(None, []):
+            if self._maybe_evaluate_static(ra.expr, axioms=()) is not None:
+                continue
+            expr = self.simplify(ra.expr)
+
+            self.dim_constraints.add(expr)
+
+        # 3. Every symbol must be within its value range (this handles 0/1
+        # specialization too).
+        for symbol, sources in symbol_to_source.items():
+            r = self.var_to_range.get(symbol)
+            if r is None:
+                continue
+            vr_sloc = self.var_to_range_sloc[symbol]
+
+            assert sources
+            bounds = []
+            rf = source_ref(sources[0])
+            verbose_expr = ""
+            if r.lower not in (-sympy.oo, -int_oo):
+                if any(is_dim(source) for source in sources):
+                    self.dim_constraints.add(sympy.Ge(symbol, r.lower))
+                # Only print lower bound in simplified mode if it is not the
+                # default
+                if not _simplified or r.lower != self._default_value_range().lower:
+                    bounds.append(sympy.Le(r.lower, symbol, evaluate=False))
+                verbose_expr = f"{r.lower} <= {rf}  # {vr_sloc.lower}"
+            if r.upper not in (sympy.oo, int_oo):
+                if any(is_dim(source) for source in sources):
+                    self.dim_constraints.add(sympy.Le(symbol, r.upper))
+                # nontrivial upper bound is always interesting
+                bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
+                if verbose_expr:
+                    verbose_expr = f"{r.lower} <= {rf} <= {r.upper}  # {vr_sloc.lower} and {vr_sloc.upper}"
+                else:
+                    verbose_expr = f"{rf} <= {r.upper}  # {vr_sloc.upper}"
+            if bounds:
+                bound = sympy.And(*bounds, evaluate=False)
+
+                for exprs, printer, lang in zip(all_exprs, printers, langs):
+                    if lang == "verbose_python":
+                        exprs.append(verbose_expr)
+                    else:
+                        exprs.append(printer.doprint(bound))
+                # NB: verbose_exprs are done above
+
+                # Check constraints
+                constraints = symbol_to_constraints[symbol]
+                for c in constraints:
+                    if isinstance(c, StrictMinMaxConstraint):
+                        # TODO: With int_oo, I think this condition is a noop
+                        # now
+                        if not (c.vr & self._default_value_range()).issubset(r):
+                            source = sources[0]
+
+                            expr = sympy.And(
+                                sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper)
+                            )
+                            guard_expr = py_printer.doprint(expr)
+                            var_with_range = (
+                                self._render_range_for_constraint_violation(source, c)
+                            )
+                            msg = f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}"
+                            record_constraint_violation(
+                                c.warn_only,
+                                self._debug_name(source),
+                                msg,
+                            )
+            # We NaN specialize, which means similar to 0/1 specialization we
+            # should assume that the float is NOT nan.  This is load bearing
+            # if you have something like an equality guard, nan will play
+            # merry hell with the reasoning.
+            if symbol_is_type(symbol, SymT.FLOAT):
+                res = f"not math.isnan({py_printer.print_source(sources[0])})"
+                for exprs, printer, lang in zip(all_exprs, printers, langs):
+                    if lang == "verbose_python":
+                        exprs.append(
+                            f"{res}  # implicit guard for float input due to NaN specialization in the framework"
+                        )
+                    elif lang == "python":
+                        exprs.append(res)
+                    elif lang == "cpp":
+                        exprs.append(f"~std::isnan({printer.print_source(sources[0])})")
+                    else:
+                        raise NotImplementedError(f"Unimplemented for lang: {lang}")
+
+        if constraint_violations:
+            warn_msgs: list[str] = []
+            error_msgs: list[str] = []
+            debug_names = set()
+            for warn_only, debug_name, msg_cb in constraint_violations:
+                if warn_only:
+                    str_msg = f"  {len(warn_msgs) + 1}. {msg_cb()}"
+                    warn_msgs.append(str_msg)
+                else:
+                    str_msg = f"  - {msg_cb()}"
+                    error_msgs.append(str_msg)
+                    # pyrefly: ignore [bad-argument-type]
+                    debug_names.add(debug_name)
+            if len(error_msgs) > 0:
+                debug_names_str = ", ".join(sorted(debug_names))
+                err = "\n".join(error_msgs)
+                raise ConstraintViolationError(
+                    f"Constraints violated ({debug_names_str})! "
+                    'For more information, run with TORCH_LOGS="+dynamic".\n'
+                    f"{err}"
+                )
+            elif len(warn_msgs) > 0:
+                log.debug("%s Warning only constraints violated", len(warn_msgs))
+
+        signpost_event(
+            "dynamic",
+            "produce_guards",
+            {
+                **self.co_fields,
+                **self.counter,
+                "num_guards": len(all_exprs[0]),
+                "free_symbols": sum(1 for v in symbol_to_source.values() if v),
+                # The keys are meaningless from an aggregate perspective, so
+                # don't include them.  Biggest first.
+                "symbol_guard_counts": sorted(
+                    self.symbol_guard_counter.values(), reverse=True
+                ),
+            },
+        )
+
+        if self._translation_validation_enabled:
+            from torch.fx.experimental.validator import PopulateValidator
+
+            # Add all deferred runtime assertions; these are not technically
+            # handled by produce_guards but we need to put them in the target
+            # set
+            for ras in self.deferred_runtime_asserts.values():
+                for ra in ras:
+                    self._add_target_expr(ra.expr)
+
+            # Add value range bound guards for all symbols with no trivial bounds.
+            # Reason: '_maybe_evaluate_static' may eliminate guards based on the
+            # refined value ranges.
+            for sym, vr in self.var_to_range.items():
+                if vr.lower not in (-sympy.oo, -int_oo):
+                    self._add_target_expr(sympy.Le(vr.lower, sym))
+                if vr.upper not in (sympy.oo, int_oo):
+                    self._add_target_expr(sympy.Le(sym, vr.upper))
+
+            # Before validating, populate the input of the validator with the
+            # built FX graph.
+            with fx_traceback.preserve_node_meta():
+                PopulateValidator(self.graph, self.validator).run()
+
+        # Only run translation validation when we are not passing custom guards
+        if guards is None:
+            self._check_translation_validate()
+
+        helpers: list[_ShapeGuardsHelper] = []
+        for exprs, printer, lang in zip(all_exprs, printers, langs):
+            if lang == "cpp":
+                assert isinstance(printer, _ShapeGuardCppPrinter)
+                helpers.append(_CppShapeGuardsHelper(exprs, printer.source_to_symbol))
+            else:
+                helpers.append(_ShapeGuardsHelper(exprs))
+        return helpers
+
+    def produce_guards_expression(
+        self,
+        placeholders: Sequence[Union[SymInt, FakeTensor]],
+        *,
+        guards: Optional[list[ShapeGuard]] = None,
+        ignore_static: bool = True,
+    ) -> Optional[str]:
+        """
+        Expected to be used with evaluate_guards_expression(). Produces the guards
+        for the given placeholders and returns a string expression to be evaluated
+        by evaluate_guards_expression given concrete values for the placeholders.
+        """
+        from torch._dynamo.source import LocalSource
+
+        arg_names = [f"t{i}" for i in range(len(placeholders))]
+        produced_guards = self.produce_guards(
+            placeholders,
+            [LocalSource(a) for a in arg_names],
+            guards=guards,
+            ignore_static=ignore_static,
+        )
+        if produced_guards:
+            return " and ".join(produced_guards)
+        return None
+
+    def evaluate_symexpr(self, code: str) -> Union[int, float, bool]:
+        """
+        To be used by compile_fx to evaluate symexprs
+        """
+        args = {str(e): val for e, val in self.var_to_val.items()}
+        return eval(code, SYMPY_INTERP, args)
+
+    def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]:
+        """
+        To be used by compile_fx to deserialize symexprs
+        """
+        args = {
+            str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None))
+            for e, val in self.var_to_val.items()
+        }
+        return eval(code, SYMPY_INTERP, args)
+
+    def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool:
+        """
+        Expected to be used with produce_guards_expression(). Evaluates an expression
+        generated by produce_guards_expression for the given concrete args.
+        """
+        arg_names = [f"t{i}" for i in range(len(args))]
+        return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+
+    def evaluate_guards_for_args(
+        self,
+        placeholders: Sequence[FakeTensor],
+        args: Sequence[Tensor],
+        *,
+        ignore_static: bool = True,
+    ) -> bool:
+        """Generate guards for a graph's placeholder values and evaluate the guards with args"""
+        code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
+        if code:
+            return self.evaluate_guards_expression(code, args)
+        return True
+
+    def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]:
+        """
+        Get a list of guards, but pruned so it only provides guards that
+        reference symints from the passed in input
+        """
+        # pyrefly: ignore [bad-assignment]
+        symints = {
+            s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
+        }
+        guards = [
+            g for g in self.guards if all(s in symints for s in g.expr.free_symbols)
+        ]
+        return guards
+
+    def bind_symbols(
+        self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor]
+    ) -> dict[sympy.Symbol, int]:
+        """
+        Given a paired list of placeholders (fake tensors with
+        symbolic sizes) and concrete arguments (regular tensors
+        with real sizes), returns a dictionary mapping each
+        symbol to its real value.  So for example, if you
+        have a placeholder with size (s0, s1), binding
+        (2, 4) to it will give you {s0: 2, s1: 4}.  This is
+        not guaranteed to bind ALL symbols in the ShapeEnv;
+        we can't bind a symbol if it doesn't occur in any placeholder,
+        and symbols that already have replacements won't get bindings.
+
+        This is a little duplicative with evaluate_guards but
+        it's different enough that it seemed cleanest to make
+        another copy.  This assumes the guards are already checked,
+        though if it's cheap we'll check for shenanigans
+        """
+        bindings: dict[sympy.Symbol, int] = {}
+
+        def bind_symint(arg: object, val: object) -> None:
+            if isinstance(val, SymInt):
+                assert isinstance(arg, int)
+                s = val.node.expr
+
+                if isinstance(s, sympy.Symbol):
+                    if s in bindings:
+                        assert bindings[s] == arg, f"{bindings[s]} != {arg}"
+                    else:
+                        bindings[s] = arg
+                elif isinstance(-s, sympy.Symbol):
+                    if -s in bindings:
+                        assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
+                    else:
+                        bindings[-s] = -arg
+
+        for t, arg in zip(placeholders, args):
+            if t is None:
+                continue
+            if isinstance(t, SymInt):
+                bind_symint(arg, t)
+                continue
+            assert isinstance(t, torch.Tensor)
+            for i, s in enumerate(t.size()):
+                bind_symint(arg.size(i), s)
+            for i, s in enumerate(t.stride()):
+                bind_symint(arg.stride(i), s)
+            bind_symint(arg.storage_offset(), t.storage_offset())
+
+        return bindings
+
+    def get_nontrivial_guards(self) -> list[SympyBoolean]:
+        """Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
+        return [
+            self.simplify(guard.expr)
+            for guard in self.guards
+            if self._maybe_evaluate_static(
+                guard.expr, axioms=(), size_oblivious=guard.size_oblivious
+            )
+            is None
+        ]
+
+    def format_guards(self, verbose: bool = False) -> str:
+        """Format this shape env's guard expressions with optional traceback info if verbose"""
+
+        return "\n".join(
+            f" - {guard.expr}{' ' + str(guard.sloc) if verbose else ''}"
+            for guard in self.guards
+        )
+
+    def bound_sympy(
+        self, expr: sympy.Expr, size_oblivious: bool = False
+    ) -> ValueRanges:
+        """Given a sympy expression, computes a ValueRanges bound for what values it can be"""
+        # TODO: maybe it's guaranteed x in is var_to_range?
+        var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
+        if size_oblivious:
+            # Clamp values of size-like variables
+            # NB: discarding the old upper bound in intentional, per
+            # https://github.com/pytorch/pytorch/pull/123675
+            for x in self.size_like & var_to_range.keys():
+                if var_to_range[x] is not None:
+                    # NB: do NOT set upper to 2 ** 48, we're using this solely
+                    # to determine if we can do size-like replacement, the
+                    # upper bound is irrelevant here
+                    var_to_range[x] = ValueRanges(2, int_oo)
+        return bound_sympy(expr, var_to_range)  # type: ignore[arg-type]
+
+    @_lru_cache
+    def get_axioms(
+        self,
+        symbols: Optional[tuple[sympy.Symbol]] = None,
+        compute_hint: bool = False,
+    ) -> tuple[SympyBoolean, ...]:
+        """
+        Given the symbols in an expression, it returns all the runtime asserts that have those symbols
+        concatenated with all the guards.
+        If symbols is None, it returns all the runtime asserts (and all the guards)
+        """
+        if symbols is None:
+            runtime_asserts = (
+                r.expr for rs in self.deferred_runtime_asserts.values() for r in rs
+            )
+        else:
+            runtime_asserts = (
+                r.expr
+                for s in symbols
+                if s not in self.var_to_val
+                for r in self.deferred_runtime_asserts.get(s, ())
+            )
+        guards: Iterator[SympyBoolean] = (g.expr for g in self.guards)
+        axioms: Iterator[SympyBoolean] = itertools.chain(guards, runtime_asserts)
+        if compute_hint:
+            axioms = (
+                canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms
+            )
+        return tuple(dict.fromkeys(axioms).keys())
+
+    @lru_cache(None)
+    def get_implications(
+        self, e: SympyBoolean
+    ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
+        """Given a expression, it returns a list of predicates that follow from it"""
+        equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
+
+        def add_expr(expr: SympyBoolean) -> None:
+            expr = canonicalize_bool_expr(expr)
+            if isinstance(expr, (sympy.Eq, sympy.Ne)):
+                # No need to canonicalize
+                # TODO We could further canonicalize Eq ordering the lhs and rhs somehow
+                # With this, we could remove the need for the commutativity part
+                opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne
+                # Commutativity of == and !=
+                equiv[type(expr)(expr.lhs, expr.rhs, evaluate=False)] = sympy.true
+                equiv[type(expr)(expr.rhs, expr.lhs, evaluate=False)] = sympy.true
+                equiv[opposite(expr.lhs, expr.rhs, evaluate=False)] = sympy.false
+                equiv[opposite(expr.rhs, expr.lhs, evaluate=False)] = sympy.false
+            else:
+                # Expr and negation
+                equiv[expr] = sympy.true
+                # we do not pass evaluate=False like others on purpose here!
+                # we want not(a=b and not ~(a Optional[sympy.Basic]:
+        """
+        Tries to evaluate expr without introducing guards
+
+        If unbacked_only == True, then we only do substitutions on
+        unbacked SymInts (leaving regular hinted integers alone).  This could
+        result in an expression that still contains backed SymInts, which you
+        could then potentially guard on.
+
+        Use compute_hint == True if you are trying to compute a non-binding
+        hint for the particular hint values of backed and unbacked SymInts,
+        e.g., if s0 happens to be 3 this run, compute_hint will substitute s0 with 3.
+        """
+
+        # axioms with compute hint NYE
+        assert not compute_hint or not axioms
+        expr = self.simplify(expr, size_oblivious)
+
+        if compute_hint:
+            expr = expr.xreplace(self.var_to_val).xreplace(self.unbacked_var_to_val)
+
+        expr = canonicalize_bool_expr(expr)
+
+        def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
+            if not self._resimplify_floor_div_axioms:
+                return
+            self._resimplify_floor_div_axioms = False
+            new_items = {}
+            for k, v in list(axioms.items()):
+                # A FloorDiv in implications could have became CleanDiv at this point, due to new facts
+                # to the shapeEnv. This handles such issue but its not ideal. This is the only expression
+                # simplification that depends on the global state of shape env.
+                # TODO try to get rid of CleanDiv since it breaks the invariant that's simplifications of sympy
+                # expressions only depend on the expression itself.
+                if k.has(FloorDiv):
+                    new_items.update({self.simplify(k): v})
+            axioms.update(new_items)
+
+        # Pattern matching
+        if axioms is None:
+            resimplify_floor_div(self.axioms)
+            subst = self.axioms
+        else:
+            subst = {}
+            for e in axioms:
+                if e.free_symbols.issubset(expr.free_symbols):
+                    subst.update(dict(self.get_implications(self.simplify(e))))
+
+            resimplify_floor_div(subst)
+
+        expr = expr.xreplace(subst)
+        # TODO: compute hint might have gotten broken here
+
+        fs = expr.free_symbols
+
+        if not fs and (expr.is_number or expr.is_Boolean):
+            return expr
+
+        if var_to_range is None:
+            var_ranges = self.var_to_range
+        else:
+            var_ranges = dict(var_to_range)
+
+        symbol_info = tuple(
+            _SymbolInfo(
+                s,
+                var_ranges.get(s),
+                self.var_to_val.get(s),
+                s in self.size_like,
+            )
+            for s in sorted(fs, key=str)  # TODO: speed up sort?
+        )
+
+        r = _maybe_evaluate_static_worker(
+            expr, symbol_info, unbacked_only, size_oblivious
+        )
+        return r
+
+    @_lru_cache
+    def replace(self, expr: _SympyT) -> _SympyT:
+        """
+        Apply symbol replacements to any symbols in the given expression.
+        """
+        replacements = {}
+        # pyrefly: ignore [missing-attribute]
+        for s in expr.free_symbols:
+            r = self._find(s)
+
+            # Micro-optimization: only do replacements if r and s are different
+            # Otherwise, xreplace is not a no-op and will trigger expensive
+            # assumption queries if expr has a relational node.
+            if not r.is_Symbol or r != s:
+                replacements[s] = r
+        if replacements:
+            # pyrefly: ignore [missing-attribute]
+            return safe_expand(expr.xreplace(replacements))
+        else:
+            return expr
+
+    @_lru_cache
+    def _update_divisible(self) -> None:
+        new_divisible = set()
+        for k in self.divisible:
+            res = self.replace(k)
+            if not res.is_number:
+                new_divisible.add(k)
+
+        self.divisible = new_divisible
+        self._update_version_counter()
+
+    @_lru_cache
+    def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT:
+        """Use known constraints and replacements to simplify the given expr"""
+        expr = safe_expand(expr)
+        expr = self.replace(expr)
+
+        # Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced
+        # expression when creating contiguous strides.
+        if not size_oblivious:
+            min_max_replacements = {}
+            for atom in expr.atoms(Max):  # type: ignore[has-type]
+                if len(atom.args) > 2:
+                    continue
+                a, b = atom.args
+                if b == 1 or b == 0:
+                    a, b = b, a
+
+                if a == 1 and self._maybe_evaluate_static(sympy.Ge(b, 1)):
+                    min_max_replacements[atom] = b
+                if a == 0 and self._maybe_evaluate_static(sympy.Ge(b, 0)):
+                    min_max_replacements[atom] = b
+            if min_max_replacements:
+                expr = expr.xreplace(min_max_replacements)
+
+        if expr.has(TruncToInt):
+            trunc_replacements = {}
+            for atom in expr.atoms(TruncToInt):
+                if isinstance(atom.args[0], IntTrueDiv):
+                    base, divisor = atom.args[0].args
+                    if base % divisor == 0:
+                        trunc_replacements[atom] = CleanDiv(base, divisor)
+                    else:
+                        # TruncToInt(IntTrueDiv(a,b)) == FloorDiv(a, b)
+                        trunc_replacements[atom] = FloorDiv(base, divisor)
+            if trunc_replacements:
+                expr = expr.xreplace(trunc_replacements)
+
+        # TODO it would seem that this pass is not necessary given the
+        # below replacement of // with /, but for nested FloorDivs
+        # the non-recursive replacement doesn't work, and
+        # recursive makes it hard to look up divisibility,
+        # because existing divisibility info has FloorDiv in it, not /
+        # for now just do a separate pass to catch common nested case
+        if expr.has(FloorDiv):
+            self._update_divisible()
+            div_replacements = {}
+            for atom in expr.atoms(FloorDiv):
+                base, divisor = atom.args
+                if isinstance(divisor, FloorDiv):
+                    base1, divisor1 = divisor.args
+                    if (
+                        self.replace(Mod(base, divisor)) in self.divisible
+                        and base == base1
+                        and self.replace(Mod(base1, divisor1)) in self.divisible
+                    ):
+                        div_replacements[atom] = divisor1
+            if div_replacements:
+                expr = expr.xreplace(div_replacements)
+                expr = safe_expand(expr)
+        if expr.has(FloorDiv):
+            div_replacements = {}
+            pows = expr.atoms(sympy.Pow)
+            rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
+            for fd in expr.atoms(FloorDiv):
+                base, divisor = fd.args
+                if self.replace(Mod(base, divisor)) in self.divisible:
+                    div_replacements[fd] = CleanDiv(base, divisor)
+            if div_replacements:
+                new_expr = expr.xreplace(div_replacements)
+                new_expr = safe_expand(new_expr)
+                new_pows = new_expr.atoms(sympy.Pow)
+                new_rationals = new_expr.atoms(sympy.Rational).difference(
+                    new_expr.atoms(sympy.Integer)
+                )
+                # divisions simplified away
+                if new_pows.issubset(pows) and new_rationals.issubset(rationals):
+                    expr = new_expr
+        return expr
+
+    # TODO: overload for allow_none literal
+    @lru_cache(256)
+    def size_hint(
+        self, expr: sympy.Basic, *, allow_none: bool = False
+    ) -> Optional[sympy.Basic]:
+        """
+        Gets a size hint for a given expression from the underlying shapes we had.
+        Does not introduce a guard, so only use this when you can guarantee that
+        your code is still valid for arbitrary shapes (such as optimization decisions)
+        """
+        result_expr = safe_expand(expr).xreplace(self.var_to_val)
+        if not result_expr.is_number:
+            from torch.utils._sympy.singleton_int import SingletonInt
+
+            if isinstance(result_expr, SingletonInt):
+                return None
+            r = self._maybe_evaluate_static(result_expr, compute_hint=True)
+            if r is not None:
+                return r
+            if allow_none:
+                return None
+
+            if self.oblivious_var_to_val:
+                # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
+                correct_hint = result_expr.xreplace(self.oblivious_var_to_val)
+                counterfactual_hint = result_expr.xreplace(
+                    {k: max(v, 2) for k, v in self.oblivious_var_to_val.items()}
+                )
+                if (
+                    not correct_hint.free_symbols
+                    and not counterfactual_hint.free_symbols
+                ):
+                    if correct_hint == counterfactual_hint:
+                        log.info("oblivious_size hit %s -> %s", expr, correct_hint)
+                        return correct_hint
+                    else:
+                        log.info(
+                            "oblivious_size counterfactual failed %s -> %s != %s",
+                            expr,
+                            correct_hint,
+                            counterfactual_hint,
+                        )
+                else:
+                    log.info(
+                        "oblivious_size miss %s -> %s (counterfactual: %s)",
+                        expr,
+                        correct_hint,
+                        counterfactual_hint,
+                    )
+
+            if self.unbacked_var_to_val:
+                unsound_expr = result_expr.xreplace(self.unbacked_var_to_val)
+                if not unsound_expr.free_symbols:
+                    log.warning(
+                        "propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr
+                    )
+                    trace_structured(
+                        "propagate_real_tensors",
+                        metadata_fn=lambda: {
+                            "expr": repr(expr),
+                            "result": repr(unsound_expr),
+                            "stack": structured.from_traceback(
+                                CapturedTraceback.extract(skip=1).summary()
+                            ),
+                        },
+                    )
+                    self.guard_or_defer_runtime_assert(
+                        sympy.Eq(result_expr, unsound_expr),
+                        f"propagate_real_tensors: {result_expr} == {unsound_expr}",
+                    )
+                    return unsound_expr
+
+            raise self._make_data_dependent_error(result_expr, expr)
+        return result_expr
+
+    # NB: keep in sync with size_hint
+    @lru_cache(256)
+    def has_hint(self, expr: sympy.Expr) -> bool:
+        result_expr = safe_expand(expr).xreplace(self.var_to_val)
+        return (
+            result_expr.is_number
+            or self._maybe_evaluate_static(result_expr) is not None
+        )
+
+    def _make_data_dependent_error(
+        self,
+        expr: sympy.Basic,
+        unhinted_expr: sympy.Basic,
+        *,
+        expr_sym_node_id: Optional[int] = None,
+    ) -> GuardOnDataDependentSymNode:
+        # TODO: in a Dynamo context, having user code, and having the
+        # name of the local, will be much better
+        size_like_symbols = []
+        for s in expr.free_symbols:
+            stacktrace = "".join(self.var_to_stack[s].format())
+            self.log.debug(
+                "Data dependent variable '%s' allocated at:\n%s", s, stacktrace
+            )
+            if s in self.size_like:
+                size_like_symbols.append(s)
+        size_oblivious_result_msg = ""
+        sloc, maybe_extra_debug = self._get_stack_summary(True)
+        if expr.is_integer:  # type: ignore[attr-defined]
+            desc = (
+                "Could not extract specialized integer from data-dependent expression"
+            )
+        else:
+            desc = "Could not guard on data-dependent expression"
+            size_oblivious_result_msg = (
+                "consider using data-dependent friendly APIs such as "
+                "guard_or_false, guard_or_true and statically_known_true."
+            )
+
+        msg = (
+            f"{desc} {expr} (unhinted: {unhinted_expr}).  "
+            f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
+            f"{size_oblivious_result_msg}\n"
+            f"Caused by: {sloc}\n"
+            'For more information, run with TORCH_LOGS="dynamic"\n'
+            "For extended logs when we create symbols, also add "
+            f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n'
+            "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
+            "For more debugging help, see "
+            "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n"
+            + maybe_extra_debug
+            # TODO: Help text about how to use our runtime tests to fix this
+            # problem
+        )
+
+        dtrace_structured(
+            "guard_on_data_dependent_error",
+            metadata_fn=lambda: {
+                "expr": repr(expr),
+                "unhinted_expr": repr(unhinted_expr),
+                "expr_id": self._expr_sym_node_id,
+                "stack": structured.from_traceback(
+                    CapturedTraceback.extract(skip=1).summary()
+                ),
+            },
+        )
+        return GuardOnDataDependentSymNode(expr, msg)
+
+    def _update_var_to_range(
+        self,
+        symbol: sympy.Symbol,
+        vr: ValueRanges,
+        vr_sloc: Optional[ValueRangesSLoc] = None,
+        *,
+        is_constraint: bool = False,
+    ) -> None:
+        lower, upper = vr.lower, vr.upper
+
+        # If we have a size-like unbacked SymInt, refuse to refine the range to be
+        # less than two.  This is because when we intersect this range
+        # with [2, inf] for size oblivious tests, the range would be
+        # unsatisfiable.  In other words, once you have a size-like
+        # unbacked SymInt, we can never learn that it is exactly zero or one,
+        # because we would now give inconsistent results for all size
+        # oblivous tests!
+        if upper < 2 and symbol in self.size_like:
+            vr = ValueRanges(lower, 2)
+
+        # Updates the range and the guards corresponding to each bound of the symbol.
+        if symbol not in self.var_to_range:
+            self.log.debug("_update_var_to_range %s = %s (new)", symbol, vr)
+            self.var_to_range[symbol] = vr
+            if vr_sloc is None:
+                sloc = self._get_sloc()
+                vr_sloc = ValueRangesSLoc(sloc, sloc)
+            self.var_to_range_sloc[symbol] = vr_sloc
+        else:
+            old = self.var_to_range[symbol]
+            new = old & vr
+            if new != old:
+                if vr_sloc is None:
+                    sloc = self._get_sloc()
+                    vr_sloc = ValueRangesSLoc(sloc, sloc)
+                if new.lower != old.lower:
+                    self.var_to_range_sloc[symbol].lower = vr_sloc.lower
+                if new.upper != old.upper:
+                    self.var_to_range_sloc[symbol].upper = vr_sloc.upper
+                self.var_to_range[symbol] = new
+                self.log.debug("_update_var_to_range %s = %s (update)", symbol, new)
+
+        if (v := self.var_to_val.get(symbol)) is not None:
+            r = self.var_to_range[symbol]
+            if v not in r:
+                # For constraint failure, delay this for later
+                # TODO: Rework all of this, the constraint logic is very
+                # duplicative with regular reasoning
+                if not is_constraint:
+                    assert v in r, f"{v} not in {r}"
+
+    def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None:
+        """
+        Adds or updates a replacement for a symbol.
+        Use this instead of `self.replacements[a] = tgt`.
+        """
+
+        if tgt == self.replacements.get(a, None):
+            return
+
+        if a in tgt.free_symbols:
+            return
+
+        # Precondition: a == tgt
+        assert isinstance(a, sympy.Symbol)
+
+        if (
+            self.prefer_deferred_runtime_asserts_over_guards
+            and not _is_supported_equivalence(tgt)
+        ):
+            return  # continuing leads to placeholder shapes having complex expressions that we can't resolve
+
+        # Handles nested tensor symbolic variables which don't have
+        # var_to_range bounds
+        tgt_bound = None
+        if a in self.var_to_range:
+            src_bound = self.var_to_range[a]
+
+            # First, refine the value range of a based on the computed value range
+            # of tgt.  This is always OK to do, even if we decide not to do the
+            # substitution in the end.  This might be a no-op, if a already has
+            # a tighter bound
+            tgt_bound = self.bound_sympy(tgt)
+            self._update_var_to_range(a, tgt_bound)
+
+            # Next, check if we can update the range of free symbols in tgt
+            # based on the range in a. But only do it if:
+            #  - the source bound non-trivially improves over what we get out of
+            #    the existing bounds.
+            #  - the replacement is univariate and we can invert the tgt expression
+            if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
+                b = next(iter(tgt.free_symbols))
+                # Try to invert the equality
+                r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
+                if r is not None:
+                    self.log.debug(
+                        "set_replacement: solve for %s in %s == %s gives %s",
+                        b,
+                        a,
+                        tgt,
+                        r,
+                    )
+                    # The solution here can be non-integral, for example, if
+                    # we have s0 = 2*s1, then s1 = s0/2.  What we would like
+                    # to do is calculated the bounds in arbitrary precision,
+                    # and then requantize the bound to integers when we are
+                    # done.
+                    rat_b_bound = self.bound_sympy(r[1])
+                    b_bound = ValueRanges(
+                        CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)
+                    )
+                    self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
+                    tgt_bound = self.bound_sympy(tgt)
+                    assert tgt_bound.issubset(src_bound), (
+                        f"{tgt_bound=} not a subset of {src_bound=}"
+                    )
+
+            # TODO: Should we propagate size-like-ness?
+            #
+            # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
+            # to become size-like.
+            #
+            # Cons: if u0 is size-like, what about u0 - 1 == u1?  You CAN'T
+            # propagate in this case, because what if u0 == 0, then u1 is negative
+            # and clearly isn't a size.  So, at minimum, any f(x) whose value
+            # range isn't [0, inf] given x in [0, inf] cannot propagate
+            # size-like-ness.  But there are many situations where you could
+            # imagine u1 is going to be size-like and actually you just didn't
+            # have a refined enough value range on u0.  Since even innocuous
+            # looking arithmetic operations can destroy size-like-ness, it's
+            # best to not propagate it at all and force the user to annotate it
+            # as necessary.
+            #
+            # Compromise: we preserve size-like-ness only for exact equality
+            # and nothing else.
+            if a in self.size_like and isinstance(tgt, sympy.Symbol):
+                self.size_like.add(tgt)
+            elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
+                self.size_like.add(a)
+
+            # Now, decide if we will do the substitution.
+            #
+            #  - If the source has a non-trivial range, only substitute if
+            #    we preserve this range.  Note that we may have propagated
+            #    the src_range to free variables in tgt when tgt is univariate
+            #    and we could find an inverse, which helps us achieve this.
+            #    This ensures we never "forget" about user defined ranges,
+            #    even if they end up being defined on composite formulas
+            #    like s0 + s1.
+            #
+            #  - If the variable is unbacked, only substitute if the substitution
+            #    would preserve the bounds also under size-like-ness conditions.
+
+            if not tgt_bound.issubset(src_bound):
+                self.log.debug(
+                    "skipped set_replacement %s = %s (%s) [%s not subset of %s]",
+                    a,
+                    tgt,
+                    msg,
+                    tgt_bound,
+                    src_bound,
+                )
+                return
+            elif a in self.size_like:
+                tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
+                src_bound_so = self.bound_sympy(a, size_oblivious=True)
+                if not tgt_bound_so.issubset(src_bound_so):
+                    self.log.debug(
+                        "skipped set_replacement %s = %s (%s) "
+                        "[%s not subset of %s (size-oblivious conditions)]",
+                        a,
+                        tgt,
+                        msg,
+                        tgt_bound_so,
+                        src_bound_so,
+                    )
+                    return
+
+        if isinstance(tgt, (sympy.Integer, sympy.Float)):
+            # specializing to a constant, which is likely unexpected (unless
+            # you specified dynamic=True)
+
+            user_tb = TracingContext.extract_stack()
+            trace_structured(
+                "symbolic_shape_specialization",
+                metadata_fn=lambda: {
+                    "symbol": repr(a),
+                    "sources": [s.name for s in self.var_to_sources.get(a, [])],
+                    "value": repr(tgt),
+                    "reason": msg,
+                    "stack": structured.from_traceback(
+                        CapturedTraceback.extract(skip=1).summary()
+                    ),
+                    "user_stack": (
+                        structured.from_traceback(user_tb) if user_tb else None
+                    ),
+                },
+            )
+
+            for source in self.var_to_sources.get(a, []):
+                if user_tb:
+                    self.specialization_stacks[source] = user_tb
+
+            if config.print_specializations:
+                self.log.warning(
+                    "Specializing %s to %s", self.var_to_sources[a][0].name, tgt
+                )
+                self.log.debug("SPECIALIZATION", stack_info=True)
+        log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
+        self.replacements[a] = tgt
+        # NB: the replacement may get refined, but the user will find the
+        # FIRST one most useful (TODO: Maybe we could consider tracking all of
+        # them)
+        if a not in self.replacements_slocs:
+            self.replacements_slocs[a] = self._get_sloc()
+        self._update_version_counter()
+
+        # When specializing 'a == tgt', the equality should be also conveyed to
+        # Z3, in case an expression uses 'a'.
+        self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))
+
+    def _add_divisible(self, expr: sympy.Expr) -> None:
+        self.divisible.add(expr)
+        self._update_version_counter()
+
+    @_lru_cache
+    @record_shapeenv_event()
+    def _find(self, a: sympy.Symbol) -> sympy.Expr:
+        """
+        Implements a DSU-like algorithm to find the variable that represents a
+        Also handles transitive non-identity replacements.
+
+        a: b + c
+        c: d
+        """
+        if a not in self.replacements:
+            return a
+        res = self.replacements[a]
+        cur_replace = {s: self._find(s) for s in res.free_symbols}
+        replaced, changed = self.replacements[a]._xreplace(cur_replace)
+        if changed:
+            self._set_replacement(a, replaced, "find")
+        return self.replacements[a]
+
+    @lru_cache(256)
+    def _maybe_guard_rel(self, expr: sympy.Expr) -> None:
+        """
+        The relational guard is guarded to be true.  Use this information to
+        simplify shapes (i.e. a == b or a % 5 == 0)
+        """
+        if isinstance(expr, sympy.And):
+            for arg in expr.args:
+                self._maybe_guard_rel(arg)
+            return
+        elif not isinstance(expr, sympy.Rel):
+            return
+
+        # A good example of what goes wrong if you don't do this is
+        # python test/functorch/test_aotdispatch.py -k
+        # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
+        if isinstance(expr, sympy.Ne):
+            return
+
+        free = list(expr.free_symbols)
+
+        assert len(free) > 0, (
+            f"The expression should not be static by this point: {expr}"
+        )
+        # In case of really gnarly expression, we don't blow up
+        if len(free) > 5:
+            return
+
+        # Prioritize unbacked symints for solving by ordering them last.
+        # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
+        #   (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
+        # Prefer to simplify out symbols with ephemeral sources.
+        def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]:
+            has_only_ephemeral_sources = x in self.var_to_sources and all(
+                s.is_ephemeral() for s in self.var_to_sources[x]
+            )
+            # NB: size_hint is int, not sympy.Expr, do not use int_oo here
+            hint_size = self.size_hint(x, allow_none=True)
+            if hint_size is None:
+                size = sys.maxsize
+            elif symbol_is_type(x, SymT.SIZE):
+                assert isinstance(hint_size, sympy.Expr)
+                size = int(hint_size)
+            else:
+                size = sys.maxsize
+            name = x.name
+            # 1 puts ephemeral sourced symbols first when sorting in reverse
+            return (1 if has_only_ephemeral_sources else 0, size, name)
+
+        free = sorted(free, key=_smart_symbol_sort, reverse=True)  # type: ignore[attr-defined]
+        lhs = expr.lhs
+        rhs = expr.rhs
+
+        self._refine_ranges(expr)
+
+        # The rest of this stuff is for equality only
+        if not isinstance(expr, sympy.Eq):
+            return
+
+        if not expr.has(Mod):
+            try:
+                floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
+                if len(floor_div_atoms) > 0 and any(
+                    a.divisor != 1 for a in floor_div_atoms
+                ):
+                    raise NotImplementedError
+
+                # Never replace unbacked symbols with other unbacked symbols that are
+                # not function arguments. (ex:mark_unbacked symbols are fine to replace
+                # other unbacked, but not those coming from .item() calls).
+
+                # This is error prone because you can cause references to
+                # unbacked symbols to time travel backwards.  E.g.,
+                #
+                # u1 = x.item()
+                # ... use of u1 ...
+                # u2 = y.item()
+                # u3 = z.item()
+                # torch._check(u1 == u2 + u3)
+                #
+                # If you replace u1 with u2 + u3, then the use of u1 now
+                # references u2 and u3 prior to them actually being bound at
+                # runtime.  It's pretty inconvenient to setup control
+                # dependencies for substitutions, so ban it entirely.
+                def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool:
+                    if isinstance(lhs, sympy.Symbol):
+                        if free_unbacked_symbols(
+                            lhs
+                        ) and not _free_non_source_unbacked_symbols(
+                            rhs, self.unbacked_inputs
+                        ):
+                            return True
+                        if symbol_is_type(lhs, SymT.FLOAT):
+                            return True
+                        # TODO: Maybe trivial solutions for int should also be
+                        # done?
+                    return False
+
+                # short-circuit when no solving is needed
+                if trivial_solve(lhs, rhs):
+                    self._set_replacement(lhs, self._find(rhs), "trivial_lhs")
+                elif trivial_solve(rhs, lhs):
+                    self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
+                else:
+                    r = try_solve(expr, free[0], floordiv_inequality=False)
+                    if r is not None and all(
+                        t.is_integer for t in sympy.preorder_traversal(r[1])
+                    ):
+                        new_var = self._find(r[1])
+                        ok = len(free_unbacked_symbols(new_var)) == 0
+                        if ok:
+                            self._set_replacement(free[0], new_var, "solve")
+
+            except NotImplementedError:
+                pass
+        else:
+            # expression has mod.
+            mod_expr = next(iter(expr.atoms(Mod)))
+            try:
+                r = try_solve(expr, mod_expr, floordiv_inequality=False)
+                if r is not None and r[1] == 0:
+                    self._add_divisible(mod_expr)
+            except NotImplementedError:
+                pass
+        return
+
+    # See: Note - On 0/1 specialization
+    def _default_value_range(
+        self, do_not_specialize_zero_one: bool = False
+    ) -> ValueRanges:
+        lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2
+        return ValueRanges(lower, int_oo)
+
+    def _default_unspecified_value_range(self) -> ValueRanges:
+        return ValueRanges.unknown_int()
+
+    @_lru_cache
+    def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr:
+        floor_divs = tuple(expr.atoms(FloorDiv))
+        # we expect floor_divs to be exact,
+        # and thus add the guards for the exact floordivs,
+        # even if tracing doesn't require them otherwise
+        for fd in reversed(floor_divs):
+            base, divisor = fd.args
+            mod_expr = Mod(base, divisor)
+            eq_expr = sympy.Eq(mod_expr, 0)
+            # add necessary mod guards
+            self.evaluate_expr(eq_expr)
+        return self.simplify(expr)
+
+    # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen
+    # and if so issue a warning
+    def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None:
+        if self.frozen:
+            self.counter["ignored_backward_guard"] += 1
+            signpost_event(
+                "dynamic",
+                "evaluate_expr_frozen",
+                {
+                    **self.co_fields,
+                    "ignored_guard": f"{expr} == {concrete_val}",
+                    # no version = original state (this signpost is expected)
+                    # version 2 = dynamic backwards is eagerly compiled
+                    "version": 2,
+                },
+            )
+            log.info(
+                "Ignored guard %s == %s, this could result in accuracy problems",
+                expr,
+                concrete_val,
+                # only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic")
+                stack_info=log.getEffectiveLevel() < logging.WARNING,
+            )
+
+    def _get_user_frame(self) -> Optional[types.FrameType]:
+        frame = inspect.currentframe()
+        while frame is not None:
+            if frame.f_code.co_filename not in uninteresting_files():
+                return frame
+            frame = frame.f_back
+        return frame
+
+    def _get_stack_summary(
+        self, is_debug: bool = False, framework_loc: Optional[str] = None
+    ) -> tuple[SLoc, str]:
+        floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc
+        if floc is None:
+            frame = self._get_user_frame()
+            try:
+                if frame is not None:
+                    floc = traceback.FrameSummary(
+                        frame.f_code.co_filename,
+                        frame.f_lineno,
+                        frame.f_code.co_name,
+                    )
+            finally:
+                del frame
+
+        # NB: this stack is truncated, but it's fine because the main
+        # stack_info will give you the rest of the info you need
+        maybe_user_loc = None
+        user_tb = TracingContext.extract_stack()
+        if user_tb:
+            idx = len(user_tb) - 1
+            while idx > 0 and user_tb[idx].filename in uninteresting_files():
+                idx -= 1
+            maybe_user_loc = format_frame(user_tb[idx], line=True)
+
+        maybe_extra_debug = ""
+        if is_debug and user_tb:
+            maybe_extra_debug = (
+                "\nUser Stack (most recent call last):\n"
+                + "  (snipped, see stack below for prefix)\n"
+                + "".join(traceback.format_list(user_tb))
+            )
+        if is_debug and config.extended_debug_cpp:
+            cpp_stack = CapturedTraceback.extract(cpp=True)
+            maybe_extra_debug += "\nC++ stack trace:\n" + "".join(cpp_stack.format())
+        elif is_debug:
+            maybe_extra_debug += (
+                "\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1"
+            )
+
+        return SLoc(floc, maybe_user_loc), maybe_extra_debug
+
+    # Pass in framework_loc to override the framework location info
+    def _get_sloc(self, framework_loc: Optional[str] = None) -> SLoc:
+        sloc, _ = self._get_stack_summary(framework_loc=framework_loc)
+        return sloc
+
+    def _generate_unique_id(self, source_name: str) -> int:
+        attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100
+        while attempt in self.unique_ids:
+            attempt += 1
+        self.unique_ids.add(attempt)
+        return attempt
+
+    def _find_frame_locals(self) -> _FrameLocalResult:
+        """
+        Given the current user code frame, finds the relevant lines of code,
+        values of symbolic locals, and free symbols involved.
+        """
+        frame_locals: dict[str, Any] = {}
+        frame_symbols: dict[str, str] = {}
+
+        if (
+            frame := _find_user_code_frame()
+        ) is None or frame.f_code.co_filename == "":
+            return _FrameLocalResult()
+
+        # find bytecode instructions relevant to the frame
+        instructions = list(dis.Bytecode(frame.f_code))
+        co_lines, offset = inspect.getsourcelines(frame.f_code)
+        start, end, cur = None, None, None
+        # pyrefly: ignore [bad-assignment]
+        for i, instr in enumerate(instructions):
+            if instr.starts_line is not None:
+                cur = instr.starts_line
+            if cur != frame.f_lineno:
+                continue
+            if start is None:
+                start = end = i
+            else:
+                end = i
+
+        if start is None or end is None:  # no instructions found
+            return _FrameLocalResult()
+
+        # track involved locals and free symbols
+        def go(x: Any) -> Optional[str]:
+            if isinstance(x, torch.Tensor):
+                for y in x.size():
+                    go(y)
+                for y in x.stride():
+                    go(y)
+                go(x.storage_offset())
+                return (
+                    f"Tensor(shape: {x.size()}, "
+                    f"stride: {x.stride()}, "
+                    f"storage_offset: {x.storage_offset()})"
+                )
+            elif isinstance(x, (SymBool, SymInt, SymFloat)):
+                for s in x.node.expr.free_symbols:
+                    if str(s) in frame_symbols:  # type: ignore[operator]
+                        continue
+                    if s in self.var_to_sources:
+                        frame_symbols[str(s)] = self.var_to_sources[s][0].name  # type: ignore[assignment]
+                return str(x)
+            return None
+
+        # go through instructions, seeing linenos & involved locals
+        last_lineno = frame.f_lineno
+        for instr in instructions[start : end + 1]:
+            if (lineno := instr.starts_line) is not None:
+                last_lineno = max(last_lineno, lineno)
+            if isinstance(instr.argval, str) and instr.argval in frame.f_locals:
+                flat_locals = pytree.tree_flatten(frame.f_locals[instr.argval])[0]
+                frame_locals[instr.argval] = [
+                    go(flat_local) for flat_local in flat_locals
+                ]
+
+        # store LOC
+        locs = co_lines[frame.f_lineno - offset : last_lineno + 1 - offset]
+        if not locs:
+            return _FrameLocalResult()
+
+        indent = len(locs[0]) - len(locs[0].lstrip())
+        frame_loc = "".join([loc[indent:] for loc in locs]).strip()  # type: ignore[assignment]
+        return _FrameLocalResult(
+            loc=frame_loc, locals=frame_locals, symbols=frame_symbols
+        )
+
+    def _log_guard(self, prefix: str, g: SympyBoolean, forcing_spec: bool) -> None:
+        dtrace_structured(
+            "guard_added",
+            metadata_fn=lambda: {
+                "expr": str(g),
+                "prefix": prefix,
+                "expr_node_id": self._expr_sym_node_id,
+                "user_stack": structured.get_user_stack(3),
+                "stack": structured.get_framework_stack(3),
+                "symbol_to_sources": {
+                    str(v): k
+                    for k, v in self.source_to_var.items()
+                    if v in g.free_symbols
+                },
+                "frame_locals": asdict(self._find_frame_locals()),
+            },
+        )
+        trace_structured(
+            "guard_added_fast",
+            metadata_fn=lambda: {
+                "expr": str(g),
+                "user_stack": structured.from_traceback(TracingContext.extract_stack()),
+                "stack": structured.from_traceback(
+                    CapturedTraceback.extract(skip=1).summary()
+                ),
+            },
+        )
+        if self.log.isEnabledFor(logging.INFO):
+            str_g = str(g)
+            is_debug = (
+                config.extended_debug_guard_added is not None
+                and str_g == config.extended_debug_guard_added
+            )
+            sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
+            maybe_more_info = ""
+            if not is_debug:
+                maybe_more_info = (
+                    ", for more info run with "
+                    f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"'
+                )
+            self.log.info(
+                "%s %s [guard added] %s%s%s",
+                prefix if not forcing_spec else f"{prefix} (forcing_spec)",
+                str_g,
+                sloc,
+                maybe_more_info,
+                maybe_extra_debug,
+                stack_info=is_debug,
+            )
+
+    # A local variable to evaluate_expr stored in the class to avoid
+    # using it for the lru_cache that is on top of it since it does
+    # not effect the results. When needed its read directly.
+    _expr_sym_node_id: Optional[int] = None
+
+    def evaluate_sym_node(
+        self,
+        sym_node: SymNode,
+        size_oblivious: bool = False,
+        fallback_value: Optional[bool] = None,
+    ) -> sympy.Basic:
+        """
+        Given a a SymNode, evaluates sym_node.expr, adding guards if necessary.
+        """
+
+        self._expr_sym_node_id = id(sym_node)
+        return self.evaluate_expr(
+            sym_node.expr,
+            sym_node.hint,
+            sym_node.fx_node,
+            size_oblivious,
+            fallback_value=fallback_value,
+        )
+
+    def _is_python_assert(self) -> bool:
+        # Check if this boolean is used in an assertion, bytecode pattern for
+        # assertions is pretty stable for Python 3.7--3.13, ported with minimal
+        # changes from torch/fx/proxy.py
+        # Bytecode pattern for `assert` statements:
+        #     TO_BOOL / COMPARE_OP  # Only for Python >= 3.13
+        #     POP_JUMP_IF_TRUE
+        #     LOAD_ASSERTION_ERROR
+        #     RAISE_VARARGS
+        frame = self._get_user_frame()
+        assert frame is not None
+
+        insts = list(dis.get_instructions(frame.f_code))
+        if sys.version_info >= (3, 11):
+            # For Python >= 3.11, instructions can be 2-4 bytes long.
+            from bisect import bisect_left
+
+            cur = bisect_left(insts, frame.f_lasti, key=lambda x: x.offset)
+        else:
+            # For Python <= 3.10, instructions are always 2 bytes.
+            cur = frame.f_lasti // 2
+
+        if sys.version_info >= (3, 13):
+            if insts[cur].opname in ("TO_BOOL", "COMPARE_OP"):
+                # Peek 1 instruction further.
+                cur += 1
+
+        assert_insts = torch._dynamo.symbolic_convert.get_assert_bytecode_sequence(
+            False
+        )
+
+        cur_insts = insts[cur + 1 : cur + 1 + len(assert_insts)]
+        cur_insts = [inst.opname for inst in cur_insts]
+        return cur_insts == assert_insts
+
+    def _log_real_tensor_propagation(
+        self, orig_expr: sympy.Basic, unsound_result: sympy.Basic
+    ) -> None:
+        log.warning(
+            "propagate_real_tensors evaluate_expr(%s) -> %s",
+            orig_expr,
+            unsound_result,
+        )
+        trace_structured(
+            "propagate_real_tensors",
+            metadata_fn=lambda: {
+                "expr": repr(orig_expr),
+                "result": repr(unsound_result),
+                "stack": structured.from_traceback(
+                    CapturedTraceback.extract(skip=1).summary()
+                ),
+            },
+        )
+        dtrace_structured(
+            "propagate_real_tensors_provenance",
+            metadata_fn=lambda: {
+                "expr": repr(orig_expr),
+                "result": repr(unsound_result),
+                "expr_node_id": self._expr_sym_node_id,
+                "user_stack": structured.get_user_stack(3),
+                "stack": structured.get_framework_stack(3),
+                "symbol_to_sources": {
+                    str(v): k
+                    for k, v in self.source_to_var.items()
+                    if v in orig_expr.free_symbols
+                },
+                "frame_locals": asdict(self._find_frame_locals()),
+            },
+        )
+
+    def evaluate_expr(
+        self,
+        orig_expr: sympy.Basic,
+        hint: Optional[Union[int, bool, float]] = None,
+        fx_node: Optional[torch.fx.Node] = None,
+        size_oblivious: bool = False,
+        fallback_value: Optional[bool] = None,
+        *,
+        forcing_spec: bool = False,
+    ) -> sympy.Basic:
+        """
+        Given an expression, evaluates it, adding guards if necessary
+        When fallback_value is not None the function return fallback_value instead of failing with data dependent error.
+        """
+
+        # Add extra state that evaluate_expr() depends on.
+        suppress_guards_tls = ShapeEnv._suppress_guards_tls()
+        return self._inner_evaluate_expr(
+            orig_expr,
+            hint,
+            fx_node,
+            size_oblivious,
+            forcing_spec,
+            suppress_guards_tls,
+            fallback_value,
+        )
+
+    @lru_cache(256)
+    @record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr")
+    def _inner_evaluate_expr(
+        self,
+        orig_expr: sympy.Basic,
+        hint: Optional[Union[int, bool, float]],
+        fx_node: Optional[torch.fx.Node],
+        size_oblivious: bool,
+        forcing_spec: bool,
+        _suppress_guards_tls: bool,
+        fallback_value: Optional[bool] = None,
+    ) -> sympy.Basic:
+        try:
+            return self._evaluate_expr(
+                orig_expr,
+                hint,
+                fx_node,
+                size_oblivious,
+                fallback_value,
+                forcing_spec=forcing_spec,
+            )
+        except Exception as e:
+            if isinstance(e, GuardOnDataDependentSymNode):
+                pass
+            else:
+                self.log.warning(
+                    "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
+                    orig_expr,
+                    hint,
+                    size_oblivious,
+                    forcing_spec,
+                )
+            raise
+
+    def _log_suppressed_dde(self, a: SymBool, assumed_value: bool) -> None:
+        sloc, extra = self._get_stack_summary(True)
+        log.info(
+            "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s",
+            a,
+            assumed_value,
+            sloc,
+            extra,
+        )
+
+    def _evaluate_expr(
+        self,
+        orig_expr: sympy.Basic,
+        hint: Optional[Union[bool, int, float]] = None,
+        fx_node: Optional[torch.fx.Node] = None,
+        size_oblivious: bool = False,
+        fallback_value: Optional[bool] = None,
+        *,
+        forcing_spec: bool = False,
+    ) -> sympy.Basic:
+        # TODO: split conjunctions and evaluate them separately
+        if isinstance(
+            orig_expr,
+            (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
+        ):
+            return orig_expr
+
+        # Don't track this one. (Because this cache is inside this function the
+        # cache only lasts for the invocation of this function call)
+        @functools.cache
+        def compute_concrete_val() -> sympy.Basic:
+            if hint is None:
+                # This is only ever called for expressions WITHOUT unbacked
+                # symbols
+                r = self.size_hint(orig_expr)
+                assert r is not None
+                return r
+            else:
+                return sympy.sympify(hint)
+
+        concrete_val: Optional[sympy.Basic]
+
+        # Check if:
+        #   1. 'translation_validation' is set
+        #   2. the corresponding 'fx_node' is not 'None'
+        #   3. the guard should not be suppressed
+        #   4. the guard doesn't contain backed symfloat symbols
+        #      since z3 can't handle floats
+        #   5. fallback_value is none.
+        # If all of the above check, we create an FX node representing the
+        # actual expression to be guarded.
+        node = None
+        fresh = False
+        if (
+            self._translation_validation_enabled
+            and fx_node is not None
+            and not self._suppress_guards_tls()
+            and not size_oblivious
+            and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
+            and fallback_value is None
+        ):
+            # TODO: does this even worked with unbacked :think:
+            concrete_val = compute_concrete_val()
+            if concrete_val is sympy.true:
+                node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
+            elif concrete_val is sympy.false:
+                neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
+                node, fresh = self._create_fx_call_function(torch._assert, (neg,))
+            else:
+                eql, _ = self._create_fx_call_function(
+                    operator.eq, (fx_node, concrete_val)
+                )
+                node, fresh = self._create_fx_call_function(torch._assert, (eql,))
+
+            assert node is not None
+            # If this is a fresh node, we have to remember the event index that
+            # corresponds to this assertion node.
+            # Reason: so that, given an assertion node, we can replay the ShapeEnv
+            # events until the point where this assertion node was freshly created.
+            if fresh:
+                self._add_fx_node_metadata(node)
+
+        # After creating the FX node corresponding to orig_expr, we must make sure that
+        # no error will be raised until the end of this function.
+        #
+        # Reason: the translation validation may become invalid otherwise.
+        #
+        # If an error is raised before the end of this function, we remove the FX node
+        # inserted, and re-raise the error.
+        guard = None
+
+        try:
+            if orig_expr.is_number:
+                self.log.debug("eval %s [trivial]", orig_expr)
+                if hint is not None:
+                    if isinstance(hint, bool):
+                        assert orig_expr == hint, f"{orig_expr} != {hint}"
+                    else:
+                        assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
+                return orig_expr
+
+            expr = orig_expr
+
+            static_expr = self._maybe_evaluate_static(
+                expr, size_oblivious=size_oblivious
+            )
+            if static_expr is not None:
+                self.log.debug(
+                    "eval %s == %s [statically known]",
+                    (
+                        f"size_oblivious({orig_expr})"
+                        if size_oblivious
+                        else size_oblivious
+                    ),
+                    static_expr,
+                )
+                if (
+                    not size_oblivious
+                    and config.backed_size_oblivious
+                    and hint is not None
+                ):
+                    # TODO: maybe reconcile this with use of counterfactual hints
+                    # in unbacked case
+                    assert static_expr == hint, f"{static_expr} != {hint}"
+                return static_expr
+
+            transmute_into_runtime_assert = False
+
+            concrete_val = None
+            if not (expr.free_symbols <= self.var_to_val.keys()):
+                # TODO: dedupe this with _maybe_evaluate_static
+                # Attempt to eliminate the unbacked SymInt
+                new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
+                assert new_expr is not None
+                if not (new_expr.free_symbols <= self.var_to_val.keys()):
+                    ok = False
+
+                    # fallback_value is set when guard_or_true or guard_or_false are used.
+                    if not ok and fallback_value is not None:
+                        self._log_suppressed_dde(orig_expr, fallback_value)
+                        return fallback_value
+
+                    # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
+                    # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
+                    if (
+                        self.oblivious_var_to_val
+                        and not (
+                            correct_hint := orig_expr.xreplace(
+                                self.oblivious_var_to_val
+                            )
+                        ).free_symbols
+                        and not (
+                            counterfactual_hint := orig_expr.xreplace(
+                                {
+                                    k: max(2, v)
+                                    for k, v in self.oblivious_var_to_val.items()
+                                }
+                            )
+                        ).free_symbols
+                        and correct_hint == counterfactual_hint
+                    ):
+                        # TODO: better logging
+                        log.info(
+                            "oblivious_size %s -> %s (passed counterfactual)",
+                            orig_expr,
+                            correct_hint,
+                        )
+
+                        concrete_val = correct_hint
+                        # NB: do NOT transmute into runtime assert
+                        ok = True
+
+                    # unbacked_var_to_val is not None iff propagate_real_tensors is on.
+                    # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
+                    # and if they pass we add a runtime assertions and continue.
+                    if (
+                        not ok
+                        and self.unbacked_var_to_val
+                        and not (
+                            unsound_result := orig_expr.xreplace(
+                                self.unbacked_var_to_val
+                            ).xreplace(self.var_to_val)
+                        ).free_symbols
+                    ):
+                        self._log_real_tensor_propagation(orig_expr, unsound_result)
+                        transmute_into_runtime_assert = True
+
+                        concrete_val = unsound_result
+                        ok = True
+
+                    # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
+                    # instead of failing.
+                    if not ok and self.trace_asserts and self._is_python_assert():
+                        concrete_val = sympy.true
+                        transmute_into_runtime_assert = True
+                        ok = True
+
+                    if not ok:
+                        raise self._make_data_dependent_error(
+                            expr.xreplace(self.var_to_val),
+                            expr,
+                            expr_sym_node_id=self._expr_sym_node_id,
+                        )
+                else:
+                    expr = new_expr
+
+            if concrete_val is None:
+                concrete_val = compute_concrete_val()
+            self._check_frozen(expr, concrete_val)
+
+            if (
+                config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
+                and isinstance(hint, bool)
+                and isinstance(expr, (sympy.Eq, sympy.Ne))
+            ):
+                expr = sympy.Not(expr)
+
+            # Turn this into a boolean expression, no longer need to consult
+            # concrete_val
+            if concrete_val is sympy.true:
+                g = cast(SympyBoolean, expr)
+            elif concrete_val is sympy.false:
+                g = sympy.Not(expr)
+            else:
+                g = sympy.Eq(expr, concrete_val)  # type: ignore[arg-type]
+
+            if transmute_into_runtime_assert:
+                self.guard_or_defer_runtime_assert(
+                    g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
+                )
+                return concrete_val
+
+            if not self._suppress_guards_tls():
+                self._log_guard("eval", g, forcing_spec=forcing_spec)
+
+                # TODO: If we successfully eliminate a symbol via equality, it
+                # is not actually necessary to save a guard for the equality,
+                # as we will implicitly generate a guard when we match that
+                # input against the symbol.  Probably the easiest way to
+                # implement this is to have maybe_guard_rel return a bool
+                # saying if it "subsumed" the guard (and therefore the guard
+                # is no longer necessary)
+                self._maybe_guard_rel(g)
+
+                if (
+                    torch.compiler.is_exporting()
+                    and self.prefer_deferred_runtime_asserts_over_guards
+                ):
+                    # it's fine to defer simple guards here without checking,
+                    # the _maybe_guard_rel() call above will set replacements if possible,
+                    # and so the result here will be statically known
+                    self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
+                else:
+                    # at this point, we've evaluated the concrete expr value, and have
+                    # flipped/negated the guard if necessary. Now we know what to guard
+                    # or defer to runtime assert on.
+                    guard = ShapeGuard(
+                        g, self._get_sloc(), size_oblivious=size_oblivious
+                    )
+                    self.guards.append(guard)
+                    self.axioms.update(dict(self.get_implications(self.simplify(g))))
+            else:
+                self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
+
+        except Exception:
+            if fresh:
+                self._remove_fx_node(node)
+            raise
+
+        if not self._suppress_guards_tls():
+            if guard is not None:  # we might have deferred this to runtime assert
+                for s in g.free_symbols:
+                    self.symbol_guard_counter[s] += 1
+                    # Forcing_spec to avoid infinite recursion
+                    if (
+                        not forcing_spec
+                        and config.symbol_guard_limit_before_specialize is not None
+                        and self.symbol_guard_counter[s]
+                        > config.symbol_guard_limit_before_specialize
+                    ):
+                        # Force specialization
+                        self.log.info(
+                            "symbol_guard_limit_before_specialize=%s exceeded on %s",
+                            config.symbol_guard_limit_before_specialize,
+                            s,
+                        )
+                        self.evaluate_expr(s, forcing_spec=True)
+
+        return concrete_val
+
+    def cleanup(self) -> None:
+        """
+        Break reference cycles.
+
+        This destroys the stacks. If you really want to keep them, we
+        just need some way to break references on code objects.
+        """
+        for s in self.var_to_stack.values():
+            s.cleanup()
+        for ras in self.deferred_runtime_asserts.values():
+            for ra in ras:
+                ra.stack.cleanup()
+
+    @lru_cache(256)
+    @record_shapeenv_event(save_tracked_fakes=True)
+    def guard_or_defer_runtime_assert(
+        self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None
+    ) -> bool:
+        """
+        Adds a guard that orig_expr is True if we can or fall back to adding an assert
+        that is checked at runtime.
+
+        Args:
+            orig_expr (sympy.Expr): Boolean expression to assert is true
+            msg (str): Message to display on assertion failure
+            fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding
+                to the expression, if applicable
+        """
+        expr = orig_expr
+
+        # TODO: split conjunctions and evaluate them separately
+
+        static_expr = self._maybe_evaluate_static(expr)
+        if static_expr is not None:
+            self.log.debug(
+                "runtime_assert %s == %s [statically known]", orig_expr, static_expr
+            )
+            # TODO: assert bool(static_expr)
+            return bool(static_expr)
+
+        # Attempt to eliminate the unbacked SymInt
+        new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
+        assert new_expr is not None
+        if (
+            not self.prefer_deferred_runtime_asserts_over_guards
+            and new_expr.free_symbols <= self.var_to_val.keys()
+        ):
+            # Do a normal guard
+            return self.evaluate_expr(new_expr, fx_node=fx_node)
+        # NB: Don't use new_expr as expr; it could contain gunk like shape0
+        # which we don't want to guard on
+
+        if (
+            self._translation_validation_enabled
+            and fx_node is not None
+            and not self._suppress_guards_tls()
+        ):
+            node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
+            assert node is not None
+            if fresh:
+                self._add_fx_node_metadata(node)
+
+        if not self._suppress_guards_tls():
+            self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
+            # If you're here because of this assert, read Note [Backwards runtime asserts]
+            # in torch/_inductor/graph.py
+            if self.runtime_asserts_frozen:
+                log.debug("runtime_asserts_frozen but then got %s", expr)
+            self._check_frozen(expr, sympy.true)
+            # eliminate symbols on equality tests / refine ranges
+            self._maybe_guard_rel(expr)
+
+            # canonicalise to remove equations that are trivially equal
+            orig_expr = expr
+            expr = canonicalize_bool_expr(expr)
+            stack = CapturedTraceback.extract(skip=1)
+            ra = RuntimeAssert(expr, msg, stack)
+
+            # TODO: Do this in a way that is less janky than int(s.name[1:])
+            cands = sorted(
+                (s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)),
+                key=lambda s: int(s.name[1:]),
+            )
+            # Is None when prefer_deferred_runtime_asserts_over_guards=True
+            # and the guard in question has no unbacked SymInts in front
+            ix = cands[-1] if cands else None
+            self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
+            self.axioms.update(dict(self.get_implications(self.simplify(expr))))
+            self.num_deferred_runtime_asserts += 1
+            self._update_version_counter()
+        else:
+            self._log_guard(
+                "runtime_assert [guard suppressed]", orig_expr, forcing_spec=False
+            )
+
+        return True
+
+    # Refines the ranges of the variables present in 'guard'.
+    #
+    # This function tries to refine the range of the variables inside
+    # 'guard' by reasoning about it. Specifically, when 'guard' is a
+    # 'sympy.Relational' operation.
+    #
+    # It does mainly 3 things:
+    #   1. Tries to isolate a variable in the left-hand side
+    #   2. Compute the value range of the right-hand side
+    #   3. Update the value range of the variable, if better
+    def _refine_ranges(self, expr: SympyBoolean) -> None:
+        expr = self.simplify(expr)
+
+        for symbol in expr.free_symbols:
+            assert isinstance(symbol, sympy.Symbol)
+
+            if isinstance(self.var_to_val.get(symbol, None), SingletonInt):
+                # Skip var_to_range logic for SingletonInt which is only used
+                # for jagged layout NestedTensors today
+                continue
+
+            r = try_solve(expr, symbol)
+
+            if r is None or not (symbol.is_integer and r[1].is_integer):
+                # Range refinement only supports integer symbols for now.
+                # There are lots of SymPy bugs when it comes to comparing
+                # reals and integers, so we skip that for now.
+                continue
+
+            r_expr, rhs = r
+            vr = self.var_to_range[symbol]
+            lower, upper = vr.lower, vr.upper
+
+            rhs_vr = bound_sympy(rhs, self.var_to_range)
+
+            # Let's suppose that we have a preexisting range for x [0, 100].
+            # Now, we issue a guard x > y, where the range for y is [50, 150].
+            # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
+            # refining x to [51, 100], since x must be greater than y, but the lowest
+            # y could be is 50.
+            #
+            # sympy.Eq may update both lower and upper bounds.
+            # sympy.G{t,e} may update the lower bound, only.
+            # sympy.L{t,e} may update the upper bound, only.
+            if lower <= rhs_vr.lower and isinstance(
+                r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)
+            ):
+                # Strictly greater relations allow us to refine a bit more, since
+                # x < y implies that the lower bound for x is: y + 1.
+                lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
+            if upper >= rhs_vr.upper and isinstance(
+                r_expr, (sympy.Eq, sympy.Le, sympy.Lt)
+            ):
+                upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))
+
+            # Do nothing if the new value range is no better than what we already have.
+            if vr == ValueRanges(lower, upper):
+                continue
+
+            # Updates the range and the guards corresponding to each bound of the symbol.
+            self._update_var_to_range(symbol, ValueRanges(lower, upper))
+            # If the range is refined to singleton, set replacement
+            if self.var_to_range[symbol].is_singleton():
+                self._set_replacement(
+                    symbol,
+                    self.var_to_range[symbol].lower,
+                    "range_refined_to_singleton",
+                )
+
+            # Clears the cache, since this update can change the result.
+            self._maybe_evaluate_static.cache_clear()
+
+    @lru_cache(maxsize=None)
+    @record_shapeenv_event()
+    def constrain_symbol_range(
+        self, s: sympy.Symbol, compiler_min: int, compiler_max: int
+    ) -> None:
+        upd_vr = ValueRanges(compiler_min, compiler_max)
+        old_vr = self.var_to_range.get(s, ValueRanges.unknown())
+        self._update_var_to_range(s, upd_vr)
+        if (new_vr := self.var_to_range[s]) != old_vr:
+            log.info(
+                "constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper
+            )
+
+
+def _is_int(expr: object) -> bool:
+    return isinstance(expr, SymInt) and expr.node.expr.is_number
+
+
+# WARNING: This is legacy, DO NOT USE
+def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool:
+    return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices
+
+
+class PropagateUnbackedSymInts(torch.fx.Interpreter):
+    def run_node(self, n: torch.fx.Node) -> Result:
+        """
+        Run an FX node, propagating unbacked Symbol bindings to the new fake tensor
+        """
+        from torch._guards import detect_fake_mode
+
+        result = super().run_node(n)
+        fake_mode = detect_fake_mode()
+        assert fake_mode is not None
+        rebind_unbacked(fake_mode.shape_env, n, result)
+        return result
+
+
+def _find_user_code_frame() -> Optional[types.FrameType]:
+    frame = inspect.currentframe()
+    while frame is not None:
+        if not frame.f_code.co_filename.startswith(
+            os.path.dirname(inspect.getfile(torch)) + os.path.sep
+        ):
+            break
+        frame = frame.f_back
+    return frame
+
+
+def _blame_user_code(e: Exception, frame: types.FrameType) -> None:
+    frame_summary = traceback.FrameSummary(
+        frame.f_code.co_filename,
+        frame.f_lineno,
+        frame.f_code.co_name,
+    )
+    msg = e.args[0]
+    msg += "\n\nThe following call raised this error:\n" + "".join(
+        traceback.StackSummary.from_list([frame_summary]).format()
+    )
+    e.args = (msg,)
+
+
+class _PythonMsgPrinter(PythonPrinter):
+    """
+    Util printer that replaces sympy symbols with their source-level names
+    and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
+    (i.e., as ==, !=, >, <).
+    """
+
+    def __init__(self, src_map: dict[str, list[str]]) -> None:
+        super().__init__()
+        self.src_map = src_map
+
+    def _print_Symbol(self, sym: sympy.Symbol) -> str:
+        return self.src_map[sym.name][0]
+
+
+def _suggest_torch_checks(
+    e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]]
+) -> None:
+    """
+    Enhances a GuardOnDataDependentSymNode error with suggested fixes using torch._check.
+
+    This function analyzes the condition that caused the data-dependent error and generates
+    user-friendly suggestions for fixing it by adding appropriate torch._check calls.
+    It handles special cases like non-negative checks with specific recommendations.
+
+    Args:
+        e: The GuardOnDataDependentSymNode error to enhance with suggestions
+        src_map: A mapping from symbol names to their corresponding source-level variable names
+
+    Returns:
+        None. Modifies the error message in-place by updating e.args[0].
+    """
+    # extract the unresolved condition on unbacked symints in the error
+    cond = e.cond
+    diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map)
+    if diff:
+        log.warning("Unable to find user code corresponding to {%s}", diff)
+        return
+    printer = _PythonMsgPrinter(src_map)
+    msg = e.args[0]
+    msg += "\nTo fix the error, insert one of the following checks before this call:"
+
+    not_cond_str = printer.doprint(sympy.Not(cond))
+
+    # suggested fixes to resolve `cond` are to tell the compiler to assume
+    # either `cond` or its negation (the user will need to select which)
+    suggested_fixes = [
+        f"torch._check({printer.doprint(cond)})",
+        f"torch._check({not_cond_str})",
+    ]
+
+    for i, fix in enumerate(suggested_fixes):
+        msg += f"\n  {i + 1}. {fix}"
+    src_mapped = ", ".join(
+        f"`{s}` with {' or '.join(src_map[s])}"
+        for s in sorted(s.name for s in cond.free_symbols)
+    )
+    msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)"
+    e.args = (msg,)
+
+
+def _suggest_fixes_for_data_dependent_error_non_strict(
+    e: GuardOnDataDependentSymNode,
+) -> None:
+    """
+    Given a raised data-dependent error, add the following to the error message:
+    1. the closest user code location that raised the error;
+    2. suggested fixes for the error in terms of live variables at that location.
+    """
+
+    # walk the stack up from the data-dependent error until a non-torch frame is found
+    frame = _find_user_code_frame()
+    if frame is not None:
+        # add frame info to error message
+        _blame_user_code(e, frame)
+
+        # map symbol names reachable via frame locals to their source-level names
+        src_map = defaultdict(list)
+        for var, val in frame.f_locals.items():
+            try:
+                tree_leaves_with_path = pytree.tree_leaves_with_path(val)
+            except ValueError:
+                log.warning(
+                    "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}",
+                    type(val),
+                    var,
+                )
+                continue
+            # figure out how to access any symbol inside `val` through `var`
+            for path, leaf in tree_leaves_with_path:
+                name = var + pytree.keystr(path)
+                if isinstance(leaf, torch.SymInt):
+                    src_map[str(leaf.node.expr)].append(name)
+                elif isinstance(leaf, torch.Tensor):
+                    for i, dim in enumerate(leaf.shape):
+                        if isinstance(dim, torch.SymInt):
+                            src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")
+
+        # add suggested torch.check()s based on `src_map` to the error message
+        # replacing unbacked symints in the unresolved condition in the error
+        if isinstance(e.cond, sympy.logic.boolalg.Boolean):
+            _suggest_torch_checks(e, src_map)
+
+
+@contextmanager
+def _remove_effect_token_unbacked_bindings(
+    node: torch.fx.Node,
+) -> Generator[None, None, None]:
+    """
+    Temporarily modifies unbacked_bindings in a node's metadata by removing the first element
+    of each path, which corresponds to an effect token.
+
+    This is used when processing nodes that have effect tokens as the first element in their
+    unbacked_bindings paths. The context manager ensures that the original bindings are
+    restored after the operation is complete.
+
+    Args:
+        node: The FX node whose unbacked_bindings will be temporarily modified
+
+    Yields:
+        None
+    """
+    old_bindings = node.meta.get("unbacked_bindings", {})
+
+    # Remove the extra layer for effect token
+    new_bindings = {k: path[1:] if path else path for k, path in old_bindings.items()}
+
+    node.meta["unbacked_bindings"] = new_bindings
+
+    try:
+        yield
+    finally:
+        node.meta["unbacked_bindings"] = old_bindings
+
+
+# This helper function is used in passes that insert runtime assertions in the graph.
+# When accessing expressions representing input placeholders, we do not apply replacements
+# since those inputs should be seen by assertions that use them to be inserted. The only replacement
+# that we apply is unbacked renaming.
+def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr:
+    shape_env = sym_node.shape_env
+    result = sym_node._expr
+    if result in shape_env.unbacked_renamings:
+        return shape_env.unbacked_renamings[result]
+    return result
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7db0e29d1d4f75c770562c65013c03817643f6b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py
@@ -0,0 +1,4 @@
+# mypy: disable-error-code=attr-defined
+from .core import reify, unify  # noqa: F403
+from .more import unifiable  # noqa: F403
+from .variable import isvar, Var, var, variables, vars  # noqa: F403
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edb1e1cfede7d1871a47ab713d852ff44dac9423
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99192937be7dc985b37a40dddf7cac32794402fc
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e11cb13af76b9a06fce4f4618b5a8938e996f910
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3109d4a15ca8e9552cd77d835c0a73b60522879c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e49c31a6c7e1c15468dc0dd31f1e229ac281f314
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..991a9e52699369db3f19caaf9283d08ad237c9e1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad31543eeda7bedc82adb207d18ea757abc53294
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..150714a0020a841ee3eb094f90a3a74c9c62e9a7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d8071c847ae5da144d7ab57b5d24e7968b5daf6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py
@@ -0,0 +1,141 @@
+# mypy: allow-untyped-defs
+from collections.abc import Iterator  # type: ignore[import]
+from functools import partial
+
+from .dispatch import dispatch
+from .unification_tools import assoc  # type: ignore[import]
+from .utils import transitive_get as walk
+from .variable import isvar
+
+
+__all__ = ["reify", "unify"]
+
+###############
+# Reification #
+###############
+
+
+@dispatch(Iterator, dict)
+def _reify(t, s):
+    return map(partial(reify, s=s), t)
+    # return (reify(arg, s) for arg in t)
+
+
+_reify
+
+
+@dispatch(tuple, dict)  # type: ignore[no-redef]
+def _reify(t, s):
+    return tuple(reify(iter(t), s))
+
+
+_reify
+
+
+@dispatch(list, dict)  # type: ignore[no-redef]
+def _reify(t, s):
+    return list(reify(iter(t), s))
+
+
+_reify
+
+
+@dispatch(dict, dict)  # type: ignore[no-redef]
+def _reify(d, s):
+    return {k: reify(v, s) for k, v in d.items()}
+
+
+_reify
+
+
+@dispatch(object, dict)  # type: ignore[no-redef]
+def _reify(o, s):
+    return o  # catch all, just return the object
+
+
+def reify(e, s):
+    """Replace variables of expression with substitution
+    >>> # xdoctest: +SKIP
+    >>> x, y = var(), var()
+    >>> e = (1, x, (3, y))
+    >>> s = {x: 2, y: 4}
+    >>> reify(e, s)
+    (1, 2, (3, 4))
+    >>> e = {1: x, 3: (y, 5)}
+    >>> reify(e, s)
+    {1: 2, 3: (4, 5)}
+    """
+    if isvar(e):
+        return reify(s[e], s) if e in s else e
+    return _reify(e, s)
+
+
+###############
+# Unification #
+###############
+
+seq = tuple, list, Iterator
+
+
+@dispatch(seq, seq, dict)  # type: ignore[arg-type]
+def _unify(u, v, s):
+    if len(u) != len(v):
+        return False
+    for uu, vv in zip(u, v):  # avoiding recursion
+        s = unify(uu, vv, s)
+        if s is False:
+            return False
+    return s
+
+
+#
+# @dispatch((set, frozenset), (set, frozenset), dict)
+# def _unify(u, v, s):
+#     i = u & v
+#     u = u - i
+#     v = v - i
+#     return _unify(sorted(u), sorted(v), s)
+#
+#
+# @dispatch(dict, dict, dict)
+# def _unify(u, v, s):
+#     if len(u) != len(v):
+#         return False
+#     for key, uval in iteritems(u):
+#         if key not in v:
+#             return False
+#         s = unify(uval, v[key], s)
+#         if s is False:
+#             return False
+#     return s
+#
+#
+# @dispatch(object, object, dict)
+# def _unify(u, v, s):
+#     return False  # catch all
+
+
+@dispatch(object, object, dict)
+def unify(u, v, s):  # no check at the moment
+    """Find substitution so that u == v while satisfying s
+    >>> x = var("x")
+    >>> unify((1, x), (1, 2), {})
+    {~x: 2}
+    """
+    u = walk(u, s)
+    v = walk(v, s)
+    if u == v:
+        return s
+    if isvar(u):
+        return assoc(s, u, v)
+    if isvar(v):
+        return assoc(s, v, u)
+    return _unify(u, v, s)
+
+
+unify
+
+
+@dispatch(object, object)  # type: ignore[no-redef]
+def unify(u, v):
+    return unify(u, v, {})
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b950c5b36d67f34cca322ffbbf6851b151de36
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py
@@ -0,0 +1,8 @@
+from functools import partial
+
+from .multipledispatch import dispatch as _dispatch  # type: ignore[import]
+
+
+namespace = {}  # type: ignore[var-annotated]
+
+dispatch = partial(_dispatch, namespace=namespace)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py
new file mode 100644
index 0000000000000000000000000000000000000000..01861a086f64b6121aa9e174d16176533cd0e1a5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py
@@ -0,0 +1,129 @@
+# mypy: allow-untyped-defs
+from .core import reify, unify  # type: ignore[attr-defined]
+from .unification_tools import first, groupby  # type: ignore[import]
+from .utils import _toposort, freeze
+from .variable import isvar
+
+
+class Dispatcher:
+    def __init__(self, name):
+        self.name = name
+        self.funcs = {}
+        self.ordering = []
+
+    def add(self, signature, func):
+        self.funcs[freeze(signature)] = func
+        self.ordering = ordering(self.funcs)
+
+    def __call__(self, *args, **kwargs):
+        func, _ = self.resolve(args)
+        return func(*args, **kwargs)
+
+    def resolve(self, args):
+        n = len(args)
+        for signature in self.ordering:
+            if len(signature) != n:
+                continue
+            s = unify(freeze(args), signature)
+            if s is not False:
+                result = self.funcs[signature]
+                return result, s
+        raise NotImplementedError(
+            "No match found. \nKnown matches: "
+            + str(self.ordering)
+            + "\nInput: "
+            + str(args)
+        )
+
+    def register(self, *signature):
+        def _(func):
+            self.add(signature, func)
+            return self
+
+        return _
+
+
+class VarDispatcher(Dispatcher):
+    """A dispatcher that calls functions with variable names
+    >>> # xdoctest: +SKIP
+    >>> d = VarDispatcher("d")
+    >>> x = var("x")
+    >>> @d.register("inc", x)
+    ... def f(x):
+    ...     return x + 1
+    >>> @d.register("double", x)
+    ... def f(x):
+    ...     return x * 2
+    >>> d("inc", 10)
+    11
+    >>> d("double", 10)
+    20
+    """
+
+    def __call__(self, *args, **kwargs):
+        func, s = self.resolve(args)
+        d = {k.token: v for k, v in s.items()}
+        return func(**d)
+
+
+global_namespace = {}  # type: ignore[var-annotated]
+
+
+def match(*signature, **kwargs):
+    namespace = kwargs.get("namespace", global_namespace)
+    dispatcher = kwargs.get("Dispatcher", Dispatcher)
+
+    def _(func):
+        name = func.__name__
+
+        if name not in namespace:
+            namespace[name] = dispatcher(name)
+        d = namespace[name]
+
+        d.add(signature, func)
+
+        return d
+
+    return _
+
+
+def supercedes(a, b):
+    """``a`` is a more specific match than ``b``"""
+    if isvar(b) and not isvar(a):
+        return True
+    s = unify(a, b)
+    if s is False:
+        return False
+    s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
+    if reify(a, s) == a:
+        return True
+    if reify(b, s) == b:
+        return False
+
+
+# Taken from multipledispatch
+def edge(a, b, tie_breaker=hash):
+    """A should be checked before B
+    Tie broken by tie_breaker, defaults to ``hash``
+    """
+    if supercedes(a, b):
+        if supercedes(b, a):
+            return tie_breaker(a) > tie_breaker(b)
+        else:
+            return True
+    return False
+
+
+# Taken from multipledispatch
+def ordering(signatures):
+    """A sane ordering of signatures to check, first to last
+    Topological sort of edges as given by ``edge`` and ``supercedes``
+    """
+    signatures = list(map(tuple, signatures))
+    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
+    edges = groupby(first, edges)
+    for s in signatures:
+        if s not in edges:
+            edges[s] = []
+    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[attr-defined, assignment]
+    return _toposort(edges)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py
new file mode 100644
index 0000000000000000000000000000000000000000..42074a46a4202cface9799af4b81743c292e766d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py
@@ -0,0 +1,131 @@
+# mypy: allow-untyped-defs
+from .core import (  # type: ignore[attr-defined]
+    _reify as core_reify,
+    _unify as core_unify,
+    reify,
+    unify,
+)
+from .dispatch import dispatch
+
+
+__all__ = ["unifiable", "reify_object", "unify_object"]
+
+
+def unifiable(cls):
+    """Register standard unify and reify operations on class
+    This uses the type and __dict__ or __slots__ attributes to define the
+    nature of the term
+    See Also:
+    >>> # xdoctest: +SKIP
+    >>> class A(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    >>> unifiable(A)
+    
+    >>> x = var("x")
+    >>> a = A(1, 2)
+    >>> b = A(1, x)
+    >>> unify(a, b, {})
+    {~x: 2}
+    """
+    core_unify.add((cls, cls, dict), unify_object)  # type: ignore[attr-defined]
+    core_reify.add((cls, dict), reify_object)  # type: ignore[attr-defined]
+
+    return cls
+
+
+#########
+# Reify #
+#########
+
+
+def reify_object(o, s):
+    """Reify a Python object with a substitution
+    >>> # xdoctest: +SKIP
+    >>> class Foo(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    ...
+    ...     def __str__(self):
+    ...         return "Foo(%s, %s)" % (str(self.a), str(self.b))
+    >>> x = var("x")
+    >>> f = Foo(1, x)
+    >>> print(f)
+    Foo(1, ~x)
+    >>> print(reify_object(f, {x: 2}))
+    Foo(1, 2)
+    """
+    if hasattr(o, "__slots__"):
+        return _reify_object_slots(o, s)
+    else:
+        return _reify_object_dict(o, s)
+
+
+def _reify_object_dict(o, s):
+    obj = object.__new__(type(o))
+    d = reify(o.__dict__, s)
+    if d == o.__dict__:
+        return o
+    obj.__dict__.update(d)
+    return obj
+
+
+def _reify_object_slots(o, s):
+    attrs = [getattr(o, attr) for attr in o.__slots__]
+    new_attrs = reify(attrs, s)
+    if attrs == new_attrs:
+        return o
+    else:
+        newobj = object.__new__(type(o))
+        for slot, attr in zip(o.__slots__, new_attrs):
+            setattr(newobj, slot, attr)
+        return newobj
+
+
+@dispatch(slice, dict)
+def _reify(o, s):
+    """Reify a Python ``slice`` object"""
+    # pyrefly: ignore [not-iterable]
+    return slice(*reify((o.start, o.stop, o.step), s))
+
+
+#########
+# Unify #
+#########
+
+
+def unify_object(u, v, s):
+    """Unify two Python objects
+    Unifies their type and ``__dict__`` attributes
+    >>> # xdoctest: +SKIP
+    >>> class Foo(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    ...
+    ...     def __str__(self):
+    ...         return "Foo(%s, %s)" % (str(self.a), str(self.b))
+    >>> x = var("x")
+    >>> f = Foo(1, x)
+    >>> g = Foo(1, 2)
+    >>> unify_object(f, g, {})
+    {~x: 2}
+    """
+    if type(u) is not type(v):
+        return False
+    if hasattr(u, "__slots__"):
+        return unify(
+            [getattr(u, slot) for slot in u.__slots__],
+            [getattr(v, slot) for slot in v.__slots__],
+            s,
+        )
+    else:
+        return unify(u.__dict__, v.__dict__, s)
+
+
+@dispatch(slice, slice, dict)
+def _unify(u, v, s):
+    """Unify a Python ``slice`` object"""
+    return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb7304069243fb45604e165b06b377a5db233a7d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
@@ -0,0 +1,7 @@
+from .core import dispatch
+from .dispatcher import (
+    Dispatcher,
+    halt_ordering,
+    MDNotImplementedError,
+    restart_ordering,
+)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6810a52e8e2cda04bcf0c25a7c6608e19ffe6b7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e840ddc33394902f1bdad656e7ce67087c37d83
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e5cc39ac116719f144bdaecf7edaf9295405a0f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e912c9e216afb20d151d0d2f87948647d6583717
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fb3488b7a2656c712cc05d53365d69de83c8a49
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9962036d8f3bfc9d14d76ea5739eed99e30dab00
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
new file mode 100644
index 0000000000000000000000000000000000000000..181e0e8dd167ac8b15d58f612308cdfeca1547e1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
@@ -0,0 +1,139 @@
+# mypy: allow-untyped-defs
+import operator
+
+from .utils import _toposort, groupby
+from .variadic import isvariadic
+
+
+__all__ = [
+    "AmbiguityWarning",
+    "supercedes",
+    "consistent",
+    "ambiguous",
+    "ambiguities",
+    "super_signature",
+    "edge",
+    "ordering",
+]
+
+
+class AmbiguityWarning(Warning):
+    pass
+
+
+def supercedes(a, b):
+    """A is consistent and strictly more specific than B"""
+    if len(a) < len(b):
+        # only case is if a is empty and b is variadic
+        return not a and len(b) == 1 and isvariadic(b[-1])
+    elif len(a) == len(b):
+        return all(map(issubclass, a, b))
+    else:
+        # len(a) > len(b)
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                assert p1 == len(a) - 1
+                return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
+            elif isvariadic(cur_b):
+                assert p2 == len(b) - 1
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+        return p2 == len(b) - 1 and p1 == len(a)
+
+
+def consistent(a, b):
+    """It is possible for an argument list to satisfy both A and B"""
+
+    # Need to check for empty args
+    if not a:
+        return not b or isvariadic(b[0])
+    if not b:
+        return not a or isvariadic(a[0])
+
+    # Non-empty args check for mutual subclasses
+    if len(a) == len(b):
+        return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
+    else:
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
+                return False
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                p2 += 1
+            elif isvariadic(cur_b):
+                p1 += 1
+        # We only need to check for variadic ends
+        # Variadic types are guaranteed to be the last element
+        return (
+            isvariadic(cur_a)  # type: ignore[possibly-undefined]
+            and p2 == len(b)
+            or isvariadic(cur_b)  # type: ignore[possibly-undefined]
+            and p1 == len(a)
+        )
+
+
+def ambiguous(a, b):
+    """A is consistent with B but neither is strictly more specific"""
+    return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
+
+
+def ambiguities(signatures):
+    """All signature pairs such that A is ambiguous with B"""
+    signatures = list(map(tuple, signatures))
+    return {
+        (a, b)
+        for a in signatures
+        for b in signatures
+        if hash(a) < hash(b)
+        and ambiguous(a, b)
+        and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
+    }
+
+
+def super_signature(signatures):
+    """A signature that would break ambiguities"""
+    n = len(signatures[0])
+    assert all(len(s) == n for s in signatures)
+
+    return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)]
+
+
+def edge(a, b, tie_breaker=hash):
+    """A should be checked before B
+    Tie broken by tie_breaker, defaults to ``hash``
+    """
+    # A either supersedes B and B does not supersede A or if B does then call
+    # tie_breaker
+    return supercedes(a, b) and (
+        not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
+    )
+
+
+def ordering(signatures):
+    """A sane ordering of signatures to check, first to last
+    Topological sort of edges as given by ``edge`` and ``supercedes``
+    """
+    signatures = list(map(tuple, signatures))
+    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
+    edges = groupby(operator.itemgetter(0), edges)
+    for s in signatures:
+        if s not in edges:
+            edges[s] = []
+    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[assignment, attr-defined]
+    return _toposort(edges)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..69b9f3b2b5a2cb8e9df9d502b4254abffff2dd18
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
@@ -0,0 +1,92 @@
+# mypy: allow-untyped-defs
+import inspect
+from collections.abc import Callable
+from typing import Any, TypeVar
+from typing_extensions import TypeVarTuple, Unpack
+
+from .dispatcher import Dispatcher, MethodDispatcher
+
+
+global_namespace = {}  # type: ignore[var-annotated]
+
+__all__ = ["dispatch", "ismethod"]
+
+T = TypeVar("T")
+Ts = TypeVarTuple("Ts")
+
+
+def dispatch(
+    *types: Unpack[Ts], **kwargs: Any
+) -> Callable[[Callable[..., T]], Callable[..., T]]:
+    """Dispatch function on the types of the inputs
+    Supports dispatch on all non-keyword arguments.
+    Collects implementations based on the function name.  Ignores namespaces.
+    If ambiguous type signatures occur a warning is raised when the function is
+    defined suggesting the additional method to break the ambiguity.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> @dispatch(int)
+        ... def f(x):
+        ...     return x + 1
+        >>> @dispatch(float)
+        ... def f(x):
+        ...     return x - 1
+        >>> # xdoctest: +SKIP
+        >>> f(3)
+        4
+        >>> f(3.0)
+        2.0
+        >>> # Specify an isolated namespace with the namespace keyword argument
+        >>> my_namespace = {}
+        >>> @dispatch(int, namespace=my_namespace)
+        ... def foo(x):
+        ...     return x + 1
+        >>> # Dispatch on instance methods within classes
+        >>> class MyClass(object):
+        ...     @dispatch(list)
+        ...     def __init__(self, data):
+        ...         self.data = data
+        ...
+        ...     @dispatch(int)
+        ...     def __init__(self, datum):
+        ...         self.data = [datum]
+        >>> MyClass([1, 2, 3]).data
+        [1, 2, 3]
+        >>> MyClass(3).data
+        [3]
+    """
+    namespace = kwargs.get("namespace", global_namespace)
+
+    types_tuple: tuple[type, ...] = tuple(types)  # type: ignore[arg-type]
+
+    def _df(func):
+        name = func.__name__
+
+        if ismethod(func):
+            dispatcher = inspect.currentframe().f_back.f_locals.get(  # type: ignore[union-attr]
+                name,  # type: ignore[union-attr]
+                MethodDispatcher(name),
+            )
+        else:
+            if name not in namespace:
+                namespace[name] = Dispatcher(name)
+            dispatcher = namespace[name]
+
+        dispatcher.add(types_tuple, func)
+        return dispatcher
+
+    return _df
+
+
+def ismethod(func):
+    """Is func a method?
+    Note that this has to work as the method is defined but before the class is
+    defined.  At this stage methods look like functions.
+    """
+    if hasattr(inspect, "signature"):
+        signature = inspect.signature(func)
+        return signature.parameters.get("self", None) is not None
+    else:
+        spec = inspect.getfullargspec(func)  # type: ignore[union-attr, assignment]
+        return spec and spec.args and spec.args[0] == "self"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2459b82247bce59cd13ed040722c7278bf36ea0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
@@ -0,0 +1,455 @@
+# mypy: allow-untyped-defs
+import inspect
+import itertools as itl
+from typing_extensions import deprecated
+from warnings import warn
+
+from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature
+from .utils import expand_tuples
+from .variadic import isvariadic, Variadic
+
+
+__all__ = [
+    "MDNotImplementedError",
+    "ambiguity_warn",
+    "halt_ordering",
+    "restart_ordering",
+    "variadic_signature_matches_iter",
+    "variadic_signature_matches",
+    "Dispatcher",
+    "source",
+    "MethodDispatcher",
+    "str_signature",
+    "warning_text",
+]
+
+
+class MDNotImplementedError(NotImplementedError):
+    """A NotImplementedError for multiple dispatch"""
+
+
+def ambiguity_warn(dispatcher, ambiguities):
+    """Raise warning when ambiguity is detected
+    Parameters
+    ----------
+    dispatcher : Dispatcher
+        The dispatcher on which the ambiguity was detected
+    ambiguities : set
+        Set of type signature pairs that are ambiguous within this dispatcher
+    See Also:
+        Dispatcher.add
+        warning_text
+    """
+    warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
+
+
+@deprecated(
+    "`halt_ordering` is deprecated, you can safely remove this call.",
+    category=FutureWarning,
+)
+def halt_ordering():
+    """Deprecated interface to temporarily disable ordering."""
+
+
+@deprecated(
+    "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
+    "you should call the `reorder()` method on each dispatcher.",
+    category=FutureWarning,
+)
+def restart_ordering(on_ambiguity=ambiguity_warn):
+    """Deprecated interface to temporarily resume ordering."""
+
+
+def variadic_signature_matches_iter(types, full_signature):
+    """Check if a set of input types matches a variadic signature.
+    Notes
+    -----
+    The algorithm is as follows:
+    Initialize the current signature to the first in the sequence
+    For each type in `types`:
+        If the current signature is variadic
+            If the type matches the signature
+                yield True
+            Else
+                Try to get the next signature
+                If no signatures are left we can't possibly have a match
+                    so yield False
+        Else
+            yield True if the type matches the current signature
+            Get the next signature
+    """
+    sigiter = iter(full_signature)
+    sig = next(sigiter)
+    for typ in types:
+        matches = issubclass(typ, sig)
+        yield matches
+        if not isvariadic(sig):
+            # we're not matching a variadic argument, so move to the next
+            # element in the signature
+            sig = next(sigiter)
+    else:
+        try:
+            sig = next(sigiter)
+        except StopIteration:
+            assert isvariadic(sig)
+            yield True
+        else:
+            # We have signature items left over, so all of our arguments
+            # haven't matched
+            yield False
+
+
+def variadic_signature_matches(types, full_signature):
+    # No arguments always matches a variadic signature
+    assert full_signature
+    return all(variadic_signature_matches_iter(types, full_signature))
+
+
+class Dispatcher:
+    """Dispatch methods based on type signature
+    Use ``dispatch`` to add implementations
+    Examples
+    --------
+    >>> # xdoctest: +SKIP("bad import name")
+    >>> from multipledispatch import dispatch
+    >>> @dispatch(int)
+    ... def f(x):
+    ...     return x + 1
+    >>> @dispatch(float)
+    ... def f(x):
+    ...     return x - 1
+    >>> f(3)
+    4
+    >>> f(3.0)
+    2.0
+    """
+
+    __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc"
+
+    def __init__(self, name, doc=None):
+        self.name = self.__name__ = name
+        self.funcs = {}
+        self.doc = doc
+
+        self._cache = {}
+
+    def register(self, *types, **kwargs):
+        """register dispatcher with new implementation
+        >>> # xdoctest: +SKIP
+        >>> f = Dispatcher("f")
+        >>> @f.register(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> @f.register(float)
+        ... def dec(x):
+        ...     return x - 1
+        >>> @f.register(list)
+        ... @f.register(tuple)
+        ... def reverse(x):
+        ...     return x[::-1]
+        >>> f(1)
+        2
+        >>> f(1.0)
+        0.0
+        >>> f([1, 2, 3])
+        [3, 2, 1]
+        """
+
+        def _df(func):
+            self.add(types, func, **kwargs)  # type: ignore[call-arg]
+            return func
+
+        return _df
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return sig.parameters.values()
+
+    @classmethod
+    def get_func_annotations(cls, func):
+        """get annotations of function positional parameters"""
+        params = cls.get_func_params(func)
+        if params:
+            Parameter = inspect.Parameter
+
+            params = (
+                param
+                for param in params
+                if param.kind
+                in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
+            )
+
+            annotations = tuple(param.annotation for param in params)
+
+            if all(ann is not Parameter.empty for ann in annotations):
+                return annotations
+
+    def add(self, signature, func):
+        """Add new types/method pair to dispatcher
+        >>> # xdoctest: +SKIP
+        >>> D = Dispatcher("add")
+        >>> D.add((int, int), lambda x, y: x + y)
+        >>> D.add((float, float), lambda x, y: x + y)
+        >>> D(1, 2)
+        3
+        >>> D(1, 2.0)
+        Traceback (most recent call last):
+        ...
+        NotImplementedError: Could not find signature for add: 
+        >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
+        >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
+        >>> # as inputs.  See ``ambiguity_warn`` for an example.
+        """
+        # Handle annotations
+        if not signature:
+            annotations = self.get_func_annotations(func)
+            if annotations:
+                signature = annotations
+
+        # Handle union types
+        if any(isinstance(typ, tuple) for typ in signature):
+            for typs in expand_tuples(signature):
+                self.add(typs, func)
+            return
+
+        new_signature = []
+
+        for index, typ in enumerate(signature, start=1):
+            if not isinstance(typ, (type, list)):
+                str_sig = ", ".join(
+                    c.__name__ if isinstance(c, type) else str(c) for c in signature
+                )
+                raise TypeError(
+                    f"Tried to dispatch on non-type: {typ}\n"
+                    f"In signature: <{str_sig}>\n"
+                    f"In function: {self.name}"
+                )
+
+            # handle variadic signatures
+            if isinstance(typ, list):
+                if index != len(signature):
+                    raise TypeError("Variadic signature must be the last element")
+
+                if len(typ) != 1:
+                    raise TypeError(
+                        "Variadic signature must contain exactly one element. "
+                        "To use a variadic union type place the desired types "
+                        "inside of a tuple, e.g., [(int, str)]"
+                    )
+                # pyrefly: ignore [bad-specialization]
+                new_signature.append(Variadic[typ[0]])
+            else:
+                new_signature.append(typ)
+
+        self.funcs[tuple(new_signature)] = func
+        self._cache.clear()
+
+        try:
+            del self._ordering
+        except AttributeError:
+            pass
+
+    @property
+    def ordering(self):
+        try:
+            return self._ordering
+        except AttributeError:
+            return self.reorder()
+
+    def reorder(self, on_ambiguity=ambiguity_warn):
+        self._ordering = od = ordering(self.funcs)
+        amb = ambiguities(self.funcs)
+        if amb:
+            on_ambiguity(self, amb)
+        return od
+
+    def __call__(self, *args, **kwargs):
+        types = tuple(type(arg) for arg in args)
+        try:
+            func = self._cache[types]
+        except KeyError as e:
+            func = self.dispatch(*types)
+            if not func:
+                raise NotImplementedError(
+                    f"Could not find signature for {self.name}: <{str_signature(types)}>"
+                ) from e
+            self._cache[types] = func
+        try:
+            return func(*args, **kwargs)
+
+        except MDNotImplementedError as e:
+            funcs = self.dispatch_iter(*types)
+            next(funcs)  # burn first
+            for func in funcs:
+                try:
+                    return func(*args, **kwargs)
+                except MDNotImplementedError:
+                    pass
+
+            raise NotImplementedError(
+                "Matching functions for "
+                f"{self.name}: <{str_signature(types)}> found, but none completed successfully",
+            ) from e
+
+    def __str__(self):
+        return f""
+
+    __repr__ = __str__
+
+    def dispatch(self, *types):
+        """Determine appropriate implementation for this type signature
+        This method is internal.  Users should call this object as a function.
+        Implementation resolution occurs within the ``__call__`` method.
+        >>> # xdoctest: +SKIP
+        >>> from multipledispatch import dispatch
+        >>> @dispatch(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> implementation = inc.dispatch(int)
+        >>> implementation(3)
+        4
+        >>> print(inc.dispatch(float))
+        None
+        See Also:
+          ``multipledispatch.conflict`` - module to determine resolution order
+        """
+
+        if types in self.funcs:
+            return self.funcs[types]
+
+        try:
+            return next(self.dispatch_iter(*types))
+        except StopIteration:
+            return None
+
+    def dispatch_iter(self, *types):
+        n = len(types)
+        for signature in self.ordering:
+            if len(signature) == n and all(map(issubclass, types, signature)):
+                result = self.funcs[signature]
+                yield result
+            elif len(signature) and isvariadic(signature[-1]):
+                if variadic_signature_matches(types, signature):
+                    result = self.funcs[signature]
+                    yield result
+
+    @deprecated(
+        "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning
+    )
+    def resolve(self, types):
+        """Determine appropriate implementation for this type signature
+        .. deprecated:: 0.4.4
+            Use ``dispatch(*types)`` instead
+        """
+        return self.dispatch(*types)
+
+    def __getstate__(self):
+        return {"name": self.name, "funcs": self.funcs}
+
+    def __setstate__(self, d):
+        self.name = d["name"]
+        self.funcs = d["funcs"]
+        self._ordering = ordering(self.funcs)
+        self._cache = {}
+
+    @property
+    def __doc__(self):  # type: ignore[override]
+        docs = [f"Multiply dispatched method: {self.name}"]
+
+        if self.doc:
+            docs.append(self.doc)
+
+        other = []
+        for sig in self.ordering[::-1]:
+            func = self.funcs[sig]
+            if func.__doc__:
+                s = f"Inputs: <{str_signature(sig)}>\n"
+                s += "-" * len(s) + "\n"
+                s += func.__doc__.strip()
+                docs.append(s)
+            else:
+                other.append(str_signature(sig))
+
+        if other:
+            docs.append("Other signatures:\n    " + "\n    ".join(other))
+
+        return "\n\n".join(docs)
+
+    def _help(self, *args):
+        return self.dispatch(*map(type, args)).__doc__
+
+    def help(self, *args, **kwargs):
+        """Print docstring for the function corresponding to inputs"""
+        print(self._help(*args))
+
+    def _source(self, *args):
+        func = self.dispatch(*map(type, args))
+        if not func:
+            raise TypeError("No function found")
+        return source(func)
+
+    def source(self, *args, **kwargs):
+        """Print source code for the function corresponding to inputs"""
+        print(self._source(*args))
+
+
+def source(func):
+    s = f"File: {inspect.getsourcefile(func)}\n\n"
+    s = s + inspect.getsource(func)
+    return s
+
+
+class MethodDispatcher(Dispatcher):
+    """Dispatch methods based on type signature
+    See Also:
+        Dispatcher
+    """
+
+    # pyrefly: ignore [bad-override]
+    __slots__ = ("obj", "cls")
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return itl.islice(sig.parameters.values(), 1, None)
+
+    def __get__(self, instance, owner):
+        self.obj = instance
+        self.cls = owner
+        return self
+
+    def __call__(self, *args, **kwargs):
+        types = tuple(type(arg) for arg in args)
+        func = self.dispatch(*types)
+        if not func:
+            raise NotImplementedError(
+                f"Could not find signature for {self.name}: <{str_signature(types)}>"
+            )
+        return func(self.obj, *args, **kwargs)
+
+
+def str_signature(sig):
+    """String representation of type signature
+    >>> str_signature((int, float))
+    'int, float'
+    """
+    return ", ".join(cls.__name__ for cls in sig)
+
+
+def warning_text(name, amb):
+    """The text for ambiguity warnings"""
+    text = f"\nAmbiguities exist in dispatched function {name}\n\n"
+    text += "The following signatures may result in ambiguous behavior:\n"
+    for pair in amb:
+        text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
+    text += "\n\nConsider making the following additions:\n\n"
+    text += "\n\n".join(
+        [
+            "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)"
+            for s in amb
+        ]
+    )
+    return text
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b21183c40b97a0757fc5c332cb783f39fc85efe
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
@@ -0,0 +1,127 @@
+# mypy: allow-untyped-defs
+from collections import OrderedDict
+
+
+__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
+
+
+def raises(err, lamda):  # codespell:ignore lamda
+    try:
+        lamda()  # codespell:ignore lamda
+        return False
+    except err:
+        return True
+
+
+def expand_tuples(L):
+    """
+    >>> expand_tuples([1, (2, 3)])
+    [(1, 2), (1, 3)]
+    >>> expand_tuples([1, 2])
+    [(1, 2)]
+    """
+    if not L:
+        return [()]
+    elif not isinstance(L[0], tuple):
+        rest = expand_tuples(L[1:])
+        return [(L[0],) + t for t in rest]
+    else:
+        rest = expand_tuples(L[1:])
+        return [(item,) + t for t in rest for item in L[0]]
+
+
+# Taken from theano/theano/gof/sched.py
+# Avoids licensing issues because this was written by Matthew Rocklin
+def _toposort(edges):
+    """Topological sort algorithm by Kahn [1] - O(nodes + vertices)
+    inputs:
+        edges - a dict of the form {a: {b, c}} where b and c depend on a
+    outputs:
+        L - an ordered list of nodes that satisfy the dependencies of edges
+    >>> _toposort({1: (2, 3), 2: (3,)})
+    [1, 2, 3]
+    >>> # Closely follows the wikipedia page [2]
+    >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
+    >>> # Communications of the ACM
+    >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
+    """
+    incoming_edges = reverse_dict(edges)
+    incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
+    S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
+    L = []
+
+    while S:
+        n, _ = S.popitem()
+        L.append(n)
+        for m in edges.get(n, ()):
+            assert n in incoming_edges[m]
+            incoming_edges[m].remove(n)
+            if not incoming_edges[m]:
+                S[m] = None
+    if any(incoming_edges.get(v, None) for v in edges):
+        raise ValueError("Input has cycles")
+    return L
+
+
+def reverse_dict(d):
+    """Reverses direction of dependence dict
+    >>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
+    >>> reverse_dict(d)  # doctest: +SKIP
+    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
+    :note: dict order are not deterministic. As we iterate on the
+        input dict, it make the output of this function depend on the
+        dict order. So this function output order should be considered
+        as undeterministic.
+    """
+    result = OrderedDict()  # type: ignore[var-annotated]
+    for key in d:
+        for val in d[key]:
+            result[val] = result.get(val, ()) + (key,)
+    return result
+
+
+# Taken from toolz
+# Avoids licensing issues because this version was authored by Matthew Rocklin
+def groupby(func, seq):
+    """Group a collection by a key function
+    >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
+    >>> groupby(len, names)  # doctest: +SKIP
+    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
+    >>> iseven = lambda x: x % 2 == 0
+    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
+    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
+    See Also:
+        ``countby``
+    """
+
+    d = OrderedDict()  # type: ignore[var-annotated]
+    for item in seq:
+        key = func(item)
+        if key not in d:
+            d[key] = []
+        d[key].append(item)
+    return d
+
+
+def typename(type):
+    """Get the name of `type`.
+    Parameters
+    ----------
+    type : Union[Type, Tuple[Type]]
+    Returns
+    -------
+    str
+        The name of `type` or a tuple of the names of the types in `type`.
+    Examples
+    --------
+    >>> typename(int)
+    'int'
+    >>> typename((int, float))
+    '(int, float)'
+    """
+    try:
+        return type.__name__
+    except AttributeError:
+        if len(type) == 1:
+            return typename(*type)
+        return f"({', '.join(map(typename, type))})"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b5604a152480f83916108cb1b02de3bc9b9adb5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
@@ -0,0 +1,96 @@
+# mypy: allow-untyped-defs
+from .utils import typename
+
+
+__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
+
+
+class VariadicSignatureType(type):
+    # checking if subclass is a subclass of self
+    def __subclasscheck__(cls, subclass):
+        other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
+        return subclass is cls or all(
+            issubclass(other, cls.variadic_type)  # type: ignore[attr-defined]
+            for other in other_type
+        )
+
+    def __eq__(cls, other):
+        """
+        Return True if other has the same variadic type
+        Parameters
+        ----------
+        other : object (type)
+            The object (type) to check
+        Returns
+        -------
+        bool
+            Whether or not `other` is equal to `self`
+        """
+        return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type)  # type: ignore[attr-defined]
+
+    def __hash__(cls):
+        return hash((type(cls), frozenset(cls.variadic_type)))  # type: ignore[attr-defined]
+
+
+def isvariadic(obj):
+    """Check whether the type `obj` is variadic.
+    Parameters
+    ----------
+    obj : type
+        The type to check
+    Returns
+    -------
+    bool
+        Whether or not `obj` is variadic
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> isvariadic(int)
+    False
+    >>> isvariadic(Variadic[int])
+    True
+    """
+    return isinstance(obj, VariadicSignatureType)
+
+
+class VariadicSignatureMeta(type):
+    """A metaclass that overrides ``__getitem__`` on the class. This is used to
+    generate a new type for Variadic signatures. See the Variadic class for
+    examples of how this behaves.
+    """
+
+    def __getitem__(cls, variadic_type):
+        if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
+            raise ValueError(
+                "Variadic types must be type or tuple of types"
+                " (Variadic[int] or Variadic[(int, float)]"
+            )
+
+        if not isinstance(variadic_type, tuple):
+            variadic_type = (variadic_type,)
+        return VariadicSignatureType(
+            f"Variadic[{typename(variadic_type)}]",
+            (),
+            dict(variadic_type=variadic_type, __slots__=()),
+        )
+
+
+class Variadic(metaclass=VariadicSignatureMeta):
+    """A class whose getitem method can be used to generate a new type
+    representing a specific variadic signature.
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> Variadic[int]  # any number of int arguments
+    
+    >>> Variadic[(int, str)]  # any number of one of int or str arguments
+    
+    >>> issubclass(int, Variadic[int])
+    True
+    >>> issubclass(int, Variadic[(int, str)])
+    True
+    >>> issubclass(str, Variadic[(int, str)])
+    True
+    >>> issubclass(float, Variadic[(int, str)])
+    False
+    """
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b4216a79ad0351cc6fedba64c06810fdc894426
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py
@@ -0,0 +1,420 @@
+# mypy: allow-untyped-defs
+import collections
+import operator
+from collections.abc import Mapping
+from functools import reduce
+
+
+__all__ = [
+    "merge",
+    "merge_with",
+    "valmap",
+    "keymap",
+    "itemmap",
+    "valfilter",
+    "keyfilter",
+    "itemfilter",
+    "assoc",
+    "dissoc",
+    "assoc_in",
+    "update_in",
+    "get_in",
+]
+
+
+def _get_factory(f, kwargs):
+    factory = kwargs.pop("factory", dict)
+    if kwargs:
+        raise TypeError(
+            f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'"
+        )
+    return factory
+
+
+def merge(*dicts, **kwargs):
+    """Merge a collection of dictionaries
+
+    >>> merge({1: "one"}, {2: "two"})
+    {1: 'one', 2: 'two'}
+
+    Later dictionaries have precedence
+
+    >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
+    {1: 2, 3: 3, 4: 4}
+
+    See Also:
+        merge_with
+    """
+    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
+        dicts = dicts[0]
+    factory = _get_factory(merge, kwargs)
+
+    rv = factory()
+    for d in dicts:
+        rv.update(d)
+    return rv
+
+
+def merge_with(func, *dicts, **kwargs):
+    """Merge dictionaries and apply function to combined values
+
+    A key may occur in more than one dict, and all values mapped from the key
+    will be passed to the function as a list, such as func([val1, val2, ...]).
+
+    >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
+    {1: 11, 2: 22}
+
+    >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30})  # doctest: +SKIP
+    {1: 1, 2: 2, 3: 30}
+
+    See Also:
+        merge
+    """
+    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
+        dicts = dicts[0]
+    factory = _get_factory(merge_with, kwargs)
+
+    result = factory()
+    for d in dicts:
+        for k, v in d.items():
+            if k not in result:
+                result[k] = [v]
+            else:
+                result[k].append(v)
+    return valmap(func, result, factory)
+
+
+def valmap(func, d, factory=dict):
+    """Apply function to values of dictionary
+
+    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
+    >>> valmap(sum, bills)  # doctest: +SKIP
+    {'Alice': 65, 'Bob': 45}
+
+    See Also:
+        keymap
+        itemmap
+    """
+    rv = factory()
+    rv.update(zip(d.keys(), map(func, d.values())))
+    return rv
+
+
+def keymap(func, d, factory=dict):
+    """Apply function to keys of dictionary
+
+    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
+    >>> keymap(str.lower, bills)  # doctest: +SKIP
+    {'alice': [20, 15, 30], 'bob': [10, 35]}
+
+    See Also:
+        valmap
+        itemmap
+    """
+    rv = factory()
+    rv.update(zip(map(func, d.keys()), d.values()))
+    return rv
+
+
+def itemmap(func, d, factory=dict):
+    """Apply function to items of dictionary
+
+    >>> accountids = {"Alice": 10, "Bob": 20}
+    >>> itemmap(reversed, accountids)  # doctest: +SKIP
+    {10: "Alice", 20: "Bob"}
+
+    See Also:
+        keymap
+        valmap
+    """
+    rv = factory()
+    rv.update(map(func, d.items()))
+    return rv
+
+
+def valfilter(predicate, d, factory=dict):
+    """Filter items in dictionary by value
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> valfilter(iseven, d)
+    {1: 2, 3: 4}
+
+    See Also:
+        keyfilter
+        itemfilter
+        valmap
+    """
+    rv = factory()
+    for k, v in d.items():
+        if predicate(v):
+            rv[k] = v
+    return rv
+
+
+def keyfilter(predicate, d, factory=dict):
+    """Filter items in dictionary by key
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> keyfilter(iseven, d)
+    {2: 3, 4: 5}
+
+    See Also:
+        valfilter
+        itemfilter
+        keymap
+    """
+    rv = factory()
+    for k, v in d.items():
+        if predicate(k):
+            rv[k] = v
+    return rv
+
+
+def itemfilter(predicate, d, factory=dict):
+    """Filter items in dictionary by item
+
+    >>> def isvalid(item):
+    ...     k, v = item
+    ...     return k % 2 == 0 and v < 4
+
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> itemfilter(isvalid, d)
+    {2: 3}
+
+    See Also:
+        keyfilter
+        valfilter
+        itemmap
+    """
+    rv = factory()
+    for item in d.items():
+        if predicate(item):
+            k, v = item
+            rv[k] = v
+    return rv
+
+
+def assoc(d, key, value, factory=dict):
+    """Return a new dict with new key value pair
+
+    New dict has d[key] set to value. Does not modify the initial dictionary.
+
+    >>> assoc({"x": 1}, "x", 2)
+    {'x': 2}
+    >>> assoc({"x": 1}, "y", 3)  # doctest: +SKIP
+    {'x': 1, 'y': 3}
+    """
+    d2 = factory()
+    d2.update(d)
+    d2[key] = value
+    return d2
+
+
+def dissoc(d, *keys, **kwargs):
+    """Return a new dict with the given key(s) removed.
+
+    New dict has d[key] deleted for each supplied key.
+    Does not modify the initial dictionary.
+
+    >>> dissoc({"x": 1, "y": 2}, "y")
+    {'x': 1}
+    >>> dissoc({"x": 1, "y": 2}, "y", "x")
+    {}
+    >>> dissoc({"x": 1}, "y")  # Ignores missing keys
+    {'x': 1}
+    """
+    factory = _get_factory(dissoc, kwargs)
+    d2 = factory()
+
+    if len(keys) < len(d) * 0.6:
+        d2.update(d)
+        for key in keys:
+            if key in d2:
+                del d2[key]
+    else:
+        remaining = set(d)
+        remaining.difference_update(keys)
+        for k in remaining:
+            d2[k] = d[k]
+    return d2
+
+
+def assoc_in(d, keys, value, factory=dict):
+    """Return a new dict with new, potentially nested, key value pair
+
+    >>> purchase = {
+    ...     "name": "Alice",
+    ...     "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
+    ...     "credit card": "5555-1234-1234-1234",
+    ... }
+    >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00])  # doctest: +SKIP
+    {'credit card': '5555-1234-1234-1234',
+     'name': 'Alice',
+     'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
+    """
+    return update_in(d, keys, lambda x: value, value, factory)
+
+
+def update_in(d, keys, func, default=None, factory=dict):
+    """Update value in a (potentially) nested dictionary
+
+    inputs:
+    d - dictionary on which to operate
+    keys - list or tuple giving the location of the value to be changed in d
+    func - function to operate on that value
+
+    If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
+    original dictionary with v replaced by func(v), but does not mutate the
+    original dictionary.
+
+    If k0 is not a key in d, update_in creates nested dictionaries to the depth
+    specified by the keys, with the innermost value set to func(default).
+
+    >>> inc = lambda x: x + 1
+    >>> update_in({"a": 0}, ["a"], inc)
+    {'a': 1}
+
+    >>> transaction = {
+    ...     "name": "Alice",
+    ...     "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
+    ...     "credit card": "5555-1234-1234-1234",
+    ... }
+    >>> update_in(transaction, ["purchase", "costs"], sum)  # doctest: +SKIP
+    {'credit card': '5555-1234-1234-1234',
+     'name': 'Alice',
+     'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
+
+    >>> # updating a value when k0 is not in d
+    >>> update_in({}, [1, 2, 3], str, default="bar")
+    {1: {2: {3: 'bar'}}}
+    >>> update_in({1: "foo"}, [2, 3, 4], inc, 0)
+    {1: 'foo', 2: {3: {4: 1}}}
+    """
+    ks = iter(keys)
+    k = next(ks)
+
+    rv = inner = factory()
+    rv.update(d)
+
+    # pyrefly: ignore [not-iterable]
+    for key in ks:
+        if k in d:
+            d = d[k]
+            dtemp = factory()
+            dtemp.update(d)
+        else:
+            d = dtemp = factory()
+
+        inner[k] = inner = dtemp
+        k = key
+
+    if k in d:
+        inner[k] = func(d[k])
+    else:
+        inner[k] = func(default)
+    return rv
+
+
+def get_in(keys, coll, default=None, no_default=False):
+    """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
+
+    If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
+    ``no_default`` is specified, then it raises KeyError or IndexError.
+
+    ``get_in`` is a generalization of ``operator.getitem`` for nested data
+    structures such as dictionaries and lists.
+
+    >>> transaction = {
+    ...     "name": "Alice",
+    ...     "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
+    ...     "credit card": "5555-1234-1234-1234",
+    ... }
+    >>> get_in(["purchase", "items", 0], transaction)
+    'Apple'
+    >>> get_in(["name"], transaction)
+    'Alice'
+    >>> get_in(["purchase", "total"], transaction)
+    >>> get_in(["purchase", "items", "apple"], transaction)
+    >>> get_in(["purchase", "items", 10], transaction)
+    >>> get_in(["purchase", "total"], transaction, 0)
+    0
+    >>> get_in(["y"], {}, no_default=True)
+    Traceback (most recent call last):
+        ...
+    KeyError: 'y'
+
+    See Also:
+        itertoolz.get
+        operator.getitem
+    """
+    try:
+        return reduce(operator.getitem, keys, coll)
+    except (KeyError, IndexError, TypeError):
+        if no_default:
+            raise
+        return default
+
+
+def getter(index):
+    if isinstance(index, list):
+        if len(index) == 1:
+            index = index[0]
+            return lambda x: (x[index],)
+        elif index:
+            return operator.itemgetter(*index)
+        else:
+            return lambda x: ()
+    else:
+        return operator.itemgetter(index)
+
+
+def groupby(key, seq):
+    """Group a collection by a key function
+
+    >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
+    >>> groupby(len, names)  # doctest: +SKIP
+    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
+    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
+
+    Non-callable keys imply grouping on a member.
+
+    >>> groupby(
+    ...     "gender",
+    ...     [
+    ...         {"name": "Alice", "gender": "F"},
+    ...         {"name": "Bob", "gender": "M"},
+    ...         {"name": "Charlie", "gender": "M"},
+    ...     ],
+    ... )  # doctest:+SKIP
+    {'F': [{'gender': 'F', 'name': 'Alice'}],
+     'M': [{'gender': 'M', 'name': 'Bob'},
+           {'gender': 'M', 'name': 'Charlie'}]}
+
+    Not to be confused with ``itertools.groupby``
+
+    See Also:
+        countby
+    """
+    if not callable(key):
+        key = getter(key)
+    d = collections.defaultdict(lambda: [].append)  # type: ignore[var-annotated]
+    for item in seq:
+        d[key(item)](item)
+    rv = {}
+    for k, v in d.items():
+        rv[k] = v.__self__  # type: ignore[var-annotated, attr-defined]
+    return rv
+
+
+def first(seq):
+    """The first element in a sequence
+
+    >>> first("ABC")
+    'A'
+    """
+    return next(iter(seq))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab99ad1b4f0d495067cb33b8464c7c80777f7d8d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py
@@ -0,0 +1,108 @@
+# mypy: allow-untyped-defs
+__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
+
+
+def hashable(x):
+    try:
+        hash(x)
+        return True
+    except TypeError:
+        return False
+
+
+def transitive_get(key, d):
+    """Transitive dict.get
+    >>> d = {1: 2, 2: 3, 3: 4}
+    >>> d.get(1)
+    2
+    >>> transitive_get(1, d)
+    4
+    """
+    while hashable(key) and key in d:
+        key = d[key]
+    return key
+
+
+def raises(err, lamda):  # codespell:ignore lamda
+    try:
+        lamda()  # codespell:ignore lamda
+        return False
+    except err:
+        return True
+
+
+# Taken from theano/theano/gof/sched.py
+# Avoids licensing issues because this was written by Matthew Rocklin
+def _toposort(edges):
+    """Topological sort algorithm by Kahn [1] - O(nodes + vertices)
+    inputs:
+        edges - a dict of the form {a: {b, c}} where b and c depend on a
+    outputs:
+        L - an ordered list of nodes that satisfy the dependencies of edges
+    >>> # xdoctest: +SKIP
+    >>> _toposort({1: (2, 3), 2: (3,)})
+    [1, 2, 3]
+    Closely follows the wikipedia page [2]
+    [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
+    Communications of the ACM
+    [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
+    """
+    incoming_edges = reverse_dict(edges)
+    incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
+    S = {v for v in edges if v not in incoming_edges}
+    L = []
+
+    while S:
+        n = S.pop()
+        L.append(n)
+        for m in edges.get(n, ()):
+            assert n in incoming_edges[m]
+            incoming_edges[m].remove(n)
+            if not incoming_edges[m]:
+                S.add(m)
+    if any(incoming_edges.get(v) for v in edges):
+        raise ValueError("Input has cycles")
+    return L
+
+
+def reverse_dict(d):
+    """Reverses direction of dependence dict
+    >>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
+    >>> reverse_dict(d)  # doctest: +SKIP
+    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
+    :note: dict order are not deterministic. As we iterate on the
+        input dict, it make the output of this function depend on the
+        dict order. So this function output order should be considered
+        as undeterministic.
+    """
+    result = {}  # type: ignore[var-annotated]
+    for key in d:
+        for val in d[key]:
+            result[val] = result.get(val, ()) + (key,)
+    return result
+
+
+def xfail(func):
+    try:
+        func()
+        raise Exception("XFailed test passed")  # pragma:nocover  # noqa: TRY002
+    except Exception:
+        pass
+
+
+def freeze(d):
+    """Freeze container to hashable form
+    >>> freeze(1)
+    1
+    >>> freeze([1, 2])
+    (1, 2)
+    >>> freeze({1: 2})  # doctest: +SKIP
+    frozenset([(1, 2)])
+    """
+    if isinstance(d, dict):
+        return frozenset(map(freeze, d.items()))
+    if isinstance(d, set):
+        return frozenset(map(freeze, d))
+    if isinstance(d, (tuple, list)):
+        return tuple(map(freeze, d))
+    return d
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b5b51aaf99a5dc9864f5aa22fa9c50571f95797
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py
@@ -0,0 +1,90 @@
+# mypy: allow-untyped-defs
+from contextlib import contextmanager
+
+from .dispatch import dispatch
+from .utils import hashable
+
+
+_global_logic_variables = set()  # type: ignore[var-annotated]
+_glv = _global_logic_variables
+
+
+class Var:
+    """Logic Variable"""
+
+    _id = 1
+
+    def __new__(cls, *token):
+        if len(token) == 0:
+            token = f"_{Var._id}"  # type: ignore[assignment]
+            Var._id += 1
+        elif len(token) == 1:
+            token = token[0]
+
+        obj = object.__new__(cls)
+        obj.token = token  # type: ignore[attr-defined]
+        return obj
+
+    def __str__(self):
+        return "~" + str(self.token)  # type: ignore[attr-defined]
+
+    __repr__ = __str__
+
+    def __eq__(self, other):
+        return type(self) is type(other) and self.token == other.token  # type: ignore[attr-defined]
+
+    def __hash__(self):
+        return hash((type(self), self.token))  # type: ignore[attr-defined]
+
+
+def var():
+    return lambda *args: Var(*args)
+
+
+def vars():
+    return lambda n: [var() for i in range(n)]
+
+
+@dispatch(Var)
+def isvar(v):
+    return True
+
+
+isvar
+
+
+@dispatch(object)  # type: ignore[no-redef]
+def isvar(o):
+    return _glv and hashable(o) and o in _glv
+
+
+@contextmanager
+def variables(*variables):
+    """
+    Context manager for logic variables
+
+    Example:
+        >>> # xdoctest: +SKIP("undefined vars")
+        >>> from __future__ import with_statement
+        >>> with variables(1):
+        ...     print(isvar(1))
+        True
+        >>> print(isvar(1))
+        False
+        >>> # Normal approach
+        >>> from unification import unify
+        >>> x = var("x")
+        >>> unify(x, 1)
+        {~x: 1}
+        >>> # Context Manager approach
+        >>> with variables("x"):
+        ...     print(unify("x", 1))
+        {'x': 1}
+    """
+    old_global_logic_variables = _global_logic_variables.copy()
+    _global_logic_variables.update(set(variables))
+    try:
+        yield
+    finally:
+        _global_logic_variables.clear()
+        _global_logic_variables.update(old_global_logic_variables)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py
new file mode 100644
index 0000000000000000000000000000000000000000..efafb146179a6c35e0a5ccb9a29893aa3a379a87
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py
@@ -0,0 +1,124 @@
+# mypy: allow-untyped-defs
+from torch.fx.experimental.graph_gradual_typechecker import Refine
+from torch.fx.experimental.unification import unify, Var  # type: ignore[attr-defined]
+from torch.fx.tensor_type import TensorType
+
+
+def infer_symbolic_types_single_pass(traced):
+    """
+    Calls our symbolic inferencer once.
+    """
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+
+def infer_symbolic_types(traced):
+    """
+    Calls our symbolic inferencer twice.
+    This is useful when one pass is not enough
+    to infer all the information such as the case
+    for braodcasting.
+    """
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r.symbolic_relations()
+
+
+def convert_eq(list_of_eq):
+    """
+    Convert equality constraints in the right format
+    to be used by unification library.
+    """
+    lhs = []
+    rhs = []
+    for eq in list_of_eq:
+        lhs.append(eq.lhs)
+        rhs.append(eq.rhs)
+    return tuple(lhs), tuple(rhs)
+
+
+def unify_eq(list_of_eq):
+    """
+    Apply unification to a set of
+    equality constraints
+    """
+    lhs, rhs = convert_eq(list_of_eq)
+    return unify(lhs, rhs)
+
+
+def substitute_solution_one_type(mapping, t):
+    """
+    Apply the most general unifier to a type
+    """
+    if isinstance(t, Var):
+        if t in mapping:
+            return mapping[t]
+        else:
+            return t
+
+    elif isinstance(t, TensorType):
+        new_type = []
+        for typ in t.__args__:
+            if typ in mapping:
+                new_type.append(mapping[typ])
+            else:
+                new_type.append(typ)
+        return TensorType(tuple(new_type))
+
+    elif isinstance(t, list):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return new_type
+
+    elif isinstance(t, tuple):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return tuple(new_type)
+
+    else:
+        return t
+
+
+def substitute_all_types(graph, mapping):
+    """
+    Apply the most general unifier to all types in a graph
+    till reaching a fixed point. If the input and output graph
+    are the same, we converge.
+    """
+    flag = True
+    while flag:
+        flag = False
+        for k in mapping:
+            old_mapping_val = mapping[k]
+            if mapping[k] in mapping:
+                new_key = mapping[k]
+                mapping[k] = mapping[new_key]
+            if old_mapping_val != mapping[k]:
+                flag = True
+
+    for n in graph.nodes:
+        n.type = substitute_solution_one_type(mapping, n.type)
+
+
+def check_for_type_equality(g1, g2):
+    """
+    A check equality to be used in fixed points.
+    We do not use graph equality but instead type
+    equality.
+    """
+    for n, m in zip(g1.nodes, g2.nodes):
+        if n.type != m.type:
+            return False
+    return True
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/validator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/validator.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b8b871626af81f23f4e88cc5f57161ed1287ad
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/experimental/validator.py
@@ -0,0 +1,874 @@
+# mypy: allow-untyped-defs
+import builtins
+import functools
+import logging
+import math
+import operator
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import sympy
+
+import torch
+import torch.fx
+import torch.fx.traceback as fx_traceback
+from torch._dynamo.exc import TorchDynamoException
+from torch._dynamo.utils import dynamo_timed
+from torch.fx.node import Argument, Target
+from torch.utils._sympy.interp import sympy_interp
+
+
+log = logging.getLogger(__name__)
+
+try:
+    import z3  # type: ignore[import]
+
+    # Translation Validation for Dynamo guards
+    # ========================================
+    #
+    # Checks whether optimizations applied to the collected guards are
+    # valid. In other words, whether the guard function we actually run
+    # does not have false positives (unsound).
+    #
+    # In order to do so, we build the guards using 2 different information
+    # attached to each 'SymNode':
+    #   1. SymPy expressions
+    #   2. FX nodes
+    #
+    # SymPy expressions have implicit optimizations baked within itself,
+    # which may have a few bugs. On the other hand, we build the FX graph
+    # manually, with no optimizations enabled. This gives us access to
+    # the "ground truth".
+    #
+    # We then convert into Z3 expressions both the SymPy expressions
+    # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
+    # and the FX nodes (see [Note: PopulateValidator]) that go through
+    # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
+    # (see [Note: TranslationValidator])
+    # Better Z3 to string implementation (for a small fraction of Z3).
+    #
+    # Here are the things we clean before showing the Z3 expression:
+    #   - Rename a few ops (e.g. "Distinct" ==> "!=")
+    #
+    #   - Ignore ToInt and ToReal operations:
+    #     usually they don't really matter
+    #
+    #   - Transform (ToInt (/ ...)) into (idiv ...):
+    #     this is the pattern for floor division
+    #
+    #   - Collect a chain of the same operations into one
+    def z3str(e: z3.ExprRef) -> str:
+        assert z3.is_expr(e), f"unsupported expression type: {e}"
+
+        def get_args_str(e: z3.ExprRef) -> list[str]:
+            return [z3str(e.arg(i)) for i in range(e.num_args())]
+
+        # First, we simplify the given expression.
+        # This is done using rewriting rules, so shouldn't take long.
+        e = z3.simplify(e)
+
+        # Only support function applications.
+        # Even Z3 "variables" are, in fact, function applications.
+        if not z3.is_app(e):
+            raise ValueError(f"can't print Z3 expression: {e}")
+
+        if z3.is_int_value(e) or z3.is_rational_value(e):
+            return e.as_string()  # type: ignore[attr-defined]
+
+        decl = e.decl()
+        kind = decl.kind()
+        op = str(decl)
+        args = get_args_str(e)
+
+        if kind == z3.Z3_OP_POWER:
+            op = "pow"
+
+        elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
+            # Collect the arguments of chains of ADD and MUL.
+            # This is safe, since they are associative.
+
+            def collect_str_args(e):
+                if not (z3.is_app(e) and e.decl().kind() == kind):
+                    return [z3str(e)]
+                else:
+                    return [
+                        x
+                        for i in range(e.num_args())
+                        for x in collect_str_args(e.arg(i))
+                    ]
+
+            args = collect_str_args(e)
+
+        elif kind == z3.Z3_OP_NOT:
+            # Revert some conversions that z3.simplify applies:
+            #   - a != b ==> (Not (== a b)) ==> (!= a b)
+            #   - a < b ==> (Not (<= b a)) ==> (> b a)
+            #   - a > b ==> (Not (<= a b)) ==> (> a b)
+
+            assert e.num_args() == 1
+            arg = e.arg(0)
+
+            assert z3.is_app(arg)
+            argkind = arg.decl().kind()
+
+            logic_inverse = {
+                z3.Z3_OP_EQ: "!=",
+                z3.Z3_OP_LE: ">",
+                z3.Z3_OP_GE: "<",
+            }
+
+            if argkind in logic_inverse:
+                op = logic_inverse[argkind]
+                args = get_args_str(arg)
+
+        elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
+            assert e.num_args() == 1
+            argstr = z3str(e.arg(0))
+
+            # Check if it's the floor division pattern.
+            if argstr.startswith("(/"):
+                return "(idiv" + argstr[2:]
+
+            # Otherwise, just ignore it.
+            return argstr
+
+        elif kind == z3.Z3_OP_UNINTERPRETED:
+            assert e.num_args() == 0
+            return str(decl)
+
+        string = op + " " + " ".join(args)
+        return f"({string.rstrip()})"
+
+    # We need to convert to/from BitVec in order to use z3 bitwise ops.
+    # We assume that integers are 64 bit.
+    # If all args are boolean, then use the boolean bitwise op implementation instead, if provided.
+    def _bitwise_op(bitwise_func, bool_func):
+        @functools.wraps(bitwise_func)
+        def wrapper(self, *args):
+            if bool_func is not None and all(
+                isinstance(arg, z3.BoolRef) for arg in args
+            ):
+                return bool_func(*args)
+
+            wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
+            return z3.BV2Int(bitwise_func(*wrapped_args))
+
+        return wrapper
+
+    # Implementation of Python semantics as Z3 expressions.
+    #
+    # Z3 Real-Int theory has operators with semantics that differ that of
+    # Python. Therefore, in order to get it right, we need to implement
+    # the (Python) semantics we are relying on in Z3.
+    @dataclass
+    class _Z3Ops:
+        # Validator used for adding assertions as needed.
+        # e.g. div(a, b) requires b != 0.
+        validator: "TranslationValidator"
+
+        # The 2 functions below are used for conditionally casting between
+        # integer and reals.
+        #
+        # Returns a real expression from 'x'.
+        @staticmethod
+        def to_real(x: z3.ArithRef) -> z3.ArithRef:
+            return x if x.is_real() else z3.ToReal(x)
+
+        # Returns an integer expression from 'x'.
+        @staticmethod
+        def to_int(x: z3.ArithRef) -> z3.ArithRef:
+            return x if x.is_int() else z3.ToInt(x)
+
+        def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
+            # pyrefly: ignore
+            return sum(args)
+
+        # Implements Python division semantics.
+        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            self.validator.add_assertion(denominator != 0)  # type: ignore[arg-type]
+            return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
+
+        def floor(self, number: z3.ArithRef) -> z3.ArithRef:
+            # Z3 ToInt function rounds a real number towards negative infinity.
+            return _Z3Ops.to_int(number)
+
+        # Python semantics for 'FloorDiv' states that before applying the floor
+        # function, the operands are converted to their common type.
+        def floordiv(
+            self, numerator: z3.ArithRef, denominator: z3.ArithRef
+        ) -> z3.ArithRef:
+            cast_result_to_real = numerator.is_real() or denominator.is_real()
+            result = _Z3Ops.to_int(self.div(numerator, denominator))
+            # Since the 'result' is already an integer, we just have to check
+            # whether we should cast it to real.
+            return _Z3Ops.to_real(result) if cast_result_to_real else result
+
+        def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(self.floor(number) < number, self.floor(number + 1), number)  # type: ignore[return-value]
+
+        def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(number >= 0, self.floor(number), self.ceil(number))  # type: ignore[return-value]
+
+        def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(a > b, a, b)  # type: ignore[return-value]
+
+        def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(a < b, a, b)  # type: ignore[return-value]
+
+        # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
+        # It should work with both integer and reals.
+        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
+            return p - self.floordiv(p, q) * q
+
+        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
+            # Z3 can't handle complex numbers very well.
+            self.validator.add_assertion(z3.Or(base != 0, exp > 0))  # type: ignore[arg-type]
+            return base**exp
+
+        def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
+            # Square-root:
+            # 1. Only work with reals
+            number = _Z3Ops.to_real(number)
+            # 2. The number should be positive or zero.
+            #    Otherwise, Z3 returns 'unknown'.
+            self.validator.add_assertion(number >= 0)
+            return number**0.5
+
+        def abs(self, number: z3.ArithRef) -> z3.ArithRef:
+            return z3.Abs(number)
+
+        def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
+            # Pythons builtin 'round' implements the 'round half to even' strategy
+            # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
+            # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
+            # floating point numbers, which is different from real numbers that we are dealing with here.
+            # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
+            # 'round half down' (ceil(x - 0.5)).
+            # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
+            # to round down, i.e. use the 'round half down' strategy
+            return z3.If(
+                self.mod(number, z3.IntVal(2)) == 0.5,
+                self.ceil(number - 0.5),
+                self.floor(number + 0.5),
+            )
+
+        bitwise_and = _bitwise_op(operator.and_, z3.And)
+        bitwise_or = _bitwise_op(operator.or_, z3.Or)
+        lshift = _bitwise_op(operator.lshift, None)
+        rshift = _bitwise_op(operator.rshift, None)
+
+    # Lifts a callable to be used in Z3.
+    #
+    # This function replaces the given 'op' by a function that:
+    #
+    #   1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
+    #
+    #   2. Calls an operation that corresponds to 'op', but works with Z3
+    #      inhabitants (left as is if it works as is)
+    def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
+        # Operations that have booleans as their argument.
+        # This is needed because the argument of some FX nodes were
+        # literal integers, instead of booleans. So, whenever this flag
+        # is set, we also convert ints to booleans.
+        boolean_ops = {operator.not_}
+        as_bool = op in boolean_ops
+
+        # Lifts the function into 'z3.ExprRef' domain.
+        def lift(func):
+            def wrap(a) -> z3.ExprRef:
+                if isinstance(a, (z3.ArithRef, z3.BoolRef)):
+                    return a
+                # Convert it into a Z3 value, if it is some of the supported
+                # types below.
+                if isinstance(a, bool) or (as_bool and isinstance(a, int)):
+                    return z3.BoolVal(bool(a))
+                if isinstance(a, (int, sympy.Integer)):
+                    return z3.IntVal(int(a))
+                if isinstance(a, (float, sympy.Float)):
+                    return z3.RealVal(float(a))
+                raise ValueError(f"can't lift type: {type(a)}")
+
+            @functools.wraps(func)
+            def wrapper(*args):
+                # Lifts the arguments into a list of Z3 inhabitants.
+                if len(args) == 1 and isinstance(args[0], (list, tuple)):
+                    wrapped_args = (tuple(wrap(a) for a in args[0]),)
+                else:
+                    wrapped_args = tuple(wrap(a) for a in args)
+                # Run the function on the Z3 expressions.
+                return func(*wrapped_args)
+
+            return wrapper
+
+        ops = _Z3Ops(validator)
+        replacement_map = {
+            # Operator module.
+            operator.not_: lift(z3.Not),
+            operator.and_: lift(ops.bitwise_and),
+            operator.or_: lift(ops.bitwise_or),
+            operator.lshift: lift(ops.lshift),
+            operator.rshift: lift(ops.rshift),
+            operator.floordiv: lift(ops.floordiv),
+            operator.truediv: lift(ops.div),
+            operator.mod: lift(ops.mod),
+            operator.abs: lift(ops.abs),
+            builtins.round: lift(ops.round_to_int),
+            # Math module.
+            math.ceil: lift(ops.ceil),
+            math.floor: lift(ops.floor),
+            math.trunc: lift(ops.trunc),
+            # Torch module.
+            torch.sym_float: lift(ops.to_real),
+            torch.sym_max: lift(ops.max),
+            torch.sym_min: lift(ops.min),
+            torch.sym_sum: lift(ops.sym_sum),
+            torch.sym_ite: lift(lambda b, t, f: t if b else f),
+            torch._sym_sqrt: lift(ops.sqrt),  # type: ignore[attr-defined]
+            # Not lifted because we only use this function as a
+            # marker for adding the expression as validator input.
+            torch._assert: torch._assert,
+        }
+        return replacement_map[op] if op in replacement_map else lift(op)
+
+    # Processes an FX graph, populating the given validator.
+    #
+    # [Note: PopulateValidator]
+    # This class walks through each node in the FX graph, translating
+    # them into the Z3 world.
+    #
+    # Then, whenever it finds an 'torch._assert' call_function operation,
+    # it adds the Z3 expression corresponding to the argument as validator
+    # input.
+    class PopulateValidator(torch.fx.Interpreter):
+        def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
+            # Reference to the translation validator.
+            self.validator = validator
+
+            # Build the graph module and call `Interpreter` constructor.
+            module = torch.fx.GraphModule(root={}, graph=graph)
+            super().__init__(module, garbage_collect_values=True)
+
+        def placeholder(
+            self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+        ) -> Any:
+            symbol = fx_traceback.get_current_meta()["symbol"]
+            return self.validator.z3var(symbol)
+
+        def call_function(
+            self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
+        ) -> Any:
+            if target is not torch._assert:
+                # Lift and runs the node target function
+                return super().call_function(z3op(target, self.validator), args, kwargs)  # type: ignore[arg-type]
+            # Adds the Z3 expression corresponding to the first argument
+            # as a validator input.
+            assert len(args) == 1, (
+                f"expected 1 argument on assertion. Got: {len(args)} "
+            )
+            self.validator.add_source_expr(args[0])  # type: ignore[arg-type]
+
+    # Translates SymPy expressions into Z3 expressions.
+    #
+    # [Note: SympyToZ3]
+    # At the time of the translation, all free variables present in the
+    # SymPy expression being translated must be already mapped to a Z3
+    # integer variable.
+    class SympyToZ3:
+        OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
+
+        def __init__(
+            self,
+            validator: "TranslationValidator",
+        ) -> None:
+            self._validator = validator
+            self._ops = _Z3Ops(self._validator)
+
+        def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
+            # TODO: Probably OK to relax this and allow lower precision
+            if dtype is torch.int64:
+                return z3.IntVal(int(value))
+            if dtype is torch.double:
+                return z3.RealVal(float(value))
+            if dtype is torch.bool:
+                return z3.BoolVal(bool(value))
+            raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
+
+        def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
+            if dtype == torch.float64:
+                return z3.ToReal(x)
+            raise NotImplementedError(f"to_dtype {dtype} NYI")
+
+        def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
+            return z3.ToInt(x)
+
+        def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
+            return self._ops.round_to_int(x)
+
+        def int_truediv(
+            self, numerator: z3.ArithRef, denominator: z3.ArithRef
+        ) -> z3.ArithRef:
+            return self._ops.div(numerator, denominator)
+
+        def truediv(
+            self, numerator: z3.ArithRef, denominator: z3.ArithRef
+        ) -> z3.ArithRef:
+            return self._ops.div(numerator, denominator)
+
+        def floordiv(
+            self, numerator: z3.ArithRef, denominator: z3.ArithRef
+        ) -> z3.ArithRef:
+            return self._ops.floordiv(numerator, denominator)
+
+        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.floordiv(numerator, denominator)
+
+        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.pow(base, exp)
+
+        def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.pow(base, exp)
+
+        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.mod(p, q)
+
+        def python_mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.mod(p, q)
+
+        def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
+            return self._ops.ceil(x)
+
+        def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
+            return self._ops.floor(x)
+
+        def __getattr__(self, name: str) -> Any:
+            REPLACEMENT = {
+                "and_": z3.And,
+                "or_": z3.Or,
+                "not_": z3.Not,
+                "bitwise_and": self._ops.bitwise_and,
+                "bitwise_or": self._ops.bitwise_or,
+                "lshift": self._ops.lshift,
+                "rshift": self._ops.rshift,
+                "floor": self._ops.floor,
+                "ceil": self._ops.ceil,
+                "minimum": self._ops.min,
+                "maximum": self._ops.max,
+            }
+
+            if name in REPLACEMENT:
+                return REPLACEMENT[name]
+            if name in self.OPERATOR_HANDLES:
+                return getattr(operator, name)
+            raise AttributeError(f"unhandled operator: {name}")
+
+        def run(self, expr: sympy.Basic) -> z3.ExprRef:
+            return sympy_interp(self, self._validator.symbols, expr)  # type: ignore[arg-type]
+
+    # Dynamo guards translation validator.
+    #
+    # [Note: TranslationValidator]
+    # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
+    # That is: whether those (target) guards only yield TRUE whenever the original,
+    # unoptimized, (source) guards yield TRUE.
+    #
+    # More concretely, given 'source' and 'target' guard expressions, we wish to
+    # check whether the following expression holds:
+    #
+    # Not(And(source)) AND And(target)
+    #
+    # i.e. whether there is an assignment of the free variables where the opposite
+    # happens: target is TRUE, but source is FALSE.
+    class TranslationValidator:
+        def __init__(self) -> None:
+            log.debug("new instance")
+
+            # Mapping of SymPy symbols to Z3 variables.
+            self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
+
+            # Set of source Z3 expressions.
+            # They represent the generated guards without any kind of
+            # simplification or transformation.
+            self._source_exprs: set[z3.BoolRef] = set()
+
+            # Set of target Z3 expressions.
+            # They represent the actual checked guards at runtime. They might
+            # be simplified or transformed versions of the source guards.
+            self._target_exprs: set[z3.BoolRef] = set()
+
+            # Set of Z3 expressions representing assertions over both the
+            # source and target expressions.
+            self._assertions: set[z3.BoolRef] = set()
+
+        # Retrieves the corresponding Z3 variable.
+        def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
+            assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
+            return self.symbols[symbol]
+
+        # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
+        def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
+            if symbol in self.symbols:
+                return self.symbols[symbol]
+
+            log.debug("new variable: %s (%s)", symbol.name, type.__name__)
+
+            if type is int:
+                var = z3.Int(symbol.name)
+
+                # If 'symbol' is positive (SymPy assumption), we have to
+                # convey it to Z3 as well.
+                if symbol.is_positive:  # type: ignore[attr-defined]
+                    self._target_exprs.add(var > 0)
+            elif type is float:
+                var = z3.Real(symbol.name)
+            elif type is bool:
+                var = z3.Bool(symbol.name)
+            else:
+                raise RuntimeError(f"unsupported type for Z3 variable: {type}")
+
+            self.symbols[symbol] = var
+            return var
+
+        # Checks whether all symbols were already added.
+        def _check_freesymbols(self, e: sympy.Basic) -> None:
+            for s in e.free_symbols:
+                assert isinstance(s, sympy.Symbol)
+                # Call 'z3var' just to check whether there's already a
+                # Z3 variable corresponding to 's'.
+                self.z3var(s)
+
+        def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
+            z3expr = SympyToZ3(self).run(e)
+            assert isinstance(z3expr, z3.BoolRef), (
+                f"expected boolean expression. Got: {z3expr}"
+            )
+            return z3expr
+
+        def add_source_expr(self, e: z3.BoolRef) -> None:
+            if e not in self._source_exprs:
+                log.debug("add source guard: %s", z3str(e))
+            self._source_exprs.add(e)
+
+        def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None:
+            self._check_freesymbols(e)
+            z3expr = self.to_z3_boolean_expr(e)
+            if e not in self._target_exprs:
+                log.debug("add target guard: %s", z3str(z3expr))
+            self._target_exprs.add(z3expr)
+
+        def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
+            if isinstance(e, sympy.Basic):
+                self._check_freesymbols(e)
+                ref = self.to_z3_boolean_expr(e)
+            else:
+                ref = e
+            assert isinstance(ref, z3.BoolRef)
+            if ref not in self._assertions:
+                log.debug("add assertion: %s", z3str(ref))
+            self._assertions.add(ref)
+
+        def validate(self) -> None:
+            with dynamo_timed("TranslationValidator.validate"):
+                return self._validate()
+
+        def _validate(self) -> None:
+            if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
+                # If there are no source/target expressions, there's nothing we really
+                # wish to prove. So, we just return.
+                return None
+
+            # Here, we use "QF_NRA" logic for the solver:
+            #   "Quantifier-free Non-linear Real Arithmetic".
+            #
+            # Most of the guards expressions have:
+            #   1. arithmetic between integer and reals
+            #   2. no quantifiers
+            #   3. potentially non-linear.
+            #
+            # Although there's also "QF_NIRA" (mixed integer-real arithmetic),
+            # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
+            solver = z3.SolverFor("QF_NRA")
+            # Set a timeout for finding a solution.
+            solver.set(timeout=translation_validation_timeout())
+
+            # Add all the assertions to the solver.
+            for assertion in self._assertions:
+                solver.add(assertion)
+
+            # "Is there any case where it's TRUE for the target expressions,
+            #  but FALSE for the source expressions?"
+            solver.add(z3.Not(z3.And(*self._source_exprs)))
+            solver.add(*self._target_exprs)
+
+            log.debug("translation validation: start")
+            r = solver.check()
+            if r == z3.sat:
+                # Target expressions are unsound.
+                # Log the found model and the source expressions that failed.
+                model = solver.model()
+                raise ValidationException(
+                    model,
+                    self._assertions,
+                    self._target_exprs,
+                    failed_source_exprs=[
+                        inp for inp in self._source_exprs if not model.evaluate(inp)
+                    ],
+                )
+            else:
+                if r == z3.unknown:
+                    # Could not find a solution. It didn't fail, but it also
+                    # didn't succeed. Canceling the validation execution (keyboard
+                    # interrupt) also gets to this branch.
+                    log.warning(
+                        "translation validation: could not validate: got z3.unknown"
+                    )
+                else:
+                    # Target expressions are sound.
+                    assert r == z3.unsat
+                    log.debug("translation validation: success")
+
+except ImportError:
+    _HAS_Z3 = False
+
+    __all__ = [
+        "translation_validation_enabled",
+        "translation_validation_timeout",
+        "ValidationException",
+        "BisectValidationException",
+    ]
+
+else:
+    _HAS_Z3 = True
+
+    __all__ = [
+        "z3str",
+        "z3op",
+        "PopulateValidator",
+        "SympyToZ3",
+        "TranslationValidator",
+        "translation_validation_enabled",
+        "translation_validation_timeout",
+        "ValidationException",
+        "BisectValidationException",
+    ]
+
+from torch.fx.experimental import _config as config
+
+
+def translation_validation_enabled() -> bool:
+    # Checks every time this function is called, in case the Dynamo
+    # option is set, but Z3 is not installed.
+    _assert_z3_installed_if_tv_set()
+    return _HAS_Z3 and config.translation_validation
+
+
+def translation_validation_timeout() -> int:
+    return config.translation_validation_timeout
+
+
+def _assert_z3_installed_if_tv_set():
+    assert _HAS_Z3 or not config.translation_validation, (
+        "translation validation requires Z3 package. Please, either install "
+        "z3-solver or disable translation validation."
+    )
+
+
+class ValidationException(TorchDynamoException):
+    def __init__(self, model, assertions, target_exprs, failed_source_exprs):
+        assert _HAS_Z3
+
+        def symbolstr(sym) -> str:
+            return f"{sym}: {model[sym]}"
+
+        def joinlines(xs) -> str:
+            return "\n".join(f"  ==> {x}" for x in xs)
+
+        model_str = joinlines(sorted(map(symbolstr, model)))
+        assertions_str = joinlines(sorted(map(z3str, assertions)))
+        target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
+        failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
+
+        self.msg = "translation validation failed."
+        self.details = f"""\
+Model:
+{model_str}
+
+Assertions:
+{assertions_str}
+
+Target Expressions:
+{target_exprs_str}
+
+Failed Source Expressions:
+{failed_source_exprs_str}"""
+
+    def __str__(self):
+        return f"{self.msg}\n\n{self.details}"
+
+
+class BisectValidationException(TorchDynamoException):
+    def __init__(self, validation_exc, expr, failed_action, traced_node):
+        self.msg = f"translation validation failed when {failed_action}: {expr}"
+        self.details = f"""\
+Failure occurred while running node:
+    {traced_node.format_node()}
+
+{validation_exc.details}"""
+
+    def __str__(self):
+        return f"{self.msg}\n\n{self.details}"
+
+
+# Checks when this module is loaded.
+_assert_z3_installed_if_tv_set()
+
+
+# Translation validation bisection.
+#
+# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
+# the earliest ValidationException.
+#
+# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
+# might be silently happening. This function tries to nail down exactly at which
+# point things went wrong from a validation perspective.
+def bisect(shape_env):
+    from torch.fx.experimental.recording import (
+        FakeTensorMeta,
+        replay_shape_env_events,
+        ShapeEnvEvent,
+    )
+    from torch.fx.experimental.symbolic_shapes import (
+        CURRENT_NODE_KEY,
+        ShapeEnv,
+        SHAPEENV_EVENT_KEY,
+    )
+
+    events = shape_env.events
+
+    # Retrieves the ShapeEnvEvent associated with node.
+    def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
+        assert SHAPEENV_EVENT_KEY in node.meta
+        return events[node.meta[SHAPEENV_EVENT_KEY]]
+
+    # Creates a new instance of fake, but updating every symbolic value's ShapeEnv
+    # reference to the one given as argument.
+    #
+    # This is needed so as not to simplify a symbolic expression using a ShapeEnv
+    # "from the future", where it may have a different set of replacements.
+    def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
+        if isinstance(fake, int):
+            return fake
+        if isinstance(fake, torch.SymInt):
+            return torch.SymInt(fake.node.with_shape_env(shape_env))
+        if isinstance(fake, torch.SymFloat):
+            return torch.SymFloat(fake.node.with_shape_env(shape_env))
+        assert isinstance(fake, FakeTensorMeta)
+        return FakeTensorMeta(
+            tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
+            tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
+            new_with_shape_env(shape_env, fake.storage_offset()),
+            fake.is_nested,
+        )
+
+    # Checks whether the given shape_env fails when produce_guards is called.
+    def check_shapeenv_fails(
+        shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
+    ) -> Optional[ValidationException]:
+        assert tracked_fakes is not None
+        try:
+            # This produce_guards call is a best-effort replication, since we
+            # don't populate EqualityConstraint list. Reason: we would also have
+            # to save OutputGraph.tracked_fakes_id_to_source.
+            shape_env.produce_guards(
+                [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
+                [a.source for a in tracked_fakes],
+                input_contexts=[a.symbolic_context for a in tracked_fakes],
+            )
+            return None
+        except ValidationException as e:
+            return e
+
+    # Checks whether the ShapeEnv reconstructed by replaying the events until
+    # node is created fails when produce_guards is called.
+    def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
+        number = node.meta[SHAPEENV_EVENT_KEY]
+        # Reconstruct shape_env until the event at event_number.
+        shape_env = replay_shape_env_events(events[: number + 1])
+        shape_env.graph.lint()
+        return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
+
+    last_exception = check_shapeenv_fails(
+        shape_env, shape_env._snapshot_tracked_fakes()
+    )
+
+    if not last_exception:
+        # We don't actually fail due to a produce_guards call.
+        # Stop and don't bisect.
+        log.info("translation validation succeeded: no errors found.")
+        return
+
+    if not shape_env.should_record_events or config.translation_validation_no_bisect:
+        # Bisection is off.
+        # Return the last ValidationException we got.
+        raise last_exception
+
+    # Cache the raised exception (if any) at each bisection point.
+    exception = {}
+
+    # Bisection happens on the assertion nodes of the recorded FX graph for
+    # dynamic shapes.
+    assert_nodes = [
+        node for node in shape_env.graph.nodes if node.target is torch._assert
+    ]
+
+    # Preparing the indices for binary search.
+    # The overall invariants are
+    # - for all i < left, assert_node[i] doesn't fail
+    # - for all i >= right, assert_node[i] fails
+    # - `right in exception` always holds
+    # - `left <= right` always holds
+    left, mid, right = 0, 0, len(assert_nodes) - 1
+    exception[right] = check_node_fails(assert_nodes[right])
+
+    while left < right:
+        mid = (left + right) // 2
+
+        node = assert_nodes[mid]
+        log.debug("bisecting at %s: %s", mid, get_node_event(node))
+
+        # Check whether the new shape_env raises a ValidationException or not.
+        exception[mid] = check_node_fails(node)
+
+        if exception[mid]:
+            right = mid
+        else:
+            left = mid + 1
+
+    assert left in exception and isinstance(exception[left], ValidationException)
+
+    node = assert_nodes[left]
+    event = get_node_event(node)
+
+    if event.is_evaluate_expr():
+        failed_action = "evaluating"
+    else:
+        assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
+        failed_action = "adding runtime assert"
+
+    args = event.args
+    assert args is not None
+    assert len(args) >= 2, (
+        f"bisecting expects {event.name} to have at least 2 positional arguments. "
+        f"Got: {len(args)}"
+    )
+    assert isinstance(args[1], sympy.Basic), (
+        f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
+        f"Got: {type(args[1])}"
+    )
+
+    raise BisectValidationException(
+        exception[left],
+        expr=args[1],
+        failed_action=failed_action,
+        traced_node=node.meta[CURRENT_NODE_KEY],
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bcb6e1d75a17cbbcf2881b48edc713bd66aa303
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__init__.py
@@ -0,0 +1,15 @@
+from . import (
+    graph_drawer,
+    graph_manipulation,
+    net_min_base,
+    operator_support,
+    param_fetch,
+    regional_inductor,
+    reinplace,
+    runtime_assert,
+    shape_prop,
+    split_module,
+    split_utils,
+    splitter_base,
+    tools_common,
+)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b54ba5cc7f8a367e5bb58ab4295538cab7914c9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d83f503a2bf0321f6f55031129c3ef7e8cff4737
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6710d44d4cee4b35c7ac619386e55f09f8683972
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..058863ab01722db2106cb8eae0ab16ef1a84168b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fba5fa19701ac07eb31585679b34947811f0f45f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72a1241778eae50037d1685f87c20952593e4ec5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d735050b9f9b1d4d681fa928f8e51c3bb1952d2b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1642245af802938b2d794286d1eb1184686b810b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d68452faf249cc12e113753c82355f505f72b52
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66027e7a4b77862791fc42e360a6783b90667f3e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cba005f5c373df46115601c9333b80df67719d87
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ec80c66bcc7bdea49def69c38b713e6ad4b3940
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4c5d3fe4bd533d48fa8a6e90bf63378edb44050
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3f2cd2db152ebabb14bdc6fd78ecb22a99e0a2e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15a8a31e0072fe19669152328898e03f9af24133
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f46d87336fcbdfaf38b948c938520d53bb37abe
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbad602ad4673267472b8ebd9b1d3cc48e0e2f5f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3765db3bff3dcbdb2edf8db716d95f76786049ce
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44ad2da13635a63fe66d11c517d14296c6d7c3c9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e4c6c56bddf9244276680aee1935b5a6c1cb048
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py
@@ -0,0 +1,407 @@
+from __future__ import annotations
+
+import logging
+import os
+from typing import Any, TYPE_CHECKING, Union
+
+from sympy import Integer, Number, Symbol
+from sympy.logic.boolalg import BooleanAtom
+
+import torch
+import torch.fx as fx
+from torch._dynamo.exc import TensorifyScalarRestartAnalysis
+from torch._dynamo.symbolic_convert import TensorifyState
+from torch._dynamo.utils import get_metrics_context
+from torch._prims_common import get_computation_dtype
+from torch._subclasses.fake_tensor import FakeTensor
+from torch._utils_internal import justknobs_check
+from torch.fx._utils import lazy_format_graph_code
+from torch.fx.experimental.symbolic_shapes import (
+    guard_scalar,
+    has_free_symbols,
+    ShapeEnv,
+)
+
+# TODO: refactor
+from torch.fx.passes.runtime_assert import _get_sym_val
+from torch.fx.proxy import MetaProxy
+from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
+from torch.utils._sympy.reference import TensorReferenceAnalysis
+from torch.utils._sympy.symbol import symbol_is_type, SymT
+
+
+if TYPE_CHECKING:
+    from torch._subclasses import fake_tensor
+    from torch.fx.graph_module import GraphModule
+
+
+__all__: list[str] = []
+
+log = logging.getLogger(__name__)
+graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")
+
+# The general shape of this transformation is to look for Tensor operations
+# that take a backed SymFloat as an argument, and then redo them as tensor
+# compute (with ints and tensors as inputs). For example, add(Tensor, Scalar)
+# can be translated into add(Tensor, Tensor). Because Dynamo has already
+# arranged for floats to be Tensor inputs to the graph, for typical float
+# compute you can entirely translate the Python float operations into Tensor
+# operations with only Tensor inputs.
+#
+# This pass is also responsible for doing CSE on the fly as we do this, since
+# you don't want to keep recomputing the same quantity over and over again if
+# it's used multiple times.
+#
+# This pass runs on the JOINT graph produced by AOT Autograd, prior to partitioning.
+# The primary goal of this pass is to eliminate floats by replacing TensorScalar
+# operations with TensorTensor operations and then Dead Code Elimination (DCE) of
+# the item calls, which effectively removes the floats.
+#
+# This needs to happen before partitioning because it influences partitioning decisions,
+# specifically by ensuring that we don't need to save floats across partitions.
+# Additionally, there is a separate pass that changes which device computations
+# occur on. That pass must be run after this one, but still before partitioning.
+#
+# HISTORY NOTE: Originally, I wanted to formulate this pass as pushing item()
+# calls down, transforming float compute into int compute as we went. If you
+# manage to eliminate all float compute, this ends up being equivalent, but
+# there is a critical difference when some floats cannot be eliminated: when
+# we call item() on them, what should it's SymFloat be? Ideally, it would
+# be the same backed SymFloat we had before. But without symbolic expression
+# propagation on tensor quantities, repropagating would instead give you an
+# unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation
+# on 0d scalar tensors, but I decided to go for something simpler to start.
+#
+# The boring stuff:
+#
+# * What operators can I Tensor-ify? (Anything with a Scalar argument)
+# * How do I Tensor-ify a SymFloat sympy expression (Sympy -> Op Handler -> Tensor)
+#
+# TODO: make sure this runs before CPU->CUDA pass for cudagraph friendliness
+
+
+SUPPORTED_OPS = {
+    torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor,
+    torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor,
+    torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor,
+    torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor,
+    torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
+    torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
+    torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
+    torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
+    torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
+    torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
+}
+
+
+@torch.fx._compatibility.compatibility(is_backward_compatible=False)
+def tensorify_python_scalars(
+    gm: GraphModule, shape_env: ShapeEnv, fake_mode: fake_tensor.FakeTensorMode
+) -> None:
+    """
+    Converts Python scalar operations into Tensor operations within the graph. This pass looks for
+    Tensor operations that involve SymFloat arguments and transforms them into equivalent operations
+    that use only Tensor inputs.
+
+    Args:
+        gm: The FX graph module representing the computation graph.
+        shape_env: The shape environment responsible for symbolic shape tracking and propagation
+        during graph transformations.
+
+    Returns:
+        None
+    """
+    import sympy
+
+    knob = True
+    if (env := os.getenv("TENSORIFY_PYTHON_SCALARS")) is not None:
+        if env in ("0", "FALSE"):
+            knob = False
+    else:
+        knob = justknobs_check("pytorch/compiler:tensorify_python_scalars")
+    if not knob:
+        return None
+
+    graph = gm.graph
+    tracer = fx.proxy.GraphAppendingTracer(graph)
+    expr_to_sym_proxy: dict[sympy.Expr, MetaProxy] = {}
+    expr_to_tensor_proxy: dict[sympy.Expr, MetaProxy] = {}
+    tensorified_symbols: set[sympy.Symbol] = set()
+    should_restart = False
+
+    first_non_placeholder = None
+    placeholders = set()
+    for node in graph.nodes:
+        if node.op != "placeholder":
+            first_non_placeholder = node
+            break
+        else:
+            placeholders.add(node)
+
+    Analysis = TensorReferenceAnalysis
+
+    def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
+        # sympy_interp() with hash consing, and special handling for
+        # generating constants correctly
+
+        # hash cons
+        if isinstance(expr, Symbol) and expr not in expr_to_tensor_proxy:
+            # This is guaranteed to be populated by invariant established by
+            # insert_deferred_runtime_asserts
+            expr_to_tensor_proxy[expr] = torch.ops.aten.scalar_tensor.default(
+                expr_to_sym_proxy[expr]
+            )
+
+        # cache constants, why not
+        if isinstance(expr, (Integer, Number, BooleanAtom)):
+            dtype = None
+            c: Union[bool, int, float]
+            if isinstance(expr, BooleanAtom):
+                dtype = torch.bool
+                c = bool(expr)
+            elif isinstance(expr, sympy.Integer):
+                dtype = torch.int64
+                c = int(expr)
+            elif isinstance(expr, sympy.Number):
+                dtype = torch.float64
+                c = float(expr)
+
+            node = graph.call_function(
+                torch.ops.aten.scalar_tensor.default,
+                # pyrefly: ignore [unbound-name]
+                (c,),
+                {"dtype": dtype},
+            )
+            with fake_mode:
+                # pyrefly: ignore [unbound-name]
+                node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype)
+            expr_to_tensor_proxy[expr] = MetaProxy(
+                node,
+                tracer=tracer,
+                fake_mode=fake_mode,
+            )
+
+        if expr in expr_to_tensor_proxy:
+            return expr_to_tensor_proxy[expr]
+
+        # don't cache
+        if isinstance(expr, Symbol):
+            return sympy_interp(Analysis, expr_to_tensor_proxy, expr)  # type: ignore[arg-type]
+
+        # hash cons on arguments, run expr handler
+        expr_to_tensor_proxy[expr] = _run_sympy_handler(
+            Analysis,
+            [_sympy_interp(arg) for arg in expr.args],  # type: ignore[arg-type]
+            expr,
+        )
+
+        return expr_to_tensor_proxy[expr]
+
+    failed_tensorify_ops: set[str] = set()
+    nodes = list(graph.nodes)
+    for i, node in enumerate(nodes[:-1]):
+        with graph.inserting_before(
+            nodes[i + 1] if node not in placeholders else first_non_placeholder
+        ):
+            # Look for tensor.item() calls on placeholders
+            if (
+                node is not None
+                and node.op == "call_function"
+                and node.target is torch.ops.aten._local_scalar_dense.default
+            ):
+                dtype = node.args[0].meta["val"].dtype
+
+                assert isinstance(node.args[0], fx.Node), node.args[0]
+
+                s = node.meta["val"].node.expr
+
+                expr_to_sym_proxy[s] = MetaProxy(
+                    node, tracer=tracer, fake_mode=fake_mode
+                )
+
+                # only tensorify if the dtype is floating point
+                if not dtype.is_floating_point:
+                    continue
+
+                expr_to_tensor_proxy[s] = MetaProxy(
+                    node.args[0], tracer=tracer, fake_mode=fake_mode
+                )
+                # Upcast the float tensor to torch.float64 to avoid precision problem
+                expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
+                    expr_to_tensor_proxy[s], torch.float64
+                )
+
+            # pyrefly: ignore [bad-argument-type]
+            elif (sym_expr := _get_sym_val(node)) is not None:
+                if sym_expr not in expr_to_sym_proxy and not isinstance(
+                    sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
+                ):
+                    expr_to_sym_proxy[sym_expr] = MetaProxy(
+                        # pyrefly: ignore [bad-argument-type]
+                        node,
+                        tracer=tracer,
+                        fake_mode=fake_mode,
+                    )
+
+            # Specialize all dimensions that contain symfloats. Here's
+            # an example test that requires this:
+            # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950
+            # pyrefly: ignore [missing-attribute]
+            val = node.meta.get("val")
+            if isinstance(val, FakeTensor):
+                for dim in val.shape:
+                    if isinstance(dim, torch.SymInt):
+                        for s in dim.node.expr.free_symbols:
+                            name = str(s)
+                            if symbol_is_type(
+                                s, SymT.FLOAT
+                            ) and not TensorifyState.should_specialize(name):
+                                # In principle, we could support float input that
+                                # is used to do size compute. The problem is that
+                                # we don't actually want to tensorify the compute
+                                # in this case, which means we need codegen support for
+                                # all symfloats.
+                                TensorifyState.specialize(name)
+                                should_restart = True
+
+            # Look for functions to convert
+            # pyrefly: ignore [missing-attribute]
+            if node.op == "call_function" and (
+                # pyrefly: ignore [missing-attribute]
+                replacement_op := SUPPORTED_OPS.get(node.target)
+            ):
+                args: list[Any] = []
+                transform = False
+                # pyrefly: ignore [missing-attribute]
+                compute_dtype = get_computation_dtype(node.meta["val"].dtype)
+
+                # pyrefly: ignore [missing-attribute]
+                for a in node.args:
+                    if (
+                        isinstance(a, fx.Node)
+                        and "val" in a.meta
+                        and isinstance(zf := a.meta["val"], torch.SymFloat)
+                    ):
+                        transform = True
+                        try:
+                            proxy = _sympy_interp(zf.node.expr)
+                        except NotImplementedError:
+                            transform = False
+                            break
+
+                        # We use _expr instead of expr b/c we want the symbol not the replacement
+                        tensorified_symbols.add(a.meta["val"].node._expr)
+
+                        # The upcasting is irrelevant when the compute dtype is bool. This happens
+                        # in cases where we are tensorifying a comparison operator such as
+                        # torch.ops.aten.gt.Tensor
+                        if (
+                            compute_dtype != torch.bool
+                            and proxy.node.meta["val"].dtype != compute_dtype
+                        ):
+                            proxy = torch.ops.prims.convert_element_type.default(
+                                proxy, compute_dtype
+                            )
+
+                        args.append(proxy)
+                    elif isinstance(a, fx.Node):
+                        args.append(MetaProxy(a, tracer=tracer, fake_mode=fake_mode))
+                    else:
+                        args.append(a)
+
+                if transform:
+                    replacement_proxy = replacement_op(*args)
+
+                    # pyrefly: ignore [missing-attribute]
+                    if compute_dtype != node.meta["val"].dtype:
+                        replacement_proxy = (
+                            torch.ops.prims.convert_element_type.default(
+                                replacement_proxy,
+                                node.meta["val"].dtype,
+                            )
+                        )
+
+                    # pyrefly: ignore [missing-attribute]
+                    node.replace_all_uses_with(replacement_proxy.node)
+                    # pyrefly: ignore [bad-argument-type]
+                    graph.erase_node(node)
+
+                    metrics_context = get_metrics_context()
+                    if metrics_context.in_progress():
+                        metrics_context.set(
+                            "tensorify_float_success", True, overwrite=True
+                        )
+            else:
+                # pyrefly: ignore [missing-attribute]
+                for a in node.args:
+                    if (
+                        isinstance(a, fx.Node)
+                        and "val" in a.meta
+                        and isinstance(zf := a.meta["val"], torch.SymFloat)
+                    ):
+                        # pyrefly: ignore [missing-attribute]
+                        failed_tensorify_ops.update(str(node.target))
+                        # pyrefly: ignore [missing-attribute]
+                        log.info("Failed to tensorify %s", str(node.target))
+
+    # Now do one more pass that specializes all symfloats we didn't manage
+    # to tensorify away.
+    for node in reversed(graph.nodes):
+        if node.op == "output" or node.op == "placeholder":
+            continue
+
+        with graph.inserting_before(node):
+            if len(node.users) == 0 and not node.is_impure():
+                graph.erase_node(node)
+                continue
+
+            if isinstance(
+                (val := node.meta.get("val")),
+                (torch.SymFloat, torch.SymInt, torch.SymBool),
+            ):
+                if has_free_symbols(val.node.expr) and all(
+                    symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols
+                ):
+                    # If all symbols are backed symfloats, we can just specialize the whole node
+                    # and get more precise guards. eg.
+                    #
+                    # zf = a.item()
+                    # zf2 = zf // 2
+                    # op(.. zf2 ..)
+                    #
+                    # It's better to guard on zf // 2 == 2.0 than zf == 5.0
+
+                    node.replace_all_uses_with(guard_scalar(val))
+                    graph.erase_node(node)
+
+    # Sometimes by the time we get to tensorify, there have already been
+    # specializations, eg. in python_arg_parser.h. In these cases,
+    # placeholder nodes no longer have a reference to their original
+    # symfloat and thus we need to deduce specializations have happened
+    # via shape_env.replacements. NB: there's an important invariant here
+    # that symfloats keep consistent names across restarts.
+    for k, v in shape_env.var_to_val.items():
+        if symbol_is_type(k, SymT.FLOAT) and isinstance(v, sympy.core.numbers.Float):
+            name = str(k)
+            if (
+                not TensorifyState.should_specialize(name)
+                and k not in tensorified_symbols
+            ):
+                TensorifyState.specialize(name)
+                should_restart = True
+
+    if should_restart:
+        # Sledgehammer time. Restart dynamo analysis, keeping track of which input sources
+        # are no longer needed and should be specialized. Restarting analysis is necessary
+        # because we need to instruct Dynamo to NOT make these as inputs.
+        metrics_context = get_metrics_context()
+        if metrics_context.in_progress():
+            metrics_context.set(
+                "tensorify_float_failure", failed_tensorify_ops, overwrite=True
+            )
+            metrics_context.set("tensorify_float_success", True, overwrite=True)
+        raise TensorifyScalarRestartAnalysis
+
+    graph_code_log.debug(
+        "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b77f6396206e37e51bbb1ff68479b55bc062fd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py
@@ -0,0 +1,59 @@
+import operator
+
+import torch
+
+
+def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
+    """
+    Annotate the type of getitem nodes, inferred from the type of sequence node.
+    If sequence node is not annotated with a type, do nothing.
+    Currently support getitem nodes from tuple, list, and NamedTuple sequence node.
+
+    This is helpful since annotations on local names within function are lost during FX transforms.
+    Adding back known type annotation for getitem nodes to improve jit scriptability.
+
+    Args:
+        graph (Graph): The graph to be annotated
+    """
+    for node in graph.nodes:
+        if node.target is operator.getitem:
+            sequence_node, index_node = node.args
+            if not sequence_node.type:
+                continue
+            # container types
+            if hasattr(sequence_node.type, "_name"):
+                parameterized_types = sequence_node.type.__args__
+                if sequence_node.type._name == "Tuple":
+                    if len(parameterized_types) == 2 and isinstance(
+                        parameterized_types[1], type(...)
+                    ):
+                        node.type = parameterized_types[0]
+                    else:
+                        assert len(parameterized_types) > index_node
+                        node_type = parameterized_types[index_node]
+                        node.type = node_type
+                elif sequence_node.type._name == "List":
+                    assert len(parameterized_types) == 1
+                    node.type = parameterized_types[0]
+            # Generic Alias Type
+            elif hasattr(sequence_node.type, "__origin__"):
+                parameterized_types = sequence_node.type.__args__
+                if sequence_node.type.__origin__ is tuple:
+                    if len(parameterized_types) == 2 and isinstance(
+                        parameterized_types[1], type(...)
+                    ):
+                        node.type = parameterized_types[0]
+                    else:
+                        assert len(parameterized_types) > index_node
+                        node_type = parameterized_types[index_node]
+                        node.type = node_type
+                elif sequence_node.type.__origin__ is list:
+                    assert len(parameterized_types) == 1
+                    node.type = parameterized_types[0]
+            # NamedTuple type
+            elif hasattr(sequence_node.type, "__annotations__"):
+                if sequence_node.type == torch.Tensor:
+                    continue
+                sequence_node_field_types = sequence_node.type.__annotations__
+                field_name = sequence_node.type._fields[index_node]
+                node.type = sequence_node_field_types[field_name]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f69f3d9aa1cf60c874f3c93abc6d3a8039693e99
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc25df3897457259f625f21f098fa2df359eb980
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..97496fbc9b2a2439b687bc09c58bb4031b8fc670
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py
@@ -0,0 +1,61 @@
+# mypy: allow-untyped-defs
+import operator
+
+import torch
+from torch.fx.passes.fake_tensor_prop import FakeTensorProp
+from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
+from torch.fx.passes.operator_support import OperatorSupport
+from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
+from torch.utils import _pytree as pytree
+
+
+class CudaGraphsSupport(OperatorSupport):
+    # TODO: why is submodules passed here
+    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
+        if node.op not in CALLABLE_NODE_OPS:
+            return False
+
+        if node.target is torch.ops.aten.embedding_dense_backward.default:
+            return False
+
+        if node.target is operator.getitem:
+            return True
+
+        found_not_cuda = False
+
+        def meta_fk(meta):
+            return meta["val"] if "val" in meta else meta["fake_result"]
+
+        def find_not_cuda(t):
+            nonlocal found_not_cuda
+            if isinstance(t, torch.Tensor) and t.device.type != "cuda":
+                found_not_cuda = True
+
+        for n in node.all_input_nodes:
+            pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
+
+        pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
+
+        # NB: factory function is accounted for because the result would be
+        # cpu or cuda
+
+        return not found_not_cuda
+
+
+def partition_cudagraphs(gm, inputs):
+    """
+    Partition an FX graph into sub-GraphModules that can be validly run under
+    CUDA graphs.  For a subgraph to be runnable under CUDA, all of the operations
+    must involve CUDA tensors only/
+    """
+
+    FakeTensorProp(gm).propagate(*inputs)
+    supported_ops = CudaGraphsSupport()
+    # TODO: single node partition may be wrong due to the pessimization
+    # from copying in and out the data.  Check in benchmarks, perhaps
+    partitioner = CapabilityBasedPartitioner(
+        gm, supported_ops, allows_single_node_partition=True
+    )
+    partitions = partitioner.propose_partitions()
+    fused_graph = partitioner.fuse_partitions(partitions)
+    return fused_graph
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc246935af3c978f1fb1992bbbf03b87ff072517
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48344c3f5348a1d6d3581f8da3272852c911c971
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..753af3e0c53342182d754ba9e36b977590e92273
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5889375bb07ae0f56917aff9950db67ff3f4bec
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py
@@ -0,0 +1,155 @@
+# mypy: allow-untyped-defs
+from typing import Any
+
+import torch
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+from torch.utils._pytree import tree_flatten
+
+
+aten = torch.ops.aten
+
+
+# stateful ops are banned from CSE
+rand_ops = {
+    aten.dropout,
+    aten._fused_dropout,
+    aten._standard_gamma,
+    aten.bernoulli,
+    aten.multinomial,
+    aten.native_dropout,
+    aten.normal,
+    aten.poisson,
+    aten.binomial,
+    aten.rrelu,
+    aten.rand_like,
+    aten.rand,
+    aten.randint,
+    aten.randn,
+    aten.randperm,
+}  # noqa: E501,B950
+
+inplace_ops = {
+    aten.add_,
+    aten.sub_,
+    aten.mul_,
+    aten.div_,
+    aten.pow_,
+    aten.lerp_,
+    aten.relu_,
+    aten.sigmoid_,
+    aten.tanh_,
+}  # noqa: E501
+
+
+@torch.fx._compatibility.compatibility(is_backward_compatible=False)
+def get_CSE_banned_ops():
+    return rand_ops.union(inplace_ops)
+
+
+@torch.fx._compatibility.compatibility(is_backward_compatible=False)
+class CSEPass(PassBase):
+    def __init__(self, banned_ops=None):
+        """
+        This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
+
+        For functional dialects, user would only need to specify the random ops in ban list.
+
+        Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
+        If your dialect contains stateful operators, please customized the banned_ops.
+
+        """
+        if banned_ops is None:
+            banned_ops = set()
+        self.banned_ops = banned_ops
+        super().__init__()
+
+    def call(self, graph_module: GraphModule) -> PassResult:
+        """
+        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
+
+        Example usage:
+
+        from torch.fx.experimental.proxy_tensor import make_fx
+        def f(a):
+            b = a * a
+            c = a * a
+            return b+c
+
+        p = CSEPass()
+        traced_graph = make_fx(f)(torch.tensor(1))
+        print(traced_graph)
+        result = p(traced_graph)
+        print(result.graph_module)
+        """
+
+        def get_aten_target(node):
+            if hasattr(node.target, "overloadpacket"):
+                return node.target.overloadpacket
+            return node.target
+
+        modified = False
+        new_graph = Graph()
+        env: dict[
+            Node, Node
+        ] = {}  # map from node in the old graph to node in the new graph
+        hash_env: dict[
+            tuple[torch._ops.OpOverload, int], Node
+        ] = {}  # map from hash to a node in the new graph
+        token_map: dict[
+            tuple[torch._ops.OpOverload, int], dict[str, Any]
+        ] = {}  # map from hash to token
+        for n in graph_module.graph.nodes:
+            # The placeholder, output, and get_attr nodes are copied to the new graph without change
+            # do not CSE away random operations
+            if (
+                n.op == "placeholder"
+                or n.op == "output"
+                or n.op == "get_attr"
+                or get_aten_target(n) in self.banned_ops
+            ):
+                new_node = new_graph.node_copy(n, lambda x: env[x])
+                env[n] = new_node
+            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
+                # substitute args and kwargs members to their mapping in env if exists
+                # specs can be used to reconstruct nested list/dictionaries
+                def substitute(arg_list):
+                    arg_list, spec = tree_flatten(arg_list)
+                    for i in range(len(arg_list)):
+                        v = arg_list[i]
+                        if isinstance(v, Node) and v in env:
+                            arg_list[i] = env[v]
+                    return tuple(arg_list), spec
+
+                args, args_spec = substitute(n.args)
+                kwargs, kwargs_spec = substitute(n.kwargs)
+
+                # each token corresponds to a unique node
+                # nodes with the same token can be substituted
+                token = {
+                    "target": n.target,
+                    "args": args,
+                    "args_spec": args_spec,
+                    "kwargs": kwargs,
+                    "kwargs_spec": kwargs_spec,
+                }
+
+                # hash substituted args to a number, do not hash specs because specs are not hashable
+                hash_arg = hash((args, kwargs))
+                hash_val = (n.target, hash_arg)
+
+                # check if a node has a substitute and can be eliminated
+                hash_val_in_hash_env = hash_val in hash_env
+                if hash_val_in_hash_env and token_map[hash_val] == token:
+                    modified = True  # substitution happens and the graph is modified
+                    env[n] = hash_env[hash_val]
+                    continue
+
+                new_node = new_graph.node_copy(n, lambda x: env[x])
+                env[n] = new_node
+                if not hash_val_in_hash_env:
+                    hash_env[hash_val] = new_node
+                    token_map[hash_val] = token
+
+        csed_gm = GraphModule(graph_module, new_graph)
+        return PassResult(csed_gm, modified)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..43dbe86c7370f66aa30b5fbc5853d5a0d12cd8ad
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py
@@ -0,0 +1,109 @@
+# mypy: allow-untyped-defs
+from typing import Optional
+
+import torch.fx
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch.fx import Node
+from torch.fx._compatibility import compatibility
+from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
+from torch.fx.node import map_aggregate
+from torch.utils._ordered_set import OrderedSet
+
+
+__all__ = ["FakeTensorProp"]
+
+
+@compatibility(is_backward_compatible=False)
+class FakeTensorProp(torch.fx.Interpreter):
+    """
+    Execute an FX graph Node-by-Node and record a fake tensor representing
+    the metadata for the node.  Unlike ShapeProp, (1) this propagation
+    is cheap--it does the propagation with meta tensors which do not actually
+    store data, and (2) the fake tensors have much more fine grained information,
+    e.g., they have accurate alias information that can be consulted by looking
+    at the storages.
+
+    Args:
+         module (GraphModule): The module to be executed
+         mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
+    """
+
+    def __init__(
+        self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None
+    ):
+        super().__init__(module)
+        if mode is None:
+            mode = FakeTensorMode()
+        self._mode = mode
+        mode.epoch += 1
+        mode.reset_nt_tensor_id_counter()
+        self.seen_subgraphs: OrderedSet[str] = OrderedSet()
+
+    def run_node(self, n: Node):
+        from torch.fx.experimental.symbolic_shapes import (
+            compute_unbacked_bindings,
+            rebind_unbacked,
+        )
+
+        if (
+            n.op == "call_function"
+            and n.target is torch.ops.higher_order.invoke_subgraph
+            and n.args[1] not in self.seen_subgraphs
+        ):
+            # Prevent redundant fake tensor prop for invoke_subgraphs. Note that
+            # there is also fake tensor caching for the entire subgraph. This
+            # happens the next time we call `run_node` for the same subgraph,
+            # which goes through super.run_node and caches the fake tensor prop.
+            # Therefore, we are propagating fake tensor through the subgraphs
+            # twice.
+            assert isinstance(n.args[1], str)
+            assert (
+                isinstance(n.args[0], torch.fx.Node)
+                and n.args[0].op == "get_attr"
+                and isinstance(n.args[0].target, str)
+            )
+            self.seen_subgraphs.add(n.args[1])
+            operands = n.args[2:]
+            example_inputs = []
+            for operand in operands:
+                assert isinstance(operand, torch.fx.Node) and "val" in operand.meta
+                example_inputs.append(operand.meta["val"])
+            return FakeTensorProp(
+                getattr(self.module, n.args[0].target), mode=self._mode
+            ).propagate(*example_inputs)
+
+        result = super().run_node(n)
+        rebind_unbacked(self._mode.shape_env, n, result)
+
+        def extract_val(obj):
+            if isinstance(obj, FakeTensor):
+                return snapshot_fake(obj)
+            elif isinstance(obj, torch.Tensor):
+                # TODO: How is it possible that we get a non fake tensor?  We
+                # should be running under the mode...
+                return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
+            elif isinstance(obj, py_sym_types):
+                return obj
+            else:
+                return None
+
+        meta = map_aggregate(result, extract_val)
+        if meta is not None:
+            n.meta["val"] = meta
+            if (shape_env := self._mode.shape_env) and (
+                symbol_to_path := compute_unbacked_bindings(shape_env, result)
+            ):
+                n.meta["unbacked_bindings"] = symbol_to_path
+
+        return result
+
+    def propagate(self, *args):
+        fake_args = [
+            self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
+            for a in args
+        ]
+        return self.propagate_dont_convert_inputs(*fake_args)
+
+    def propagate_dont_convert_inputs(self, *args):
+        with self._mode:
+            return super().run(*args)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ce645df8fa92e03e912da7d66f9b8622edeec7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py
@@ -0,0 +1,504 @@
+# mypy: allow-untyped-defs
+
+import hashlib
+from itertools import chain
+from types import ModuleType
+from typing import Any, Optional, TYPE_CHECKING
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import _parse_stack_trace
+from torch.fx.node import _format_arg, _get_qualified_name
+from torch.fx.operator_schemas import normalize_function
+from torch.fx.passes.shape_prop import TensorMetadata
+
+
+if TYPE_CHECKING:
+    import pydot
+
+    HAS_PYDOT = True
+else:
+    pydot: Optional[ModuleType]
+    try:
+        import pydot
+
+        HAS_PYDOT = True
+    except ModuleNotFoundError:
+        HAS_PYDOT = False
+        pydot = None
+
+
+__all__ = ["FxGraphDrawer"]
+
+_COLOR_MAP = {
+    "placeholder": '"AliceBlue"',
+    "call_module": "LemonChiffon1",
+    "get_param": "Yellow2",
+    "get_attr": "LightGrey",
+    "output": "PowderBlue",
+}
+
+_HASH_COLOR_MAP = [
+    "CadetBlue1",
+    "Coral",
+    "DarkOliveGreen1",
+    "DarkSeaGreen1",
+    "GhostWhite",
+    "Khaki1",
+    "LavenderBlush1",
+    "LightSkyBlue",
+    "MistyRose1",
+    "MistyRose2",
+    "PaleTurquoise2",
+    "PeachPuff1",
+    "Salmon",
+    "Thistle1",
+    "Thistle3",
+    "Wheat1",
+]
+
+_WEIGHT_TEMPLATE = {
+    "fillcolor": "Salmon",
+    "style": '"filled,rounded"',
+    "fontcolor": "#000000",
+}
+
+if HAS_PYDOT:
+
+    @compatibility(is_backward_compatible=False)
+    class FxGraphDrawer:
+        """
+        Visualize a torch.fx.Graph with graphviz
+        Basic usage:
+            g = FxGraphDrawer(symbolic_traced, "resnet18")
+            g.get_dot_graph().write_svg("a.svg")
+        """
+
+        def __init__(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool = False,
+            ignore_parameters_and_buffers: bool = False,
+            skip_node_names_in_args: bool = True,
+            parse_stack_trace: bool = False,
+            dot_graph_shape: Optional[str] = None,
+            normalize_args: bool = False,
+        ):
+            self._name = name
+            self.dot_graph_shape = (
+                dot_graph_shape if dot_graph_shape is not None else "record"
+            )
+            self.normalize_args = normalize_args
+            _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
+
+            self._dot_graphs = {
+                name: self._to_dot(
+                    graph_module,
+                    name,
+                    ignore_getattr,
+                    ignore_parameters_and_buffers,
+                    skip_node_names_in_args,
+                    parse_stack_trace,
+                )
+            }
+
+            for node in graph_module.graph.nodes:
+                if node.op != "call_module":
+                    continue
+
+                leaf_node = self._get_leaf_node(graph_module, node)
+
+                if not isinstance(leaf_node, torch.fx.GraphModule):
+                    continue
+
+                self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
+                    leaf_node,
+                    f"{name}_{node.target}",
+                    ignore_getattr,
+                    ignore_parameters_and_buffers,
+                    skip_node_names_in_args,
+                    parse_stack_trace,
+                )
+
+        def get_dot_graph(self, submod_name=None) -> pydot.Dot:
+            """
+            Visualize a torch.fx.Graph with graphviz
+            Example:
+                >>> # xdoctest: +REQUIRES(module:pydot)
+                >>> # xdoctest: +REQUIRES(module:ubelt)
+                >>> # define module
+                >>> class MyModule(torch.nn.Module):
+                >>>     def __init__(self) -> None:
+                >>>         super().__init__()
+                >>>         self.linear = torch.nn.Linear(4, 5)
+                >>>     def forward(self, x):
+                >>>         return self.linear(x).clamp(min=0.0, max=1.0)
+                >>> module = MyModule()
+                >>> # trace the module
+                >>> symbolic_traced = torch.fx.symbolic_trace(module)
+                >>> # setup output file
+                >>> import ubelt as ub
+                >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
+                >>> fpath = dpath / "linear.svg"
+                >>> # draw the graph
+                >>> g = FxGraphDrawer(symbolic_traced, "linear")
+                >>> g.get_dot_graph().write_svg(fpath)
+            """
+            if submod_name is None:
+                return self.get_main_dot_graph()
+            else:
+                return self.get_submod_dot_graph(submod_name)
+
+        def get_main_dot_graph(self) -> pydot.Dot:
+            return self._dot_graphs[self._name]
+
+        def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
+            return self._dot_graphs[f"{self._name}_{submod_name}"]
+
+        def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
+            return self._dot_graphs
+
+        def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
+            template = {
+                "shape": self.dot_graph_shape,
+                "fillcolor": "#CAFFE3",
+                "style": '"filled,rounded"',
+                "fontcolor": "#000000",
+            }
+            if node.op in _COLOR_MAP:
+                template["fillcolor"] = _COLOR_MAP[node.op]
+            else:
+                # Use a random color for each node; based on its name so it's stable.
+                target_name = node._pretty_print_target(node.target)
+                target_hash = int(
+                    hashlib.md5(
+                        target_name.encode(), usedforsecurity=False
+                    ).hexdigest()[:8],
+                    16,
+                )
+                template["fillcolor"] = _HASH_COLOR_MAP[
+                    target_hash % len(_HASH_COLOR_MAP)
+                ]
+            return template
+
+        def _get_leaf_node(
+            self, module: torch.nn.Module, node: torch.fx.Node
+        ) -> torch.nn.Module:
+            py_obj = module
+            assert isinstance(node.target, str)
+            atoms = node.target.split(".")
+            for atom in atoms:
+                if not hasattr(py_obj, atom):
+                    raise RuntimeError(
+                        str(py_obj) + " does not have attribute " + atom + "!"
+                    )
+                py_obj = getattr(py_obj, atom)
+            return py_obj
+
+        def _typename(self, target: Any) -> str:
+            if isinstance(target, torch.nn.Module):
+                ret = torch.typename(target)
+            elif isinstance(target, str):
+                ret = target
+            else:
+                ret = _get_qualified_name(target)
+
+            # Escape "{" and "}" to prevent dot files like:
+            # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
+            # which triggers `Error: bad label format (...)` from dot
+            return ret.replace("{", r"\{").replace("}", r"\}")
+
+        # shorten path to avoid drawing long boxes
+        # for full path = '/home/weif/pytorch/test.py'
+        # return short path = 'pytorch/test.py'
+        def _shorten_file_name(
+            self,
+            full_file_name: str,
+            truncate_to_last_n: int = 2,
+        ):
+            splits = full_file_name.split("/")
+            if len(splits) >= truncate_to_last_n:
+                return "/".join(splits[-truncate_to_last_n:])
+            return full_file_name
+
+        def _get_node_label(
+            self,
+            module: torch.fx.GraphModule,
+            node: torch.fx.Node,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> str:
+            def _get_str_for_args_kwargs(arg):
+                if isinstance(arg, tuple):
+                    prefix, suffix = r"|args=(\l", r",\n)\l"
+                    arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
+                elif isinstance(arg, dict):
+                    prefix, suffix = r"|kwargs={\l", r",\n}\l"
+                    arg_strs_list = [
+                        f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
+                    ]
+                else:  # Fall back to nothing in unexpected case.
+                    return ""
+
+                # Strip out node names if requested.
+                if skip_node_names_in_args:
+                    arg_strs_list = [a for a in arg_strs_list if "%" not in a]
+                if len(arg_strs_list) == 0:
+                    return ""
+                arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
+                if len(arg_strs_list) == 1:
+                    arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
+                return arg_strs.replace("{", r"\{").replace("}", r"\}")
+
+            label = "{" + f"name=%{node.name}|op_code={node.op}\n"
+
+            if node.op == "call_module":
+                leaf_module = self._get_leaf_node(module, node)
+                label += r"\n" + self._typename(leaf_module) + r"\n|"
+                extra = ""
+                if hasattr(leaf_module, "__constants__"):
+                    extra = r"\n".join(
+                        [
+                            f"{c}: {getattr(leaf_module, c)}"
+                            for c in leaf_module.__constants__  # type: ignore[union-attr]
+                        ]  # type: ignore[union-attr]
+                    )
+                label += extra + r"\n"
+            else:
+                label += f"|target={self._typename(node.target)}" + r"\n"
+                if self.normalize_args:
+                    try:
+                        args, kwargs = normalize_function(  # type: ignore[misc]
+                            node.target,  # type: ignore[arg-type]
+                            node.args,  # type: ignore[arg-type]
+                            node.kwargs,
+                            normalize_to_only_use_kwargs=True,
+                        )
+                    except Exception:
+                        # Fallback to not normalizing if there's an exception.
+                        # Some functions need overloads specified to normalize.
+                        args, kwargs = node.args, node.kwargs
+                else:
+                    args, kwargs = node.args, node.kwargs
+                if len(args) > 0:
+                    label += _get_str_for_args_kwargs(args)
+                if len(kwargs) > 0:
+                    label += _get_str_for_args_kwargs(kwargs)
+                label += f"|num_users={len(node.users)}" + r"\n"
+
+            tensor_meta = node.meta.get("tensor_meta")
+            label += self._tensor_meta_to_label(tensor_meta)
+
+            # for original fx graph
+            # print buf=buf0, n_origin=6
+            buf_meta = node.meta.get("buf_meta", None)
+            if buf_meta is not None:
+                label += f"|buf={buf_meta.name}" + r"\n"
+                label += f"|n_origin={buf_meta.n_origin}" + r"\n"
+
+            # for original fx graph
+            # print file:lineno code
+            if parse_stack_trace and node.stack_trace is not None:
+                parsed_stack_trace = _parse_stack_trace(node.stack_trace)
+                fname = self._shorten_file_name(parsed_stack_trace.file)
+                label += (
+                    f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
+                    + r"\n"
+                )
+
+            return label + "}"
+
+        def _tensor_meta_to_label(self, tm) -> str:
+            if tm is None:
+                return ""
+            elif isinstance(tm, TensorMetadata):
+                return self._stringify_tensor_meta(tm)
+            elif isinstance(tm, list):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            elif isinstance(tm, dict):
+                result = ""
+                for v in tm.values():
+                    result += self._tensor_meta_to_label(v)
+                return result
+            elif isinstance(tm, tuple):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            else:
+                raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
+
+        def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
+            result = ""
+            if not hasattr(tm, "dtype"):
+                print("tm", tm)
+            result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
+            result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
+            result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
+            result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
+            if tm.is_quantized:
+                assert tm.qparams is not None
+                assert "qscheme" in tm.qparams
+                qscheme = tm.qparams["qscheme"]
+                if qscheme in {
+                    torch.per_tensor_affine,
+                    torch.per_tensor_symmetric,
+                }:
+                    result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
+                    result += (
+                        "|"
+                        + "q_zero_point"
+                        + "="
+                        + str(tm.qparams["zero_point"])
+                        + r"\n"
+                    )
+                elif qscheme in {
+                    torch.per_channel_affine,
+                    torch.per_channel_symmetric,
+                    torch.per_channel_affine_float_qparams,
+                }:
+                    result += (
+                        "|"
+                        + "q_per_channel_scale"
+                        + "="
+                        + str(tm.qparams["scale"])
+                        + r"\n"
+                    )
+                    result += (
+                        "|"
+                        + "q_per_channel_zero_point"
+                        + "="
+                        + str(tm.qparams["zero_point"])
+                        + r"\n"
+                    )
+                    result += (
+                        "|"
+                        + "q_per_channel_axis"
+                        + "="
+                        + str(tm.qparams["axis"])
+                        + r"\n"
+                    )
+                else:
+                    raise RuntimeError(f"Unsupported qscheme: {qscheme}")
+                result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
+            return result
+
+        def _get_tensor_label(self, t: torch.Tensor) -> str:
+            return str(t.dtype) + str(list(t.shape)) + r"\n"
+
+        # when parse_stack_trace=True
+        # print file:lineno code
+        def _to_dot(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool,
+            ignore_parameters_and_buffers: bool,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> pydot.Dot:
+            """
+            Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
+            If ignore_parameters_and_buffers is True, the parameters and buffers
+            created with the module will not be added as nodes and edges.
+            """
+
+            # "TB" means top-to-bottom rank direction in layout
+            dot_graph = pydot.Dot(name, rankdir="TB")
+
+            buf_name_to_subgraph = {}
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                style = self._get_node_style(node)
+                dot_node = pydot.Node(
+                    node.name,
+                    label=self._get_node_label(
+                        graph_module, node, skip_node_names_in_args, parse_stack_trace
+                    ),
+                    **style,  # type: ignore[arg-type]
+                )
+
+                current_graph = dot_graph
+
+                buf_meta = node.meta.get("buf_meta", None)
+                if buf_meta is not None and buf_meta.n_origin > 1:
+                    buf_name = buf_meta.name
+                    if buf_name not in buf_name_to_subgraph:
+                        buf_name_to_subgraph[buf_name] = pydot.Cluster(
+                            buf_name, label=buf_name
+                        )
+                    current_graph = buf_name_to_subgraph.get(buf_name)  # type: ignore[assignment]
+
+                # pyrefly: ignore [missing-attribute]
+                current_graph.add_node(dot_node)
+
+                def get_module_params_or_buffers():
+                    for pname, ptensor in chain(
+                        leaf_module.named_parameters(),
+                        # pyrefly: ignore [bad-argument-type]
+                        leaf_module.named_buffers(),
+                    ):
+                        pname1 = node.name + "." + pname
+                        label1 = (
+                            pname1 + "|op_code=get_" + "parameter"
+                            if isinstance(ptensor, torch.nn.Parameter)
+                            else "buffer" + r"\l"
+                        )
+                        dot_w_node = pydot.Node(
+                            pname1,
+                            label="{" + label1 + self._get_tensor_label(ptensor) + "}",
+                            **_WEIGHT_TEMPLATE,  # type: ignore[arg-type]
+                        )
+                        dot_graph.add_node(dot_w_node)
+                        dot_graph.add_edge(pydot.Edge(pname1, node.name))
+
+                if node.op == "call_module":
+                    leaf_module = self._get_leaf_node(graph_module, node)
+
+                    if not ignore_parameters_and_buffers and not isinstance(
+                        leaf_module, torch.fx.GraphModule
+                    ):
+                        get_module_params_or_buffers()
+
+            for subgraph in buf_name_to_subgraph.values():
+                subgraph.set("color", "royalblue")
+                subgraph.set("penwidth", "2")
+                dot_graph.add_subgraph(subgraph)  # type: ignore[arg-type]
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                for user in node.users:
+                    dot_graph.add_edge(pydot.Edge(node.name, user.name))
+
+            return dot_graph
+
+else:
+    if not TYPE_CHECKING:
+
+        @compatibility(is_backward_compatible=False)
+        class FxGraphDrawer:
+            def __init__(
+                self,
+                graph_module: torch.fx.GraphModule,
+                name: str,
+                ignore_getattr: bool = False,
+                ignore_parameters_and_buffers: bool = False,
+                skip_node_names_in_args: bool = True,
+                parse_stack_trace: bool = False,
+                dot_graph_shape: Optional[str] = None,
+                normalize_args: bool = False,
+            ):
+                raise RuntimeError(
+                    "FXGraphDrawer requires the pydot package to be installed. Please install "
+                    "pydot through your favorite Python package manager."
+                )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6026e9ca25c05cfb4bdc941d5beb638175d00fc6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+from typing import Any, NamedTuple, Optional
+
+import torch
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import map_arg, Node, Target
+from torch.fx.passes.shape_prop import ShapeProp
+
+
+__all__ = [
+    "replace_target_nodes_with",
+    "size_bytes",
+    "get_size_of_all_nodes",
+    "get_tensor_meta",
+    "get_size_of_node",
+]
+
+
+@compatibility(is_backward_compatible=False)
+def replace_target_nodes_with(
+    fx_module: GraphModule,
+    old_op: str,
+    old_target: Target,
+    new_op: str,
+    new_target: Target,
+):
+    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
+    and updates them to match the new op code and target"""
+    new_graph = Graph()
+    val_map: dict[Node, Node] = {}
+    for node in fx_module.graph.nodes:
+        if node.op == old_op and node.target == old_target:
+            args = map_arg(node.args, lambda n: val_map[n])
+            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
+            assert isinstance(args, tuple)
+            assert isinstance(kwargs, dict)
+            val_map[node] = new_graph.create_node(
+                new_op, new_target, args, kwargs, node.name
+            )
+        else:
+            val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
+    fx_module.graph = new_graph
+
+
+@compatibility(is_backward_compatible=False)
+class size_bytes(NamedTuple):
+    output_size: int
+    total_size: int
+
+
+@compatibility(is_backward_compatible=False)
+def get_size_of_all_nodes(
+    fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
+) -> None:
+    """Given a fx graph module, update each node with its total size (weights + bias + output)
+    and its output_size(output). For a non-module node, the total size is the output size.
+    return total size"""
+    if args is not None:
+        # Mark shape and dtype for each node (node.shape and node.dtype)
+        ShapeProp(fx_module).propagate(*args)
+    # Calculate the total size of the whole fx graph
+    for node in fx_module.graph.nodes:
+        if node.op == "output":
+            break
+        node.size_bytes = get_size_of_node(fx_module, node)
+    return
+
+
+@compatibility(is_backward_compatible=False)
+def get_tensor_meta(node: Node) -> Any:
+    tensor_meta = node.meta.get("tensor_meta")
+
+    if not tensor_meta:
+        raise RuntimeError(
+            f"Node {node} has no tensor metadata associated with it! "
+            f"Check that shape propagation has run."
+        )
+
+    return tensor_meta
+
+
+@compatibility(is_backward_compatible=False)
+def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
+    """Given a node with node.dtype and node.shape, return its total size and its output size.
+    total_size = weights + bias + output_size
+    """
+    # Total num of elements
+    total_num_of_elems = 0
+    # For a module, consider all parameters
+    if node.op == "call_module":
+        submodule_dict = dict(fx_module.named_modules())
+        submodule = submodule_dict[node.target]
+        parameters = submodule.named_parameters()
+        # Parameters are named tuples
+        for _name, p in parameters:
+            total_num_of_elems += p.numel()
+    # Don't forget the output size
+    # node.shape is the shape of this node's output
+    tensor_meta = get_tensor_meta(node)
+    output_elem = tensor_meta.shape.numel()
+    total_num_of_elems += output_elem
+    # Assume for now if it's quantized then it's qint8 or quint8
+    if tensor_meta.is_quantized:
+        size_per_elem_bytes = torch._empty_affine_quantized(
+            [], dtype=tensor_meta.dtype
+        ).element_size()
+    else:
+        size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
+    total_size = size_per_elem_bytes * total_num_of_elems
+    output_size = size_per_elem_bytes * output_elem
+    return size_bytes(output_size, total_size)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e762b8a60d10cf9dea401a501b9cd3840411ed17
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py
@@ -0,0 +1,229 @@
+# mypy: allow-untyped-defs
+import os
+from collections.abc import Callable
+from typing import Optional, TypeVar
+
+from torch.fx import Graph, Node
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+from torch.fx.traceback import NodeSource, NodeSourceAction
+
+
+T = TypeVar("T")
+
+
+from .graph_drawer import FxGraphDrawer
+
+
+__all__ = ["GraphTransformObserver"]
+
+
+@compatibility(is_backward_compatible=False)
+class GraphTransformObserver:
+    __pass_count = 0
+
+    def __init__(
+        self,
+        gm: GraphModule,
+        passname: str,
+        subsystem: Optional[str] = None,
+        log_url: Optional[str] = None,
+    ):
+        """
+        log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
+        """
+        from torch._inductor import config as inductor_config
+
+        self.gm = gm
+        self.passname = passname
+        self.subsystem = subsystem
+
+        if log_url is None:
+            log_url = inductor_config.trace.log_url_for_graph_xform
+
+        self.log_url = log_url
+
+        self.active = (
+            self.log_url is not None
+            or inductor_config.trace.provenance_tracking_level == 1
+        )
+
+        if self.active:
+            self.erased_nodes: set[str] = set()
+            self.created_nodes: set[str] = set()
+            self.name_to_node: dict[str, Node] = {}
+            # record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context
+            self.copied_gms: list[GraphModule] = []
+
+            self._node_creation_hook = self.get_node_creation_hook()
+            self._node_erase_hook = self.get_node_erase_hook()
+            self._node_replace_hook = self.get_node_replace_hook()
+            self._deepcopy_hook = self.get_deepcopy_hook()
+
+        # If log_url is None, we don't log anything
+        if self.log_url is None:
+            return
+        GraphTransformObserver.__pass_count += 1
+
+        self.input_dot_graph = FxGraphDrawer(
+            self.gm,
+            self.passname,
+            ignore_getattr=True,
+            ignore_parameters_and_buffers=True,
+        ).get_dot_graph()
+
+    @classmethod
+    def get_current_pass_count(cls):
+        return cls.__pass_count
+
+    def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
+        with self:
+            if not self._check_disable_pass():
+                return pass_fn(self.gm)
+
+        return None
+
+    def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
+        with self:
+            if not self._check_disable_pass():
+                return pass_fn(self.gm.graph)
+
+        return None
+
+    def _check_disable_pass(self):
+        if self.subsystem is None:
+            return False
+
+        debug_info = lambda: self.passname  # noqa: E731
+        from torch._inductor.compiler_bisector import CompilerBisector
+
+        return CompilerBisector.disable_subsystem(
+            "inductor", self.subsystem, debug_info
+        )
+
+    def __enter__(self):
+        if not self.active:
+            return self
+        self.gm._register_create_node_hook(self._node_creation_hook)
+        self.gm._register_erase_node_hook(self._node_erase_hook)
+        self.gm._register_replace_node_hook(self._node_replace_hook)
+        self.gm._register_deepcopy_hook(self._deepcopy_hook)
+
+        self.erased_nodes.clear()
+        self.created_nodes.clear()
+        self.name_to_node.clear()
+        self.copied_gms.clear()
+
+        for node in self.gm.graph.nodes:
+            self.name_to_node[node.name] = node
+
+        return self
+
+    def __exit__(self, type, value, tb):
+        if not self.active:
+            return
+        for gm in self.copied_gms + [self.gm]:
+            gm._unregister_create_node_hook(self._node_creation_hook)
+            gm._unregister_erase_node_hook(self._node_erase_hook)
+            gm._unregister_replace_node_hook(self._node_replace_hook)
+            gm._unregister_deepcopy_hook(self._deepcopy_hook)
+
+        if self.log_url is None:
+            return
+
+        if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
+            for e in self.input_dot_graph.get_node_list():
+                if e.get_name() in self.erased_nodes:
+                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
+                else:
+                    e.obj_dict["attributes"]["fillcolor"] = "grey"
+            assert self.log_url is not None
+            self.input_dot_graph.write(
+                os.path.join(
+                    self.log_url,
+                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
+                )
+            )
+
+            output_dot_graph = FxGraphDrawer(
+                self.gm,
+                self.passname,
+                ignore_getattr=True,
+                ignore_parameters_and_buffers=True,
+            ).get_dot_graph()
+            for e in output_dot_graph.get_node_list():
+                if e.get_name() in self.created_nodes:
+                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
+                else:
+                    e.obj_dict["attributes"]["fillcolor"] = "grey"
+            output_dot_graph.write(
+                os.path.join(
+                    self.log_url,
+                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
+                )
+            )
+
+    def get_node_creation_hook(self):
+        # We have to return a function instead of using a class method directly
+        # to avoid max recursion issue when deepcopy a graph module within the context manager.
+        def on_node_creation(node):
+            self.created_nodes.add(node.name)
+            self.name_to_node[node.name] = node
+            source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
+            if "from_node" not in node.meta:
+                node.meta["from_node"] = [source]
+            else:
+                node.meta["from_node"].append(source)
+
+        return on_node_creation
+
+    def get_node_erase_hook(self):
+        def on_node_erase(node):
+            self.erased_nodes.add(node.name)
+            self.name_to_node.pop(node.name, None)
+
+        return on_node_erase
+
+    def get_node_replace_hook(self):
+        def on_node_replace(old: Node, new: str, user: Node):
+            # Update node meta when replacing old node with new node
+            new_node = self.name_to_node.get(new, None)
+
+            if not new_node:
+                return
+
+            assert isinstance(new_node, Node)
+
+            # replace hook is called once for each user of old
+            # this avoids adding duplicated source nodes
+            added_nodes = {s.name for s in new_node.meta.get("from_node", [])}
+            if old.name in added_nodes:
+                return
+
+            action = [NodeSourceAction.REPLACE]
+            if new_node.name in self.created_nodes:
+                action.append(NodeSourceAction.CREATE)
+
+            def created_this_pass(source):
+                return source.pass_name == self.passname and source.action == [
+                    NodeSourceAction.CREATE
+                ]
+
+            # remove redundant source added on node creation
+            new_from_node = new_node.meta.get("from_node", [])
+            new_from_node = [
+                source for source in new_from_node if not created_this_pass(source)
+            ]
+
+            # add new source
+            new_node_source = NodeSource(old, self.passname, action)
+            new_from_node.append(new_node_source)
+            new_node.meta["from_node"] = new_from_node
+
+        return on_node_replace
+
+    def get_deepcopy_hook(self):
+        def on_deepcopy(gm):
+            self.copied_gms.append(gm)
+
+        return on_deepcopy
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939157f1302e75e3cf17ec3c1e93d1b8993d67a0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py
@@ -0,0 +1 @@
+from . import pass_manager
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aaf04fe11f5c30cd2d78a80f5676eedd99e9a80
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c48af1aa97fb3404d114408eab6fc458d0172cc
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9699ef35e4454f968727e1b6aecf0e75723b228
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac617327d1c8a2654593cd8bcd6ef2cbe900a096
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bb536dbba9399d9ae5d5df53966f75f3a2a18a8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py
@@ -0,0 +1,400 @@
+# mypy: allow-untyped-defs
+import collections
+import itertools
+import logging
+import operator
+from collections.abc import Iterable, Sequence
+from typing import Optional
+
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import _get_qualified_name, Node
+from torch.fx.passes.operator_support import OperatorSupportBase
+from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+
+class Partition:
+    def __init__(
+        self,
+        id: Optional[int] = None,
+        nodes: Optional[Iterable[Node]] = None,
+        node_orders: Optional[Iterable[int]] = None,
+    ):
+        self.id = id
+        self.nodes: dict[Node, Optional[int]] = {}
+        if nodes is not None:
+            if node_orders is None:
+                self.nodes = dict.fromkeys(nodes, None)
+            else:
+                nodes_list = list(nodes)
+                node_orders_list = list(node_orders)
+                assert len(nodes_list) == len(node_orders_list), (
+                    "nodes and node_orders must have the same length"
+                )
+                self.nodes = dict(zip(nodes_list, node_orders_list))
+
+    def __repr__(self) -> str:
+        return str(self.nodes)
+
+    def add_node(self, node: Node, node_order: Optional[int] = None):
+        self.nodes.update({node: node_order})
+
+    def remove_node(self, node: Node):
+        del self.nodes[node]
+
+    def size(self):
+        return len(self.nodes)
+
+
+class _DependencyViewer:
+    def __init__(self, graph_module: GraphModule):
+        self.downstreams = collections.defaultdict(set)
+
+        for node in reversed(graph_module.graph.nodes):
+            for output_node in node.users:
+                # add output_node and output_node's downstream dependency
+                self.downstreams[node].add(output_node)
+                self.downstreams[node].update(self.downstreams[output_node])
+
+    def downstreams_of(self, node: Node) -> set[Node]:
+        return self.downstreams[node]
+
+
+class CapabilityBasedPartitioner:
+    def __init__(
+        self,
+        graph_module: GraphModule,
+        operator_support: OperatorSupportBase,
+        allows_single_node_partition: bool = False,
+        non_compute_ops: Optional[Sequence[str]] = None,
+        allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
+    ) -> None:
+        self.graph_module = graph_module
+        self.operator_support = operator_support
+        self.allows_single_node_partition = allows_single_node_partition
+        self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
+        self.allowed_single_node_partition_ops = (
+            allowed_single_node_partition_ops
+            if allowed_single_node_partition_ops is not None
+            else []
+        )
+        self.dependency_viewer = _DependencyViewer(graph_module)
+
+    def _is_node_supported(self, node: Node) -> bool:
+        return self.operator_support.is_node_supported(
+            dict(self.graph_module.named_modules()), node
+        )
+
+    def propose_partitions(self) -> list[Partition]:
+        # partition_map is a mapping from partition id to a set of partition id's.
+        # The value set contains all the partition ids that can be reached by doing a
+        # DFS starting from the partition id in the key.
+        partition_map: dict[int, set] = collections.defaultdict(set)
+
+        # assumptions: nodes in candidate list is sorted in topological order
+        assignment: dict[Node, int] = {}  # mapping from node to partition_id
+        partitions_by_id: dict[
+            int, Partition
+        ] = {}  # mapping from partition_id to partition
+        nodes_order: dict[
+            Node, int
+        ] = {}  # mapping from nodes to reversed topological order
+        partitions_order: dict[
+            int, int
+        ] = {}  # mapping from partition_id to minimum topo order of nodes in partition
+        partition_users: dict[
+            int, set
+        ] = {}  # mapping from partition_id to partition users
+        new_partition_id = itertools.count()
+
+        # try to merge partition other_id into partition self_id
+        # merge only happens if the end graph doesn't contain cyclic dependency
+        # returns `True` when merge happens, `False` otherwise.
+        def maybe_merge_partition(self_id: int, other_id: int):
+            # merged_nodes is the union of nodes in two partition to-be-merged
+            self_nodes = partitions_by_id[self_id].nodes
+            other_nodes = partitions_by_id[other_id].nodes
+
+            def dfs_iter_find_cycle(all_user_nodes: set[Node]):
+                for user_node in all_user_nodes:
+                    visited_partition_ids = set()
+
+                    for path_node in self.dependency_viewer.downstreams_of(user_node):
+                        # If any of the nodes in the dfs path of this node are in the merged_nodes
+                        # list then there is a cycle in the graph.
+                        if path_node in self_nodes or path_node in other_nodes:
+                            return True
+
+                        # If any of the nodes in the dfs path of this node are in the assignment
+                        # map then we have to make sure that the partitions that these nodes belong
+                        # to do not form a cycle with the current partitions being merged. This means
+                        # iterating through all the nodes in all the parititons that are traversed in
+                        # the dfs path and checking if they are in the merged_nodes list.
+                        if path_node in assignment:
+                            partition_id = assignment[path_node]
+                            # If the partition id has already been visited then we know that it doesn't
+                            # form a cycle with the current partitions being merged.
+                            if partition_id in visited_partition_ids:
+                                continue
+                            p_map = partition_map[partition_id]
+                            if self_id in p_map or other_id in p_map:
+                                return True
+
+                            visited_partition_ids.add(partition_id)
+
+                return False
+
+            # find new partition users if merge.
+            all_user_nodes = partition_users[self_id] | partition_users[other_id]
+            all_user_nodes.difference_update(other_nodes, self_nodes)
+
+            # check if merge would create cyclic dependency.
+            if dfs_iter_find_cycle(all_user_nodes):
+                # return false indicating cyclic dependency found and
+                # merge is aborted
+                return self_id, False
+
+            # merge the smaller partition into the larger.
+            merge_id, removed_id = self_id, other_id
+            if len(self_nodes) < len(other_nodes):
+                merge_id, removed_id = removed_id, merge_id
+            # no cyclic dependency found, move forward with the merge
+            # updating partition nodes
+            partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes)
+            # updating assignment map
+            for node in partitions_by_id[removed_id].nodes:
+                assignment[node] = merge_id
+            # delete other partition
+            del partitions_by_id[removed_id]
+
+            partitions_order[merge_id] = min(
+                partitions_order[merge_id], partitions_order[removed_id]
+            )
+            del partitions_order[removed_id]
+
+            partition_map[merge_id] = partition_map[merge_id].union(
+                partition_map[removed_id]
+            )
+            del partition_map[removed_id]
+
+            partition_users[merge_id] = all_user_nodes
+            del partition_users[removed_id]
+
+            return merge_id, True
+
+        def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]):
+            def _update_partition_map(node: Node, id: int):
+                # Iterate through all the users of this node and update the partition map to indicate
+                # that there is a path from the partition id of this node to the target partition id.
+                for user_node in node.users:
+                    target_id = assignment.get(user_node)
+                    if target_id is not None:
+                        partition_map[id].add(target_id)
+                        partition_map[id].update(partition_map[target_id])
+
+            if node in assignment:
+                partitions_by_id[assignment[node]].remove_node(node)
+
+            if id is None:
+                assignment.pop(node)
+            elif id not in partitions_by_id:
+                assignment[node] = id
+                assert node_order is not None
+                partitions_by_id[id] = Partition(
+                    id=id, nodes=[node], node_orders=[node_order]
+                )
+                partition_users[id] = set(node.users)
+                _update_partition_map(node, id)
+            else:
+                assignment[node] = id
+                partitions_by_id[id].add_node(node, node_order)
+
+        logger.debug("Proposing partitions...")
+
+        for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)):
+            # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
+            merge_candidates: dict[int, None] = {}
+
+            # Note a limited horizontal fusion is enabled:
+            #   when `node` is not supported, the code below attempts to fuse consumer of `node`.
+            #
+            # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
+            # the fusion by adding an `else` block here to skip horizontal fusion.
+            if self._is_node_supported(node) and node not in assignment:
+                partition_id = next(new_partition_id)
+                nodes_order[node] = partition_id
+                partitions_order[partition_id] = partition_id
+                merge_single_node(node, node_order, partition_id)
+                merge_candidates[partition_id] = None
+
+            # merge all possible partitions
+            for partition_id, _ in sorted(
+                partitions_order.items(), key=operator.itemgetter(1)
+            ):
+                merge_candidates[partition_id] = None
+
+            merge_candidates_list = list(merge_candidates.keys())
+            if len(merge_candidates_list) > 1:
+                self_id = merge_candidates_list[0]
+                for other_id in merge_candidates_list[1:]:
+                    # note: merge partitions if it doesn't create cyclic dependency
+                    # in the graph, otherwise, this is a no-op
+                    self_id, _ = maybe_merge_partition(self_id, other_id)
+
+        # sort partition nodes based on descending node order
+        for partition in partitions_by_id.values():
+            partition.nodes = dict(
+                sorted(
+                    partition.nodes.items(), key=operator.itemgetter(1), reverse=True
+                )
+            )
+
+        # post processing to re-assign "getitem" nodes into upstream partition
+        logger.debug("Reassigning getitem nodes to its producer node's partition...")
+        nodes_reassignment: dict[Node, int] = {}
+        for node in self.graph_module.graph.nodes:
+            is_tuple_output = True
+            for user in node.users:
+                if (
+                    user.op != "call_function"
+                    or _get_qualified_name(user.target) != "_operator.getitem"
+                ):  # type: ignore[arg-type]
+                    is_tuple_output = False
+                    break
+
+            # node has tuple outputs, re-assign all following getitem node into node's partition
+            if is_tuple_output:
+                id = assignment.get(node)  # type: ignore[arg-type]
+                for user in node.users:
+                    if assignment.get(user) != id:  # type: ignore[arg-type]
+                        nodes_reassignment[user] = id  # type: ignore[assignment]
+        for node, id in nodes_reassignment.items():
+            merge_single_node(node, None, id)
+
+        # filter out single node partitions
+        if not self.allows_single_node_partition:
+            logger.debug("Filtering out single node partitions...")
+            default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
+            non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
+            partitions_to_remove: list[int] = []
+            for id, partition in partitions_by_id.items():
+                compute_node_count = 0
+                for node in partition.nodes:
+                    if node.op == "call_function":
+                        assert callable(node.target)
+                        if _get_qualified_name(node.target) not in non_compute_ops:
+                            compute_node_count += 1
+                        if (
+                            _get_qualified_name(node.target)
+                            in self.allowed_single_node_partition_ops
+                        ):
+                            compute_node_count += 1
+                if compute_node_count <= 1:
+                    partitions_to_remove.append(id)
+            for id in partitions_to_remove:
+                del partitions_by_id[id]
+
+        logger.debug("Partitions proposed:")
+        for id, partition in partitions_by_id.items():
+            logger.debug(
+                "partition #%s: %s", id, [node.name for node in partition.nodes]
+            )
+
+        return [
+            partition for partition in partitions_by_id.values() if partition.size() > 0
+        ]
+
+    def fuse_partitions(
+        self, partitions: list[Partition], prefix: str = "fused_"
+    ) -> GraphModule:
+        logger.debug("Fusing partitions...")
+        # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
+        return fuse_by_partitions(
+            self.graph_module,
+            [partition.nodes for partition in partitions],
+            prefix=prefix,
+        )
+
+    # remove non-compute-ops that sits at the boundary of a partition.
+    def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
+        non_compute_ops = set(self.non_compute_ops)
+
+        def is_non_compute_node(node: Node):
+            return (
+                node.op == "call_function"
+                and _get_qualified_name(node.target) in non_compute_ops  # type: ignore[arg-type]
+            )
+
+        # cache transparent nodes
+        transparent_input_nodes: dict[Node, bool] = {}
+        transparent_output_nodes: dict[Node, bool] = {}
+
+        def is_transparent_input_node(
+            node: Node, partition: set[Node], removed_nodes: set[Node]
+        ):
+            if (
+                node.op == "placeholder"
+                or (node not in partition)
+                or (node in removed_nodes)
+            ):
+                return True
+            if node in transparent_input_nodes:
+                return transparent_input_nodes[node]
+            if is_non_compute_node(node):
+                for input_n in node.all_input_nodes:
+                    if not is_transparent_input_node(input_n, partition, removed_nodes):
+                        transparent_input_nodes[node] = False
+                        return False
+                transparent_input_nodes[node] = True
+                return True
+            transparent_input_nodes[node] = False
+            return False
+
+        def is_transparent_output_node(
+            node: Node, partition: set[Node], removed_nodes: set[Node]
+        ):
+            if (
+                node.op == "placeholder"
+                or (node not in partition)
+                or (node in removed_nodes)
+            ):
+                return True
+            if node in transparent_output_nodes:
+                return transparent_output_nodes[node]
+            if is_non_compute_node(node):
+                for output_n in node.users:
+                    if not is_transparent_output_node(
+                        output_n, partition, removed_nodes
+                    ):
+                        transparent_output_nodes[node] = False
+                        return False
+                transparent_output_nodes[node] = True
+                return True
+            transparent_output_nodes[node] = False
+            return False
+
+        for partition in partitions:
+            # Note it's ok to use `set` here, since we are only query if a node
+            # has been removed. We are NEVER going to iterate on nodes inside
+            # the set.
+            remove_node: set[Node] = set()
+            for node in partition.nodes:
+                if is_non_compute_node(node) and (
+                    is_transparent_input_node(node, set(partition.nodes), remove_node)
+                    or is_transparent_output_node(
+                        node, set(partition.nodes), remove_node
+                    )
+                ):
+                    remove_node.add(node)
+
+            if len(remove_node) != 0:
+                for node in remove_node:
+                    partition.nodes.pop(node, None)
+
+    def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
+        partitions = self.propose_partitions()
+        fused_gm = self.fuse_partitions(partitions, prefix=prefix)
+        return fused_gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..32c641031b31f2c49ca76daac6751b356e740213
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py
@@ -0,0 +1,79 @@
+# mypy: allow-untyped-defs
+import abc
+from collections import namedtuple
+from typing import Optional
+
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+
+
+__all__ = ["PassResult", "PassBase"]
+
+
+@compatibility(is_backward_compatible=False)
+# pyrefly: ignore [invalid-inheritance]
+class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
+    """
+    Result of a pass:
+        graph_module: The modified graph module
+        modified: A flag for if the pass has modified the graph module
+    """
+
+    __slots__ = ()
+
+    def __new__(cls, graph_module, modified):
+        return super().__new__(cls, graph_module, modified)
+
+
+@compatibility(is_backward_compatible=False)
+class PassBase(abc.ABC):
+    """
+    Base interface for implementing passes.
+
+    It is required to implement the `call` function so that we can directly
+    pass instances of the Pass directly to the PassManager and call them as a
+    function.
+
+    We can directly pass an instance of a class implementing this interface into
+    the PassManager's `passes` attribute.
+    """
+
+    def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
+        """
+        Runs the precondition check, the pass itself, and the postcondition check.
+        """
+
+        self.requires(graph_module)
+        res = self.call(graph_module)
+        self.ensures(graph_module)
+        return res
+
+    @abc.abstractmethod
+    def call(self, graph_module: GraphModule) -> Optional[PassResult]:
+        """
+        The pass that is run through the given graph module. To implement a
+        pass, it is required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run a pass on
+        """
+
+    def requires(self, graph_module: GraphModule) -> None:  # noqa: B027
+        """
+        This function will be called before the pass is run and will check that
+        the given graph module contains the preconditions needed to run the
+        pass. It is not required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run checks on
+        """
+
+    def ensures(self, graph_module: GraphModule) -> None:  # noqa: B027
+        """
+        This function will be called after the pass is run and will check that
+        the given graph module contains the postconditions needed to run the
+        pass. It is not required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run checks on
+        """
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..87fb6e70037f9a00c46143f87efc5a832a7db3ae
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py
@@ -0,0 +1,310 @@
+# mypy: allow-untyped-defs
+import inspect
+import logging
+from collections.abc import Callable
+from functools import wraps
+from queue import Queue
+
+import torch.nn as nn
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes.infra.pass_base import PassResult
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"]
+
+
+@compatibility(is_backward_compatible=False)
+def pass_result_wrapper(fn: Callable) -> Callable:
+    """
+    Wrapper for passes which currently do not return a PassResult.
+    This wrapper makes them return a PassResult containing the modified object
+    and True for the "modified" flag.
+
+    Args:
+        fn (Callable[Module, Any])
+
+    Returns:
+        wrapped_fn (Callable[Module, PassResult])
+    """
+    if fn is None:
+        # pyrefly: ignore [bad-return]
+        return None
+
+    @wraps(fn)
+    def wrapped_fn(gm):
+        res = fn(gm)
+        if res is None:
+            return PassResult(gm, True)
+        if isinstance(res, PassResult):
+            return res
+        elif isinstance(res, nn.Module):
+            return PassResult(res, True)
+
+    if not inspect.isfunction(fn):
+        wrapped_fn.__name__ = type(fn).__name__
+
+    return wrapped_fn
+
+
+def _validate_pass_schedule_constraint(
+    constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
+) -> None:
+    for i, a in enumerate(passes):
+        for j, b in enumerate(passes[i + 1 :]):
+            if constraint(a, b):
+                continue
+            raise RuntimeError(
+                f"pass schedule constraint violated. Expected {a} before {b}"
+                f" but found {a} at index {i} and {b} at index{j} in pass"
+                f" list."
+            )
+
+
+def _topological_sort_passes(
+    passes: list[Callable], constraints: list[Callable]
+) -> list[Callable]:
+    """
+    Args
+        passes: Passes that we are ordering
+        constraints: Constraints applied on these passes
+
+    Returns
+        A sorted list of callables and a boolean of if a circular dependency
+        existed
+    """
+    if len(constraints) == 0:
+        return passes
+
+    # Construct a graph mapping nodes to a list of their users
+    graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
+    indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
+    candidates: Queue = Queue()
+    for a in passes:
+        for b in passes:
+            if a == b:
+                continue
+
+            for constraint in constraints:
+                if not constraint(a, b):
+                    graph[b].append(a)
+                    indegree_map[a] += 1
+
+        if indegree_map[a] == 0:
+            candidates.put(a)
+
+    visited: dict[Callable, bool] = dict.fromkeys(passes, False)
+    sorted_passes: list[Callable] = []
+
+    while not candidates.empty():
+        p = candidates.get()
+        sorted_passes.append(p)
+        visited[p] = True
+
+        for n in graph[p]:
+            if not visited[n]:
+                indegree_map[n] -= 1
+                if indegree_map[n] == 0:
+                    candidates.put(n)
+
+    # Check if there are unvisited nodes (aka cycles in the graph)
+    cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
+    if len(cycle_passes) != 0:
+        error = (
+            f"Circular dependency detected within the following passes: {cycle_passes}"
+        )
+        raise RuntimeError(error)
+
+    return sorted_passes
+
+
+@compatibility(is_backward_compatible=False)
+def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
+    """
+    Defines a partial order ('depends on' function) where `this` must occur
+    before `that`.
+
+    For example, the following pass list and constraint list would be invalid.
+    ```
+    passes = [pass_b, pass_a]
+
+    constraints = [this_before_that_pass_constraint(pass_a, pass_b)]
+    ```
+
+    Args:
+        this (Callable): pass which should occur first
+        that (Callable): pass which should occur later
+
+    Returns:
+        depends_on (Callable[[Object, Object], bool]
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        return a != that or b != this
+
+    return depends_on
+
+
+@compatibility(is_backward_compatible=False)
+class PassManager:
+    """
+    Construct a PassManager.
+
+    Collects passes and constraints. This defines the pass schedule, manages
+    pass constraints and pass execution.
+
+    Args:
+        passes (Optional[List[Callable]]): List of passes. A pass is a
+            callable which modifies an object and returns a PassResult
+        constraint (Optional[List[Callable]]): List of constraints. A
+            constraint is a callable which takes two passes (A, B) and returns
+            True if A depends on B and False otherwise. See implementation of
+            `this_before_that_pass_constraint` for example.
+        steps (int): Max number of times we run the passes (default = 1).
+        run_checks_after_each_pass (bool): Whether to run checks and linting
+            after each pass
+        suppress_check_failures (bool): Whether to raise errors when running
+            checks
+    """
+
+    passes: list[Callable[[nn.Module], PassResult]]
+    constraints: list[Callable[[Callable, Callable], bool]]
+    _validated: bool = False
+    steps: int = 1
+
+    def __init__(
+        self,
+        passes=None,
+        constraints=None,
+        steps=None,
+        run_checks_after_each_pass: bool = False,
+        suppress_check_failures: bool = False,
+    ):
+        self.passes = passes or []
+        self.constraints = constraints or []
+        if steps:
+            self.steps = steps
+
+        self.run_checks_after_each_pass = run_checks_after_each_pass
+        self.suppress_check_failures = suppress_check_failures
+
+    def add_pass(self, _pass: Callable):
+        """
+        Adds a pass into the current list of passes.
+        """
+        self.passes.append(_pass)
+        self._validated = False
+
+    def add_constraint(self, constraint: Callable):
+        """
+        Adds a constraint into the current list of constraints.
+        """
+        self.constraints.append(constraint)
+        self._validated = False
+
+    def validate_constraints(self):
+        """
+        Validates that current pass schedule defined by `self.passes` is valid
+        according to all constraints in `self.constraints`
+        """
+        if self._validated:
+            return
+        for constraint in self.constraints:
+            _validate_pass_schedule_constraint(constraint, self.passes)
+        self._validated = True
+
+    def solve_constraints(self):
+        """
+        Finds a valid traversal order based on the given constraints and orders
+        the passes based on this order.
+
+        If a circular dependency exists between the constraints and steps = 1,
+        then we will raise an error because if steps != 1 this means that we
+        will re-run the passes, allowing for circular dependencies.
+        """
+        self.passes = _topological_sort_passes(self.passes, self.constraints)
+        self._validated = True
+
+    def add_checks(self, check: Callable) -> None:
+        """
+        Adds a function which takes runs various checks on a given graph module.
+        This function is run before and after each pass if the
+        `run_checks_after_each_pass` flag is enabled.
+        """
+        sig = inspect.signature(check)
+
+        if len(list(sig.parameters.values())) != 1:
+            raise TypeError(
+                "PassManager check function should only take in one variable, a module"
+            )
+
+        setattr(self, "check", check)  # noqa: B010
+
+    def check(self, module: nn.Module) -> None:
+        pass
+
+    def __call__(self, module: nn.Module) -> PassResult:
+        """
+        Runs a list of passes in the order based on `self.passes` on the given
+        graph module. Each time a pass is run, checks and linting will be run on
+        the graph module if `run_checks_after_each_pass` is set.
+
+        If the module is a graph module, we will run the list of passes until
+        the graph stops changing, or until `steps` number of times.
+        """
+        # Order the passes based on the constraints
+        if not self._validated:
+            self.solve_constraints()
+
+        # Check graph invariants
+        self.check(module)
+
+        # Run the set of passes `steps` number of times or until the graph stops
+        # changing
+        overall_modified = False
+        for _ in range(self.steps):
+            modified = False
+
+            # Run the set of passes on the graph module
+            for i, fn in enumerate(self.passes):
+                fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
+                logger.debug("Running pass '%s'", fn_name)
+
+                try:
+                    res = fn(module)
+
+                    if not isinstance(res, PassResult) and not hasattr(
+                        res, "graph_module"
+                    ):
+                        raise TypeError(
+                            f"The result of the pass {fn_name} should be type PassResult."
+                            + "Please wrap it with pass_result_wrapper()"
+                        )
+                    module = res.graph_module
+                    modified = modified or res.modified
+
+                    if isinstance(module, GraphModule):
+                        logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
+                        module.recompile()
+
+                    # Check graph invariants
+                    if self.run_checks_after_each_pass:
+                        self.check(module)
+
+                except Exception as e:
+                    prev_pass_names = [
+                        p.__name__ if inspect.isfunction(p) else type(p).__name__
+                        for p in self.passes[:i]
+                    ]
+                    msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
+                    raise Exception(msg) from e  # noqa: TRY002
+
+            # If the graph no longer changes, then we can stop running these passes
+            overall_modified = overall_modified or modified
+            if not modified:
+                break
+
+        return PassResult(module, overall_modified)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e98bad06e5a55eedbc5177f7053d000c25f23e8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py
@@ -0,0 +1,983 @@
+# mypy: allow-untyped-defs
+import logging
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, cast, Optional
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
+
+from .shape_prop import ShapeProp
+from .split_utils import split_by_tags
+from .tools_common import (
+    CALLABLE_NODE_OPS,
+    FxNetAccFusionsFinder,
+    Names,
+    NodeList,
+    NodeSet,
+    TensorOrTensors,
+    Tensors,
+)
+
+
+__all__ = [
+    "FxNetMinimizerBadModuleError",
+    "FxNetMinimizerRunFuncError",
+    "FxNetMinimizerResultMismatchError",
+]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerBadModuleError(Exception):
+    """
+    Raised if failed to split out a minimize module
+    """
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerRunFuncError(Exception):
+    """
+    Raised if error occurs during run_a or run_b functions
+    """
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerResultMismatchError(Exception):
+    """
+    Raised if comparing function thinks the results are mismatching.
+    """
+
+
+@dataclass
+class _MinimizerSettingBase:
+    """
+    Args:
+    `accumulate_error`: Instead of using a's input for both converted module to verify
+    , use the previous outputs of each converted module as input to accumulate the
+    errors.
+
+    `traverse_method`: "sequential" or "binary" or "accumulate"
+    Determine the way of traverse the nodes in FX module.
+
+    `find_all`: Minimizer will go through the entire model and return all problematic nodes.
+
+    `return_intermediate`: If true, when using `run_nodes()` function to run the
+    model, intermediate results of all the ops will be returned as output.
+
+    `all_outputs`: If true, when using `_run_and_compare()` function,
+    all the output nodes in the subgraph will be used for comparison.
+    """
+
+    accumulate_error: bool = False
+    traverse_method: str = "sequential"
+    find_all: bool = False
+    return_intermediate: bool = False
+    all_outputs: bool = False
+
+    def __str__(self):
+        settings_str = "FX Minimizer Settings:\n"
+
+        for k, v in vars(self).items():
+            settings_str += f"\t{k}: {v}\n"
+
+        return settings_str
+
+
+class _MinimizerBase:
+    """
+    This class is used to automatically find problematic nodes in a model. It takes a FX
+    graphmodule and generate some submodules while traverse the graph. Then two functions
+    `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
+    will be used to compare the results.
+
+    Currently we provides two ways to traverse the graph and generate submodules.
+        1. Sequential traversal: this will traverse the graph node by node and generate
+           one submodule with one single node.
+        2. Binary searching: this will do a binary search style traversal on the graph.
+
+    For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tensors,
+        compare_fn: Callable[
+            [TensorOrTensors, TensorOrTensors, Names], tuple[float, bool]
+        ],
+        settings: _MinimizerSettingBase,
+        module_exporter: Optional[
+            Callable[[Tensors, torch.fx.GraphModule, str], None]
+        ] = None,
+        exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None,
+    ):
+        assert isinstance(module, torch.fx.GraphModule)
+
+        self.module = module
+        self.sample_input = sample_input
+        self.compare_fn = compare_fn
+        self.module_exporter = module_exporter
+        self.settings = settings
+        self.exclusion_fn = exclusion_fn
+
+        # Stores outputs of run_a function
+        self.a_outputs: dict[str, Any] = {}
+
+        # Stores outputs of run_b function
+        self.b_outputs: dict[str, Any] = {}
+
+        # Stores the results of compare_fn
+        self.results: dict[Any, Any] = {}
+
+        # Stores the report for the runs
+        self.reports: list[list[str]] = []
+
+        # Current iteration
+        self.iteration: int = 0
+
+        callable_nodes = {
+            node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
+        }
+        self.run_shape_prop()
+        self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
+
+        # Check if number of input in sample_input matches the number of placeholders
+        placeholders = [
+            node.name for node in self.module.graph.nodes if node.op == "placeholder"
+        ]
+        assert len(placeholders) == len(self.sample_input)
+
+        # Store sample_input
+        for i, name in enumerate(placeholders):
+            self.a_outputs[name] = sample_input[i]
+            self.b_outputs[name] = sample_input[i]
+
+    def run_shape_prop(self) -> None:
+        """
+        Helper function to run shape propagation on module. Can be overridden by
+        subclasses for custom shape propagation logic.
+        """
+        ShapeProp(self.module).propagate(*self.sample_input)
+
+    def run_a(
+        self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
+    ) -> TensorOrTensors:
+        """
+        Run `mod` with `inputs` and generate output. The output will be compared with
+        output of run_b().
+        """
+        raise RuntimeError("run_a() is not implemented.")
+
+    def run_b(
+        self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
+    ) -> TensorOrTensors:
+        """
+        Run `mod` with `inputs` and generate output. The output will be compared with
+        output of run_a().
+        """
+        raise RuntimeError("run_b() is not implemented.")
+
+    def _store_outputs(
+        self,
+        a_result: TensorOrTensors,
+        b_result: TensorOrTensors,
+        submodule: torch.fx.GraphModule,
+    ):
+        """
+        Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
+        self.b_outputs, so that we can use them when execute preceding nodes that
+        use those outputs as inputs.
+
+        Args:
+            a_result: Output of self.run_a(). Could be a tensor or tensors.
+            b_result: Output of self.run_b(). Could be a tensor or tensors.
+            submodule: The module that generates a_result and b_result.
+        """
+        output_node = next(
+            node for node in submodule.graph.nodes if node.op == "output"
+        )
+
+        # Only one output
+        if isinstance(output_node.args[0], torch.fx.Node):
+            self.a_outputs[output_node.args[0].name] = a_result
+            self.b_outputs[output_node.args[0].name] = b_result
+        # Multiple outputs
+        else:
+            for i, arg in enumerate(output_node.args[0]):
+                self.a_outputs[arg.name] = a_result[i]
+                self.b_outputs[arg.name] = b_result[i]
+
+    def _get_submod_inputs(
+        self, main_module: torch.fx.GraphModule, submod_path: str
+    ) -> tuple[Tensors, Tensors]:
+        """
+        Try get submodule inputs from stored outputs. If not found then use
+        torch_glow.get_submod_inputs to get the inputs.
+
+        If accumulate_error is False, use a_input for run_a() and run_b()
+        otherwise use a_input for run_a and b_input for run_b.
+
+        Args:
+            main_module: Top-levlel fx module.
+            submod_path: Path to the submodule we want to run and compare results.
+
+        Returns:
+            a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
+            b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
+        """
+        a_input = []
+        b_input = []
+        submodule = getattr(main_module, submod_path)
+        placeholders = [
+            node.name for node in submodule.graph.nodes if node.op == "placeholder"
+        ]
+
+        # If all placeholder can be found in stored outputs, use stored
+        # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
+        # to get the inputs.
+        if set(placeholders) <= self.a_outputs.keys():
+            for name in placeholders:
+                a_input.append(self.a_outputs[name])
+                b_input.append(self.b_outputs[name])
+        else:
+            if self.settings.accumulate_error:
+                print(f"Can't find previous stored outputs named {placeholders}!")
+
+            def get_inputs(self: torch.nn.Module, inputs: Any):
+                nonlocal a_input
+                a_input = inputs
+
+            # Use forward hook to get the inputs to the submodule
+            handle = submodule.register_forward_pre_hook(get_inputs)
+            main_module(*self.sample_input)
+            handle.remove()
+
+            b_input = a_input
+
+        if not self.settings.accumulate_error:
+            return a_input, a_input
+
+        return a_input, b_input
+
+    def _tag_nodes(self, selected_nodes: NodeSet):
+        """
+        Tag selected nodes with tag "minimize". Nodes with the same tags will
+        be split to the same submodule afterwards.
+
+        Args:
+            selected_nodes: Nodes that we want to minimize. We will tag those nodes
+                with "minimize", all preceding nodes with "main_0" and all following
+                nodes with "main_1".
+        """
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            if node in selected_nodes:
+                node.tag = "minimize"
+            elif any(
+                n.tag in {"minimize", "main_1"}
+                for n in node.all_input_nodes
+                if n.op in CALLABLE_NODE_OPS
+            ):
+                node.tag = "main_1"
+            else:
+                node.tag = "main_0"
+
+    def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]:
+        """
+        Split self.module so that one submodule consists of `nodes` and only `nodes`.
+
+        Args:
+            nodes: Nodes that we want to include in the minimize submodule.
+
+        Returns:
+            split_module (torch.fx.GraphModule): the module after split.
+            submodule_name (str): the name of the submodule that consists of `nodes`.
+        """
+        # Color provided nodes
+        self._tag_nodes(nodes)
+
+        # Split module based on coloring
+        split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
+
+        # Find submodule containing colored nodes
+        submodule_name: str = ""
+        for child_name, _ in split_module.named_children():  # type: ignore[union-attr]
+            # Skip submodules we're not interested in at the moment
+            if "minimize" not in child_name:
+                continue
+
+            if submodule_name == "":
+                submodule_name = child_name
+            else:
+                raise FxNetMinimizerBadModuleError(
+                    f"Expected only one minimize submodule with nodes {nodes}"
+                )
+
+        if submodule_name == "":
+            raise FxNetMinimizerBadModuleError(
+                f"Minimize submodule was not found with nodes {nodes}"
+            )
+
+        return split_module, submodule_name  # type: ignore[return-value]
+
+    def _run_and_compare(
+        self,
+        split_module: torch.fx.GraphModule,
+        submod_name: str,
+        output_names: Names,
+        report_idx: int = -1,
+    ):
+        """
+        Run the submodule in `split_module` that has name `submod_name`
+        using `self.run_a` and `self.run_b` and compare their results.
+
+        Args:
+            split_module: Main module that contains the minimize submodule.
+            submod_name: Name of the minimize submodule.
+            output_names: Names of the node we want to output. If None, we
+                will use the original output.
+        """
+        submodule = getattr(split_module, submod_name)
+        a_input, b_input = self._get_submod_inputs(split_module, submod_name)
+
+        if len(self.reports) == 0:
+            self.reports.append([])
+            self.iteration = 1
+
+        report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1]
+        report.append("Run and compare ...")
+
+        if output_names and not self.settings.all_outputs:
+            output_nodes: NodeList = []
+            for node in submodule.graph.nodes:
+                if node.op == "output":
+                    submodule.graph.erase_node(node)
+
+                if node.name in output_names:
+                    output_nodes.append(node)
+
+            submodule.graph.output(
+                output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
+            )
+            submodule.graph.lint()
+            submodule.recompile()
+
+        # Use name of args in output node as key to store comparison result
+        for node in submodule.graph.nodes:
+            if node.op == "output":
+                result_key = map_arg(node.args, lambda x: x.name)
+
+        try:
+            a_result = self.run_a(submodule, a_input, report_idx)
+            b_result = self.run_b(submodule, b_input, report_idx)
+            self._store_outputs(a_result, b_result, submodule)
+        except Exception as e:
+            report.append(f"Exception raised when running {submod_name}: {e}")
+            raise FxNetMinimizerRunFuncError(  # noqa: B904
+                f"Exception raised when running {submod_name}: {e}"
+            )
+
+        # Compare results
+        names: Names = output_names
+        if output_names is None:
+            names = [str(v) for v in result_key]  # type: ignore[possibly-undefined]
+
+        numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
+
+        self.results[result_key] = numeric_result  # type: ignore[possibly-undefined]
+        report.append(f"Numerical accuracy = {numeric_result}")
+        if not bool_result:
+            report.append(f"Result mismatch for {result_key}")  # type: ignore[possibly-undefined]
+            if self.module_exporter:
+                if isinstance(result_key, tuple):  # type: ignore[possibly-undefined]
+                    # pyrefly: ignore [unbound-name]
+                    result_key = result_key[-1]
+                # If the result is still a tuple (happens in non-sequential mode),
+                # we only use the first element as name.
+                if isinstance(result_key, tuple):  # type: ignore[possibly-undefined]
+                    # pyrefly: ignore [unbound-name]
+                    result_key = str(result_key[0])
+                # pyre-ignore[29]: not a function
+                self.module_exporter(
+                    a_input,
+                    submodule,
+                    # pyrefly: ignore [unbound-name]
+                    result_key + "_cpu",
+                )
+                # pyre-ignore[29]: not a function
+                self.module_exporter(
+                    b_input,
+                    submodule,
+                    # pyrefly: ignore [unbound-name]
+                    result_key + "_acc",
+                )
+            raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")  # type: ignore[possibly-undefined]
+
+    def _binary_search_impl(
+        self, all_nodes: NodeList, start_idx: int, end_idx: int
+    ) -> NodeSet:
+        """
+        Recursive binary search implementation.
+        """
+        culprits: NodeSet = set()
+        nodes: NodeList = all_nodes[start_idx:end_idx]
+
+        report: list[str] = []
+        if self.exclusion_fn is not None:
+            self.exclusion_fn(nodes, start_idx, end_idx)
+            if len(nodes) == 0:
+                report = ["All nodes are excluded by user"]
+                self.reports.append(report)
+                return culprits
+
+        first_node_name = nodes[0].name
+        output_node_name = nodes[-1].name
+        self.iteration += 1
+        self.reports.append(report)
+        report.append(f"Binary search iteration {self.iteration}")
+        report.append(
+            f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. "
+            f"Size of the interested node list is {len(nodes)}"
+        )
+        cur_nodes: NodeSet = set(nodes)
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [output_node_name])
+
+        except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
+            if len(nodes) == 1:
+                report.append(
+                    f"This is the last node in the sub-module. "
+                    f"Search in the current branch is successful with culprit = {cur_nodes}."
+                )
+                self.print_report(report)
+                return cur_nodes
+
+            report.append(
+                "Proceed to split and lower the halves of the current "
+                "sub-module individually."
+            )
+            self.print_report(report)
+
+            mid = len(nodes) // 2
+            culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
+
+            if len(culprits) != 0 and not self.settings.find_all:
+                return culprits
+
+            culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
+
+            if len(culprits) == 0:
+                report.append(
+                    f"Further split and lowering found no errors. "
+                    f"Unable to minimize the submodule with list of nodes: {nodes}"
+                )
+                self.print_report(report)
+
+            return culprits
+        else:
+            report.append("No discrepancy found.")
+            self.print_report(report)
+            return set()
+
+    def _binary_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        Binary search on `nodes` for culprit.
+        """
+        return self._binary_search_impl(nodes, 0, len(nodes))
+
+    def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        Traverse `nodes` one by one and determine if any of them is a culprit.
+        """
+        culprits: NodeSet = set()
+
+        for node in nodes:
+            report: list[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Sequential traverse iteration {self.iteration}.")
+            report.append(f"Visit node: {node.name}")
+
+            _LOGGER.info("Visit node: %s", node.name)
+            node_list: NodeList = [node]
+            if self.exclusion_fn is not None:
+                self.exclusion_fn(node_list, -1, -1)
+                if len(node_list) == 0:
+                    report.append(f"User exclusion : {node.name}")
+                    self.print_report(report)
+                    if not self.settings.find_all:
+                        return culprits
+                    else:
+                        continue
+
+            cur_nodes: NodeSet = {node}
+
+            if node in self.fusions:
+                cur_nodes = self.fusions[node]
+
+            try:
+                split_module, submod_name = self._build_submodule(cur_nodes)
+                self._run_and_compare(split_module, submod_name, [node.name])
+                self.print_report(report)
+            except FxNetMinimizerResultMismatchError:
+                culprits.add(node)
+                report.append(f"Found culprit from numeric error: {node}")
+                self.print_report(report)
+                if not self.settings.find_all:
+                    return culprits
+            except FxNetMinimizerRunFuncError:
+                culprits.update(cur_nodes)
+                report.append(f"Found culprit from run error: {node}")
+                self.print_report(report)
+                if not self.settings.find_all:
+                    return culprits
+
+        return culprits
+
+    def _block_traverse_impl(
+        self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool
+    ) -> Optional[int]:
+        """
+        Recursive block search implementation.
+        find_last_node: If True, search for the last node which result in numerics difference
+        if False: find first node in sorted node list
+        """
+        report: list[str] = []
+
+        mid = (start_idx + end_idx) // 2
+        cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
+
+        if self.exclusion_fn:
+            self.exclusion_fn(cur_nodes_list, -1, -1)
+
+        cur_nodes = set(cur_nodes_list)
+
+        first_node_name = cur_nodes_list[0].name
+        last_node_name = cur_nodes_list[-1].name
+        target_node_name = last_node_name if find_last_node else first_node_name
+
+        self.iteration += 1
+        self.reports.append(report)
+        report.extend(
+            [
+                "=" * 30,
+                f"Block search iteration {self.iteration}",
+            ]
+        )
+        report.extend(
+            [
+                f"Search for {'last' if find_last_node else 'first'} node in culprits",
+                f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ",
+                f"Subgraph constructed by {first_node_name} to {last_node_name}",
+                f"Targeting node: {target_node_name}",
+                f"Size of the interested node list is {end_idx - start_idx + 1}",
+            ]
+        )
+        report_idx = len(self.reports) - 1
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(
+                split_module, submod_name, [last_node_name], report_idx
+            )
+        except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
+            report.append(
+                f"Culprits found from node {first_node_name} to {last_node_name}."
+            )
+
+            if start_idx == mid == end_idx:
+                report.extend(
+                    [
+                        "This is the last node in the sub-module. ",
+                        "Search in the current branch is successful with node :",
+                        f"{start_idx}, node name: {nodes[start_idx].name}.",
+                    ]
+                )
+                self.print_report(report)
+                return start_idx
+
+            report.append(
+                "Proceed to split and lower the halves of the current "
+                "sub-module individually."
+            )
+            self.print_report(report)
+
+            if find_last_node:
+                return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
+            else:
+                return self._block_traverse_impl(
+                    nodes, mid + 1, end_idx, find_last_node
+                )
+        else:
+            report.append(
+                f"Culprits not found from node start to {mid}:{nodes[mid].name}."
+            )
+
+            if start_idx == mid == end_idx:
+                # We did not find anything if the pointers have not moved
+                if (start_idx == 0 and not find_last_node) or (
+                    start_idx == len(nodes) - 1 and find_last_node
+                ):
+                    report.append(
+                        f"At {'last' if find_last_node else 'first'} node, no culprits found."
+                    )
+                    self.print_report(report)
+                    return None
+
+                # Otherwise, we have converged on the border between discrepancy and valid
+                return start_idx + (1 if find_last_node else -1)
+
+            report.append(
+                "Proceed to split and lower the halves of the current "
+                "sub-module individually."
+            )
+            self.print_report(report)
+
+            if find_last_node:
+                return self._block_traverse_impl(
+                    nodes, mid + 1, end_idx, find_last_node
+                )
+            else:
+                return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
+
+    def _block_traverse(
+        self, nodes: NodeList, find_last_node: Optional[bool]
+    ) -> NodeSet:
+        """
+        Traverse topologically sorted node list
+        Find minimum block (start_idx, end_idx) which contains the culprit
+        1st pass: search for end_idx by finding the last node in culprit block
+        where Numerical accuracy (0, end_idx) > threshold
+        2nd pass: search for start_idx by finding the first node in culprit block
+        where Numerical accuracy (start_idx, end_idx) < threshold
+        Form minimum block by (start_idx - 1, end_idx)
+        """
+        culprits: NodeSet = set()
+        first_node_name = nodes[0].name
+        last_node_name = nodes[-1].name
+        last_node_report = [f"Block search from {first_node_name} to {last_node_name}"]
+        last_node_report.append("*" * 50)
+        self.reports.append(last_node_report)
+
+        start_idx = 0
+        end_idx = len(nodes) - 1
+
+        final_start_idx: Optional[int] = start_idx
+        final_end_idx: Optional[int] = end_idx
+
+        run_both = find_last_node is None
+
+        # step 1: find (0, end_idx) of culprit block
+        if run_both or find_last_node:
+            last_node_report.append("Start searching for last node in culprit")
+            self.print_report(last_node_report)
+            final_end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
+
+            if final_end_idx is None:
+                last_node_report.append("No culprits found")
+                self.print_report(last_node_report)
+                return culprits
+
+            last_node_report.extend(
+                [
+                    "Finish Pass 1",
+                    f"Find end_idx = {final_end_idx}:{nodes[final_end_idx].name}",
+                ]
+            )
+            self.print_report(last_node_report)
+
+        # step 2: reduce culprit block to (start_idx, end_idx)
+        if run_both or not find_last_node:
+            first_node_report = ["Start searching for first node in culprit"]
+            self.print_report(first_node_report)
+            final_start_idx = self._block_traverse_impl(
+                nodes[0 : end_idx + 1], start_idx, final_end_idx or end_idx, False
+            )
+
+            if final_start_idx is None:
+                last_node_report.append("No culprits found")
+                self.print_report(last_node_report)
+                return culprits
+
+            first_node_report.append("*" * 50)
+            self.reports.append(first_node_report)
+            first_node_report.extend(
+                [
+                    "Finish Pass 2",
+                    f"Find start_idx = {final_start_idx}:{nodes[final_start_idx].name}",
+                ]
+            )
+            self.print_report(first_node_report)
+
+        # step 3: form module with minimum culprits. These indexes are guaranteed to exist
+        range_start, range_end = cast(int, final_start_idx), cast(int, final_end_idx)
+        culprits.update(nodes[range_start : range_end + 1])
+        result_report = [
+            f"Finish searching, found minimum block ({nodes[range_start]},{nodes[range_end]})"
+        ]
+        self.reports.append(result_report)
+        self.print_report(result_report)
+        return culprits
+
+    def _defined_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        run user defined `nodes` and determine if it is a culprit.
+        """
+        culprits: NodeSet = set()
+        if self.exclusion_fn is not None:
+            self.exclusion_fn(nodes, -1, -1)
+        if len(nodes) == 0:
+            report = ["All nodes are excluded by user"]
+            self.reports.append(report)
+            return culprits
+
+        first_node_name = nodes[0].name
+        output_node_name = nodes[-1].name
+        report = [f"Defined graph from {first_node_name} to {output_node_name}"]
+        cur_nodes: NodeSet = set(nodes)
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [output_node_name])
+            self.print_report(report)
+        except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
+            report.append(f"Found culprit {cur_nodes}")
+            self.print_report(report)
+            return culprits
+
+        return culprits
+
+    def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
+        culprits: NodeSet = set()
+        nodes_to_run: NodeSet = set()
+
+        # find_all is not supported for accumulate traversal because all the
+        # ops run on NNPI. So we return after the first op that raises error.
+        if self.settings.find_all:
+            print("'Find All' mode is not supported in accumulate traversal.")
+            return culprits
+
+        for node in nodes:
+            report: list[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Accumulate traverse iteration {self.iteration}.")
+
+            nodes_to_run.add(node)
+
+            node_name = node.name
+            if node_name is not None and isinstance(node_name, tuple):
+                node_name = node_name[0]
+            assert node_name is not None and isinstance(node_name, str), (
+                f"minimize: node_name: {node_name}"
+            )
+
+            report.append(f"Add node: {node_name}")
+
+            try:
+                split_module, submod_name = self._build_submodule(nodes_to_run)
+                self._run_and_compare(split_module, submod_name, [node_name])
+                self.print_report(report)
+            except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
+                culprits.add(node)
+                report.append(f"Found culprit {node}")
+                self.print_report(report)
+                return culprits
+
+        return culprits
+
+    def _skip_traverse_impl(
+        self, all_nodes: NodeList, start_idx: int, end_idx: int
+    ) -> NodeSet:
+        """
+        Skip certain nodes in graph based on settings
+        """
+        culprits: NodeSet = set()
+        nodes: NodeList = all_nodes[start_idx:end_idx]
+        cur_nodes: NodeSet = set(nodes)
+        if self.exclusion_fn is not None:
+            self.exclusion_fn(nodes, start_idx, end_idx)
+            cur_nodes = set(nodes)
+        else:
+            for node in nodes:
+                if node in self.fusions:
+                    cur_nodes.update(self.fusions[node])
+        report: list[str] = []
+        self.reports.append(report)
+        self.iteration += 1
+        report.append(f" Nodes block {self.iteration}.")
+        report.append(
+            f"From node index {start_idx} to {end_idx - 1}. "
+            f"Size of the interested node list is {len(nodes)}"
+        )
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [])
+        except FxNetMinimizerResultMismatchError:
+            culprits.update(cur_nodes)
+            report.append(f"Found culprit from numeric error: {cur_nodes}")
+            self.print_report(report)
+            return culprits
+        except FxNetMinimizerRunFuncError:
+            culprits.update(cur_nodes)
+            report.append(f"Found culprit from run error: {cur_nodes}")
+            self.print_report(report)
+            return culprits
+        else:
+            report.append("No discrepancy found.")
+            self.print_report(report)
+            return set()
+
+    def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet:
+        """
+        Skip certain nodes in graph based on settings
+        """
+        start_idx = 0
+        num_nodes = len(all_nodes)
+        idx = 0
+        culprits = set()
+        while idx < num_nodes:
+            node = all_nodes[idx]
+            if node.name in skip_nodes:  # skip the node
+                if idx > start_idx:
+                    culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
+                start_idx = idx + 1
+            elif idx == num_nodes - 1 and start_idx <= idx:  # last node
+                culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
+            idx += 1
+
+        return culprits
+
+    def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
+        """
+        Collect nodes in the model that between nodes with name of `start` and `end`.
+        These two nodes are also included.
+        """
+        nodes: NodeList = []
+        add_node = start is None
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            if node.name == start:
+                add_node = True
+
+            if add_node:
+                nodes.append(node)
+
+            if node.name == end:
+                break
+
+        return nodes
+
+    def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
+        """
+        Run part of the model from `start` node to `end` node. If `start` is None
+        then we start from the beginning of the model. If `end` is None then we
+        stop at the end of the model.
+
+        Args:
+            start: The name of the node which is the first node of the submodule
+                we want to run. If set to None, then we'll start with the first
+                node of the model.
+            end: The name of the node which is the last node of the submodule we
+                want to run. If set to None, we'll end with the last node of the
+                model.
+        """
+        nodes = self._collect_nodes(start, end)
+        cur_nodes = set(nodes)
+
+        for node in nodes:
+            if node in self.fusions:
+                cur_nodes.update(self.fusions[node])
+
+        output_names = []
+        if self.settings.return_intermediate:
+            output_names = [node.name for node in nodes]
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, output_names)
+        except (
+            FxNetMinimizerRunFuncError,
+            FxNetMinimizerResultMismatchError,
+        ) as e:
+            print(e)
+
+    def print_report(self, report: list[str]):
+        for i in range(len(report)):
+            if i > 0:
+                print(" . " + report[i])
+            else:
+                print(report[i])
+
+    def print_reports(self):
+        for report in self.reports:
+            self.print_report(report)
+
+    def minimize(
+        self,
+        start: Optional[str] = None,
+        end: Optional[str] = None,
+        skip_nodes: Optional[list] = None,
+        find_last_node: Optional[bool] = None,
+    ) -> NodeSet:
+        """
+        Minimizing the model from node with name `start` to node with name `end` base
+        on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
+        FxNetMinimizerResultMismatchError errors.
+
+        Args:
+            start: The name of the node where we want to start minimizing. If set
+                to None, then we'll start with the first node of the model.
+            end: The name of the node where we want to terminate minimizing. If
+                set to None, we'll end with the last node of the model.
+            skip_nodes: The names of nodes where we want to skip during minimizing.
+                It'll create subgraphs without these skip nodes under the hood.
+                Only applicable in mode "skip".
+            find_last_node: True if only last_node of a culprits is needed in mode "block".
+                False if only the first_node of a culprits is needed.
+                Only applicable in mode "block".
+
+        Returns:
+            nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
+                FxNetMinimizerResultMismatchError errors during minimizing.
+        """
+
+        print(self.settings)
+        print(self.module.graph)
+
+        nodes = self._collect_nodes(start, end)
+
+        if self.settings.traverse_method == "sequential":
+            return self._sequential_traverse(nodes)
+
+        if self.settings.traverse_method == "binary":
+            return self._binary_traverse(nodes)
+
+        if self.settings.traverse_method == "accumulate":
+            return self._accumulate_traverse(nodes)
+
+        if self.settings.traverse_method == "skip":
+            if skip_nodes is None:
+                raise RuntimeError(
+                    "'skip_nodes' can't be None when 'traverse_method' is 'skip'."
+                )
+            return self._skip_traverse(nodes, skip_nodes)
+
+        if self.settings.traverse_method == "defined":
+            return self._defined_traverse(nodes)
+
+        if self.settings.traverse_method == "block":
+            return self._block_traverse(nodes, find_last_node)
+
+        raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/operator_support.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/operator_support.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cb14d312b60b0209195706488dd48a359c40b3f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/operator_support.py
@@ -0,0 +1,229 @@
+# mypy: allow-untyped-defs
+import abc
+import typing as t
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+
+from .shape_prop import TensorMetadata
+from .tools_common import CALLABLE_NODE_OPS, get_node_target
+
+
+__all__ = [
+    "OperatorSupportBase",
+    "OperatorSupport",
+    "create_op_support",
+    "chain",
+    "OpSupports",
+    "any_chain",
+]
+
+# fx.Node.target typename, as returned by `get_node_target()`
+TargetTypeName = str
+
+# Arguments' dtypes for a given node, see `OperatorSupport`
+SupportedArgumentDTypes = t.Optional[
+    tuple[
+        t.Sequence[t.Sequence[torch.dtype]],
+        dict[str, t.Sequence[torch.dtype]],
+    ]
+]
+
+SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
+
+
+@compatibility(is_backward_compatible=False)
+class OperatorSupportBase(abc.ABC):
+    """Interface for determining if a fx.Node is supported by a backend"""
+
+    @abc.abstractmethod
+    def is_node_supported(
+        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+    ) -> bool:
+        raise NotImplementedError
+
+
+@compatibility(is_backward_compatible=False)
+class OperatorSupport(OperatorSupportBase):
+    """
+    `_support_dict` maps node.target typename to supported inputs dtypes.
+
+    node.target typename is retrieved using helper function `get_node_target()`
+
+    If supported inputs dtypes is None, it means any dtype is supported, else
+    we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
+
+    The first tuple ([dtypes], ...) indicates what dtypes are supported for
+    inputs in node.args and the second dict {"name": [dtypes], ...} indicates
+    what dtypes are supported for inputs in node.kwargs.
+
+    For inputs in args, if we don't want to check it, we can put None there,
+    e.g. (None, [torch.float]) indicates that we don't care about the type of
+    the first input in args. And for inputs in kwargs, if not listed, will not
+    be checked.
+    """
+
+    _support_dict: SupportDict
+
+    def __init__(self, support_dict: t.Optional[SupportDict] = None):
+        self._support_dict = support_dict or {}
+
+    def is_node_supported(
+        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+    ) -> bool:
+        """
+        Args:
+            `submodules`: mapping from module name to the module. This can be
+                          retrieved by calling model.named_modules().
+
+            `node`: a Fx node that we want to determine whether it's supported.
+
+        Returns:
+            `is_supported`: whether the arg `node` is supported.
+        """
+        if node.op not in CALLABLE_NODE_OPS:
+            return True
+
+        target = get_node_target(submodules, node)
+
+        # Target not found in _support_dict meaning that we don't support this op at all
+        if target not in self._support_dict:
+            return False
+
+        # The rule for target is None meaning that we accept any dtype
+        if self._support_dict[target] is None:
+            return True
+
+        args_dtypes, kwargs_dtypes = self._support_dict[target]  # type: ignore[misc]
+
+        # Check args dtypes
+        for i, dtypes in enumerate(args_dtypes):
+            if len(node.args) <= i:
+                break
+
+            # None indicates we don't care about the dtype of args[i]
+            if dtypes is None:
+                continue
+
+            # If arg is not a node then we don't check it
+            if not isinstance(node.args[i], torch.fx.Node):
+                continue
+
+            arg_dtype = _get_arg_dtype(node.args[i])  # type: ignore[arg-type]
+            if arg_dtype not in dtypes:
+                return False
+
+        # Check kwargs dtypes
+        for k, dtypes in kwargs_dtypes.items():
+            if k not in node.kwargs:
+                continue
+
+            # If arg is not a node then we don't check it
+            if not isinstance(node.kwargs[k], torch.fx.Node):
+                continue
+
+            kwarg_dtype = _get_arg_dtype(node.kwargs[k])  # type: ignore[arg-type]
+            if kwarg_dtype not in dtypes:
+                return False
+
+        return True
+
+
+# ======================================================================
+# Functional interfaces and utils for defining basic operator support logic
+# and composing them into more complex ones
+# ======================================================================
+
+IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
+
+
+@compatibility(is_backward_compatible=False)
+def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
+    """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
+
+    `IsNodeSupported` has the same call signature as
+    `OperatorSupportBase.is_node_supported`
+    """
+
+    class FunctionalOperatorSupport(OperatorSupportBase):
+        def is_node_supported(
+            self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+        ) -> bool:
+            return is_node_supported(submodules, node)
+
+    return FunctionalOperatorSupport()
+
+
+@compatibility(is_backward_compatible=False)
+def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
+    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
+    instance by evaluating each input `OperatorSupportBase` instance, and returns False if
+    any of it reports False.
+    """
+
+    def _chain(submods, node) -> bool:
+        return all(x.is_node_supported(submods, node) for x in op_support)
+
+    return create_op_support(_chain)
+
+
+@compatibility(is_backward_compatible=False)
+def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
+    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
+    instance by evaluating each input `OperatorSupportBase` instance, and returns True if
+    any of it reports True.
+    """
+
+    def _any_chain(submods, node) -> bool:
+        return any(x.is_node_supported(submods, node) for x in op_support)
+
+    return create_op_support(_any_chain)
+
+
+@compatibility(is_backward_compatible=False)
+class OpSupports:
+    """A set of atomic `OperatorSupportBase` instances that can be combined together
+    to form more complex operator support logic.
+    """
+
+    @classmethod
+    def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
+        """Report a node as non-supported, if any of its arguments is of dtype"""
+
+        def _decline_if_input_dtype(
+            submodules: t.Mapping[str, torch.nn.Module],
+            node: torch.fx.Node,
+        ) -> bool:
+            for arg in node.all_input_nodes:
+                arg_dtype = _get_arg_dtype(arg)
+                if arg_dtype == dtype:
+                    return False
+            return True
+
+        return create_op_support(_decline_if_input_dtype)
+
+    @classmethod
+    def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase:
+        """
+        If a node has a name that is in the disallow set, reported it as non-supported.
+        """
+
+        def _decline_if_node_in_names(
+            submodules: t.Mapping[str, torch.nn.Module],
+            node: torch.fx.Node,
+        ) -> bool:
+            return node.name not in disallow_set
+
+        return create_op_support(_decline_if_node_in_names)
+
+
+def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
+    assert isinstance(arg, torch.fx.Node)
+    tensor_meta = arg.meta.get("tensor_meta")  # type: ignore[union-attr]
+    dtype = (
+        tensor_meta.dtype
+        if isinstance(tensor_meta, TensorMetadata)
+        else arg.meta["type"]
+    )
+    return dtype
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e17a8040e6a9573200e10bb1fa670bb71219a26
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py
@@ -0,0 +1,97 @@
+from collections.abc import Callable
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+
+
+__all__ = [
+    "default_matching",
+    "extract_attrs_for_lowering",
+    "lift_lowering_attrs_to_nodes",
+]
+
+
+# Matching method matches the attribute name of current version to the attribute name of `target_version`
+@compatibility(is_backward_compatible=False)
+def default_matching(name: str, target_version: int) -> str:
+    """Default matching method"""
+    return name
+
+
+# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
+# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
+# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
+module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
+    torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
+    torch.nn.modules.conv.Conv2d: (
+        1,
+        [
+            "weight",
+            "bias",
+            "kernel_size",
+            "stride",
+            "padding",
+            "dilation",
+            "groups",
+            "padding_mode",
+        ],
+        default_matching,
+    ),
+    torch.nn.modules.batchnorm.BatchNorm2d: (
+        2,
+        ["weight", "bias", "running_mean", "running_var", "eps"],
+        default_matching,
+    ),
+    torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
+    torch.nn.modules.pooling.MaxPool2d: (
+        1,
+        ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
+        default_matching,
+    ),
+    torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
+}
+
+
+@compatibility(is_backward_compatible=False)
+def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
+    """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
+    after checking module's version is compatible with the `module_fetch_book`.
+    """
+    attrs_for_lowering: dict[str, Any] = {}
+    attrs_for_lowering["name"] = torch.typename(mod)
+
+    if type(mod) in module_fetch_book:
+        version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
+        if version < mod._version:
+            raise RuntimeError(
+                f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
+                "please upgrade the module_fetch_book, open an issue and @842974287 "
+                "or report a bug to AIACC team directly."
+            )
+        for attr in param_to_fetch:
+            attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
+    else:
+        raise RuntimeError(
+            f"{torch.typename(mod)} is not in the module_fetch_book yet, "
+            "please add it to the module_fetch_book, open an issue and @842974287 "
+            "or report a bug to AIACC team directly."
+        )
+    return attrs_for_lowering
+
+
+@compatibility(is_backward_compatible=False)
+def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
+    """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
+    submodules = dict(fx_module.named_modules())
+
+    for node in fx_module.graph.nodes:
+        if node.op == "call_module":
+            if isinstance(submodules[node.target], GraphModule):
+                lift_lowering_attrs_to_nodes(submodules[node.target])
+            else:
+                node.attrs_for_lowering = extract_attrs_for_lowering(
+                    submodules[node.target]
+                )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..297d50a68f474dd8ba791f734027c4edc0e5ff70
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py
@@ -0,0 +1,254 @@
+# mypy: allow-untyped-defs
+import logging
+from collections.abc import Callable
+from functools import wraps
+from inspect import unwrap
+from typing import Optional
+
+
+logger = logging.getLogger(__name__)
+
+__all__ = [
+    "PassManager",
+    "inplace_wrapper",
+    "log_hook",
+    "loop_pass",
+    "this_before_that_pass_constraint",
+    "these_before_those_pass_constraint",
+]
+
+
+# for callables which modify object inplace and return something other than
+# the object on which they act
+def inplace_wrapper(fn: Callable) -> Callable:
+    """
+    Convenience wrapper for passes which modify an object inplace. This
+    wrapper makes them return the modified object instead.
+
+    Args:
+        fn (Callable[Object, Any])
+
+    Returns:
+        wrapped_fn (Callable[Object, Object])
+    """
+
+    @wraps(fn)
+    def wrapped_fn(gm):
+        fn(gm)
+        return gm
+
+    return wrapped_fn
+
+
+def log_hook(fn: Callable, level=logging.INFO) -> Callable:
+    """
+    Logs callable output.
+
+    This is useful for logging output of passes. Note inplace_wrapper replaces
+    the pass output with the modified object. If we want to log the original
+    output, apply this wrapper before inplace_wrapper.
+
+
+    ```
+    def my_pass(d: Dict) -> bool:
+        changed = False
+        if "foo" in d:
+            d["foo"] = "bar"
+            changed = True
+        return changed
+
+
+    pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))])
+    ```
+
+    Args:
+        fn (Callable[Type1, Type2])
+        level: logging level (e.g. logging.INFO)
+
+    Returns:
+        wrapped_fn (Callable[Type1, Type2])
+    """
+
+    @wraps(fn)
+    def wrapped_fn(gm):
+        val = fn(gm)
+        logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
+        return val
+
+    return wrapped_fn
+
+
+def loop_pass(
+    base_pass: Callable,
+    n_iter: Optional[int] = None,
+    predicate: Optional[Callable] = None,
+):
+    """
+    Convenience wrapper for passes which need to be applied multiple times.
+
+    Exactly one of `n_iter`or `predicate` must be specified.
+
+    Args:
+        base_pass (Callable[Object, Object]): pass to be applied in loop
+        n_iter (int, optional): number of times to loop pass
+        predicate (Callable[Object, bool], optional):
+
+    """
+    assert (n_iter is not None) ^ (predicate is not None), (
+        "Exactly one of `n_iter`or `predicate` must be specified."
+    )
+
+    @wraps(base_pass)
+    def new_pass(source):
+        output = source
+        if n_iter is not None and n_iter > 0:
+            for _ in range(n_iter):
+                output = base_pass(output)
+        elif predicate is not None:
+            while predicate(output):
+                output = base_pass(output)
+        else:
+            raise RuntimeError(
+                f"loop_pass must be given positive int n_iter (given "
+                f"{n_iter}) xor predicate (given {predicate})"
+            )
+        return output
+
+    return new_pass
+
+
+# Pass Schedule Constraints:
+#
+# Implemented as 'depends on' operators. A constraint is satisfied iff a list
+# has a valid partial ordering according to this comparison operator.
+def _validate_pass_schedule_constraint(
+    constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
+):
+    for i, a in enumerate(passes):
+        for j, b in enumerate(passes[i + 1 :]):
+            if constraint(a, b):
+                continue
+            raise RuntimeError(
+                f"pass schedule constraint violated. Expected {a} before {b}"
+                f" but found {a} at index {i} and {b} at index{j} in pass"
+                f" list."
+            )
+
+
+def this_before_that_pass_constraint(this: Callable, that: Callable):
+    """
+    Defines a partial order ('depends on' function) where `this` must occur
+    before `that`.
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        return a != that or b != this
+
+    return depends_on
+
+
+def these_before_those_pass_constraint(these: Callable, those: Callable):
+    """
+    Defines a partial order ('depends on' function) where `these` must occur
+    before `those`. Where the inputs are 'unwrapped' before comparison.
+
+    For example, the following pass list and constraint list would be invalid.
+    ```
+    passes = [
+        loop_pass(pass_b, 3),
+        loop_pass(pass_a, 5),
+    ]
+
+    constraints = [these_before_those_pass_constraint(pass_a, pass_b)]
+    ```
+
+    Args:
+        these (Callable): pass which should occur first
+        those (Callable): pass which should occur later
+
+    Returns:
+        depends_on (Callable[[Object, Object], bool]
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        return unwrap(a) != those or unwrap(b) != these
+
+    return depends_on
+
+
+class PassManager:
+    """
+    Construct a PassManager.
+
+    Collects passes and constraints. This defines the pass schedule, manages
+    pass constraints and pass execution.
+
+    Args:
+        passes (Optional[List[Callable]]): list of passes. A pass is a
+            callable which modifies an object and returns modified object
+        constraint (Optional[List[Callable]]): list of constraints. A
+            constraint is a callable which takes two passes (A, B) and returns
+            True if A depends on B and False otherwise. See implementation of
+            `this_before_that_pass_constraint` for example.
+    """
+
+    passes: list[Callable]
+    constraints: list[Callable]
+    _validated: bool = False
+
+    def __init__(
+        self,
+        passes=None,
+        constraints=None,
+    ):
+        self.passes = passes or []
+        self.constraints = constraints or []
+
+    @classmethod
+    def build_from_passlist(cls, passes):
+        pm = PassManager(passes)
+        # TODO(alexbeloi): add constraint management/validation
+        return pm
+
+    def add_pass(self, _pass: Callable):
+        self.passes.append(_pass)
+        self._validated = False
+
+    def add_constraint(self, constraint):
+        self.constraints.append(constraint)
+        self._validated = False
+
+    def remove_pass(self, _passes: list[str]):
+        if _passes is None:
+            return
+        passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
+        self.passes = passes_left
+        self._validated = False
+
+    def replace_pass(self, _target, _replacement):
+        passes_left = []
+        for ps in self.passes:
+            if ps.__name__ == _target.__name__:
+                passes_left.append(_replacement)
+            else:
+                passes_left.append(ps)
+        self.passes = passes_left
+        self._validated = False
+
+    def validate(self):
+        """
+        Validates that current pass schedule defined by `self.passes` is valid
+        according to all constraints in `self.constraints`
+        """
+        if self._validated:
+            return
+        for constraint in self.constraints:
+            _validate_pass_schedule_constraint(constraint, self.passes)
+        self._validated = True
+
+    def __call__(self, source):
+        self.validate()
+        out = source
+        for _pass in self.passes:
+            out = _pass(out)
+        return out
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae98950ab60b0348d80d44541b2d38ebed5d1d67
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py
@@ -0,0 +1,240 @@
+# mypy: allow-untyped-defs
+
+import functools
+import logging
+
+import torch
+from torch.fx._compatibility import compatibility
+
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["regional_inductor"]
+
+
+# standalone_inductor returns a callable class object - this does not sit well
+# with Fx graph node op call_function which expects a function. So this is just
+# a wrapper function to make Fx graph codegen happy.
+def _dummy_wrapper(fn):
+    @functools.wraps(fn)
+    def inner(*args, **kwargs):
+        return fn(*args, **kwargs)
+
+    return inner
+
+
+def _partition_by_supported_nodes(gm, supported_ops, prefix):
+    from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
+    from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
+
+    partitioner = CapabilityBasedPartitioner(
+        gm, supported_ops, allows_single_node_partition=True
+    )
+
+    candidate_partitions = partitioner.propose_partitions()
+    partitioned_gm = fuse_by_partitions(
+        partitioner.graph_module,
+        [partition.nodes for partition in candidate_partitions],
+        prefix=prefix,
+        always_return_tuple=True,
+    )
+
+    return partitioned_gm
+
+
+def _compile_submod(gm, prefix):
+    from torch._inductor.standalone_compile import AOTCompiledArtifact
+
+    for node in gm.graph.nodes:
+        if node.op == "call_module" and node.target.startswith(prefix):
+            fake_inputs = []
+            for inp_node in node.all_input_nodes:
+                if hasattr(inp_node, "meta") and "val" in inp_node.meta:
+                    fake_inputs.append(inp_node.meta["val"])
+                else:
+                    raise RuntimeError(
+                        f"Partition is bad because non fake tensor value is seen {inp_node}"
+                    )
+
+            submod = getattr(gm, node.target)
+
+            # Get inductor configs from annotation
+            # TODO we should change partition when there are multiple differently
+            # annotated regions.
+            inductor_options = {}
+            for sub_node in submod.graph.nodes:
+                if hasattr(sub_node, "meta") and sub_node.meta.get("custom", None):
+                    custom = sub_node.meta["custom"]
+                    if isinstance(custom, dict) and "compile_with_inductor" in custom:
+                        compile_value = custom["compile_with_inductor"]
+                        if (
+                            isinstance(compile_value, dict)
+                            and "inductor_configs" in compile_value
+                        ):
+                            inductor_options = compile_value["inductor_configs"]
+                            break
+
+            # Log the options being used
+            logger.info(
+                "Compiling submodule %s with inductor options: %s",
+                node.target,
+                inductor_options,
+            )
+
+            # Apply config patches before compilation
+            import torch._inductor.config as inductor_config
+
+            # Validate that all config keys exist
+            for key in inductor_options:
+                if not hasattr(inductor_config, key):
+                    raise ValueError(
+                        f"Invalid inductor config key '{key}' in regional_inductor annotation. "
+                        f"Available config keys can be found in torch._inductor.config"
+                    )
+
+            with inductor_config.patch(inductor_options):
+                compiled_fn = torch._inductor.standalone_compile(
+                    submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
+                )
+            assert isinstance(compiled_fn, AOTCompiledArtifact)
+            # _dummy_wrapper is to make call_function happy
+            compiled_submod = _dummy_wrapper(compiled_fn)
+            with gm.graph.inserting_after(node):
+                new_node = gm.graph.call_function(
+                    compiled_submod, args=node.args, kwargs=node.kwargs
+                )
+                new_node.meta = node.meta
+                node.replace_all_uses_with(new_node)
+                gm.graph.erase_node(node)
+                del gm._modules[node.target]
+
+    gm.recompile()
+    return gm
+
+
+def _needs_inductor_compile(node: torch.fx.Node):
+    return (
+        node.op not in ("placeholder", "output")
+        and hasattr(node, "meta")
+        and node.meta.get("custom", None)
+        and "compile_with_inductor" in node.meta["custom"]
+    )
+
+
+class _RegionScooper:
+    """
+    Scoops out the inductor marked regions. It does NOT compile them.
+    """
+
+    @staticmethod
+    def scoop_regions(gm):
+        from torch.fx.passes.operator_support import OperatorSupport
+
+        found_marked_node = False
+        for node in gm.graph.nodes:
+            if _needs_inductor_compile(node):
+                found_marked_node = True
+                break
+
+        if not found_marked_node:
+            logger.info("No inductor marked nodes found")
+            return gm
+
+        class InductorMarkedNodes(OperatorSupport):
+            def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
+                return _needs_inductor_compile(node)
+
+        marked_nodes = InductorMarkedNodes()
+        return _partition_by_supported_nodes(
+            gm, marked_nodes, "__marked_inductor_submod"
+        )
+
+    @staticmethod
+    def recursively_scoop_regions(gm):
+        for node in gm.graph.find_nodes(op="get_attr"):
+            if _needs_inductor_compile(node):
+                # If the get_attr itself is marked for compile, the outer graph will
+                # take care of it. If we dont do that, we end up with nested
+                # regional inductor compiles that do not work well.
+                continue
+            submod = getattr(gm, node.target)
+            if isinstance(submod, torch.fx.GraphModule):
+                _RegionScooper.recursively_scoop_regions(submod)
+
+        return _RegionScooper.scoop_regions(gm)
+
+    def __call__(self, gm):
+        with torch.fx.traceback.preserve_node_meta(enable=False):
+            return _RegionScooper.recursively_scoop_regions(gm)
+
+
+class _RegionCompiler:
+    """
+    Compiles the scooped out regions.
+    """
+
+    @staticmethod
+    def compile_region(gm):
+        from torch.fx.graph import _BoxedCodeGen
+
+        gm = _compile_submod(gm, "__marked_inductor_submod")
+        gm.graph.set_codegen(_BoxedCodeGen())
+        gm.recompile()
+        return gm
+
+    @staticmethod
+    def recursively_compile_regions(gm):
+        # Find if the graph module has a scooped out region
+        found_region = False
+        for node in gm.graph.find_nodes(op="call_module"):
+            submod = getattr(gm, node.target)
+            if isinstance(submod, torch.fx.GraphModule):
+                if node.target.startswith("__marked_inductor_submod"):
+                    found_region = True
+
+        # Recurse through the subgraphs
+        for node in gm.graph.find_nodes(op="get_attr"):
+            submod = getattr(gm, node.target)
+            if isinstance(submod, torch.fx.GraphModule):
+                _RegionCompiler.recursively_compile_regions(submod)
+
+        if found_region:
+            return _RegionCompiler.compile_region(gm)
+        return gm
+
+    def __call__(self, gm):
+        with torch.fx.traceback.preserve_node_meta(enable=False):
+            return _RegionCompiler.recursively_compile_regions(gm)
+
+
+def _create_inductor_marked_regions(gm):
+    with torch.fx.traceback.preserve_node_meta(enable=False):
+        return _RegionScooper()(gm)
+
+
+def _compile_inductor_marked_regions(gm):
+    with torch.fx.traceback.preserve_node_meta(enable=False):
+        return _RegionCompiler()(gm)
+
+
+@compatibility(is_backward_compatible=False)
+def regional_inductor(gm, *example_args):
+    """
+    Scoops out inductor marked regions and compiles them with inductor.
+
+    Inductor options should be provided via the annotation API:
+    with fx_traceback.annotate({
+        "compile_with_inductor": {
+            "inductor_configs": {
+                "max_autotune": True,
+                "triton.cudagraphs": False
+            }
+        }
+    }):
+    """
+    # fuser utils create new nodes using create_proxy which retains the seq_nr
+    # metadata and cause issues
+    with torch.fx.traceback.preserve_node_meta(enable=False):
+        gm = _create_inductor_marked_regions(gm)
+        gm = _compile_inductor_marked_regions(gm)
+        return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/reinplace.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/reinplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dba9f0ca12f0d5545cfab18e84edfd19a1a7758
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/reinplace.py
@@ -0,0 +1,755 @@
+# mypy: allow-untyped-defs
+import _operator
+import itertools
+from collections import defaultdict
+from collections.abc import Callable
+from enum import Enum
+from typing import Any
+
+import torch
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch.fx import Node
+from torch.fx._compatibility import compatibility
+from torch.multiprocessing.reductions import StorageWeakRef
+from torch.utils import _pytree as pytree
+from torch.utils._pytree import tree_map_only
+
+
+__all__ = ["reinplace"]
+
+
+class _ViewType(Enum):
+    NonView = 0
+    SingleOutputView = 1
+    MultiOutputView = 2
+
+
+def _is_view_op(tgt):
+    if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
+        schema = tgt._schema
+        if len(schema.arguments) > 0:
+            first_arg = schema.arguments[0]
+            # check if op is a view
+            return (
+                first_arg.alias_info is not None and not first_arg.alias_info.is_write
+            )
+
+
+def _get_view_type(tgt) -> _ViewType:
+    if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
+        schema = tgt._schema
+        if len(schema.arguments) > 0:
+            first_arg = schema.arguments[0]
+            # check if op is a view
+            if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
+                # check if op is a multi-output view
+                if "*" in first_arg.alias_info.after_set:
+                    return _ViewType.MultiOutputView
+                else:
+                    return _ViewType.SingleOutputView
+    return _ViewType.NonView
+
+
+# Stores a bunch of metadata related to functionalization each node.
+# Relevant metadata:
+# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
+#   The fake tensor output from running the current node
+# n.meta['view_of']: Node
+#   If the current node n is a view of some base tensor, the 'view_of' field tells us which
+#   view node was used to generate the current node (a view tensor).
+#   This information actually makes `fake_result` redundant, but we can use `fake_result`
+#   to sanity check that our aliasing information is correct.
+@compatibility(is_backward_compatible=False)
+class _FunctionalizationMetadataProp(torch.fx.Interpreter):
+    def run_node(self, node: Node):
+        self.node_counter += 1
+        result = super().run_node(node)
+        node.meta["fake_result"] = result
+        node.meta["node_idx"] = self.node_counter
+
+        # (1) Update metadata with the list of nodes that are used by this node
+        # copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
+        # We don't want to treat it as "being used as an input".
+        node_args = node.args
+        if node.target is torch.ops.aten.copy_.default:
+            node_args = node_args[1:]
+
+        # (2) Update metadata to track aliasing information about view tensor nodes.
+        if node.op == "call_function":
+            view_type = _get_view_type(node.target)
+            if view_type == _ViewType.SingleOutputView:
+                assert isinstance(node.args[0], Node)
+                node.meta["view_of"] = node.args[0]
+            elif view_type == _ViewType.MultiOutputView:
+                self.multi_output_view_nodes[node] = node.args[0]
+
+            # Check if we returned a multi-output view,
+            # and we're now grabbing the individual views from the output.
+            #
+            # For multi-output views, we want to map each output view to the base,
+            # but this mapping involves two separate nodes in FX IR.
+            # e.g. "a, b = x_1.split(...)" becomes:
+            #    %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
+            #    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
+            #    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
+            # And we'd like to set:
+            #    getitem1.meta['view_of'] = x_1
+            elif node.target is _operator.getitem:
+                list_arg = node.args[0]
+                maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
+                if maybe_base_of_view is not None:
+                    # Note: we could also track indexing info here for multi-output views.
+                    # I don't think this metadata is strictly needed for de-functionalization.
+                    assert isinstance(maybe_base_of_view, Node)
+                    node.meta["view_of"] = maybe_base_of_view
+
+        if "view_of" in node.meta:
+            # We're linking the current node with its first argument as views.
+            # Assert here that this is actually the case, and their storages are the same.
+            assert isinstance(node.meta["fake_result"], FakeTensor)
+            assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor)
+            view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage())
+            base_storage = StorageWeakRef(
+                node.meta["view_of"].meta["fake_result"]._typed_storage()
+            )
+            assert view_storage == base_storage
+        return result
+
+    def propagate(self, *args):
+        self.multi_output_view_nodes = {}
+        self.node_counter = -1
+
+        with FakeTensorMode() as mode:
+            fake_args = [
+                mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
+            ]
+            return super().run(*fake_args)
+
+
+def _schemas_match(functional_schema, inplace_schema):
+    names_match = (
+        inplace_schema.name.endswith("_")
+        and inplace_schema.name[:-1] == functional_schema.name
+    )
+    arg_types_match = len(functional_schema.arguments) == len(
+        inplace_schema.arguments
+    ) and all(
+        a1.type == a2.type
+        for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)
+    )
+    # for the inplace op, its first argument should be mutable
+    assert (
+        inplace_schema.arguments[0].alias_info is not None
+        and inplace_schema.arguments[0].alias_info.is_write
+    )
+    # and its remaining arguments shouldn't be.
+    assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
+    return names_match and arg_types_match
+
+
+# TODO: this should be beefed up to be able to properly re-inplace with:
+# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
+# - out= ops (e.g. angle -> angle.out)
+# TODO: we should also figure this info out using torchgen.
+def _maybe_get_inplace_op(op):
+    # __module__ seems broken; it returns torch._ops.aten which doesn't exist
+    if not isinstance(op, torch._ops.OpOverload):
+        return None
+    # Some view ops have inplace variants (as_strided_, etc),
+    # but we do NOT want the reinplacing pass to directly add these into the program.
+    # (they'll require extra special handling, aren't aren't really useful for perf anyway)
+    if _is_view_op(op):
+        return None
+    op_namespace = op.__module__.split(".")[-1]
+    op_base_name = op.overloadpacket.__name__
+    maybe_namespace_module = getattr(torch.ops, op_namespace)
+    maybe_inplace_op = (
+        None
+        if maybe_namespace_module is None
+        else getattr(maybe_namespace_module, f"{op_base_name}_", None)
+    )
+    if maybe_inplace_op is None:
+        return None
+
+    inplace_overloads = [
+        getattr(maybe_inplace_op, overload_name)
+        for overload_name in maybe_inplace_op.overloads()
+    ]
+    inplace_overloads_with_matching_schemas = [
+        f for f in inplace_overloads if _schemas_match(op._schema, f._schema)
+    ]
+    # Just because foo() and foo_() are both existing operators,
+    # They aren't guaranteed to have compatible schemas.
+    # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
+    # Even though several overloads of pow_ exist.
+    if len(inplace_overloads_with_matching_schemas) == 0:
+        return None
+    assert len(inplace_overloads_with_matching_schemas) == 1
+    inplace_op = inplace_overloads_with_matching_schemas[0]
+    return inplace_op
+
+
+_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = {
+    torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
+    torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
+    torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
+    torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
+}
+
+
+# This function, given a set of set of (aliased) tensor nodes,
+# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
+# in the node ordering.
+def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int):
+    def _add_if_tensor(x, set_):
+        if isinstance(x, FakeTensor):
+            set_.add(StorageWeakRef(x._typed_storage()))
+
+    nodes_used_after = set()
+    for t in tensor_aliases:
+        # get all nodes that use the current alias
+        usage_nodes = t.users
+        for n in usage_nodes:
+            # We only care about usages after the current node
+            if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index:
+                continue
+            # We also don't care about intermediate view ops.
+            # They only matter if their output is then used elsewhere
+            # (either in an out-of-place op, or as an output to the function).
+            if n in tensor_aliases:
+                if (
+                    isinstance(n.target, torch._ops.OpOverload)
+                    or n.target is _operator.getitem
+                ):
+                    continue
+            nodes_used_after.add(n)
+    return nodes_used_after
+
+
+# Given an op that we're trying to re-inplace, "b = foo(a)",
+# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
+# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
+# If there are any aliases in the alias_set(a) that satisfy:
+# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
+# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
+#     as "alias"
+def _get_view_inverse_node_usages(
+    later_node_usages: set[Node], self_aliases: set[Node]
+) -> set[Node]:
+    def matching_view_metadata(a, b):
+        return (
+            a.size() == b.size()
+            and a.stride() == b.stride()
+            and a.storage_offset() == b.storage_offset()
+        )
+
+    view_inverse_nodes = set()
+    # Go through them in node order, so we can see chains of view_scatter ops.
+    for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]):
+        if n.target not in _VIEW_INVERSE_MAP:
+            continue
+        base = n.args[0]
+        mutated_view = n.args[1]
+        assert isinstance(base, Node)
+        assert isinstance(base.meta["fake_result"], FakeTensor)
+        assert isinstance(mutated_view, Node)
+        assert isinstance(mutated_view.meta["fake_result"], FakeTensor)
+        assert not isinstance(n.target, str)
+        # Check that this view_inverse op actually corresponds to taking doing the inverse
+        # of one of our existing self_alias nodes.
+        original_view = _VIEW_INVERSE_MAP[n.target]
+        for self_alias in self_aliases:
+            # We're looking for some alias of the self arg, "alias",
+            # that was created from some op `alias = foo(base, args...)`
+            # such that the current _scatter op "inverts" that foo call.
+            # We can check that by running the original op again, and checking that the strides match.
+            if "view_of" not in self_alias.meta:
+                continue
+            self_alias_base = self_alias.meta["view_of"]
+            try:
+                # The we're trying to reuse the args from the view_scatter call inside of the corresponding
+                # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
+                # of the current alias we're looking at.
+                view_replay_metadata = original_view(
+                    self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs
+                )
+                expected_metadata = self_alias.meta["fake_result"]
+                # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
+                if matching_view_metadata(
+                    self_alias_base.meta["fake_result"], base.meta["fake_result"]
+                ) and matching_view_metadata(view_replay_metadata, expected_metadata):
+                    view_inverse_nodes.add(n)
+            except Exception:
+                continue
+
+    return view_inverse_nodes
+
+
+@compatibility(is_backward_compatible=True)
+def reinplace(gm, *sample_args):
+    """
+    Given an fx.GraphModule, modifies it to perform "reinplacing",
+    mutating the nodes of the graph.
+    We look for out-of-place op call sites like `b = a.add(...)`,
+    and convert them to be inplace (`b = a.add_(...)`),
+    as long as the input to the current operator ("a") isn't reused
+    anywhere later in the graph.
+
+    This pass currently expects to operate on a **functional, ATen** graph.
+    This can be obtained by running `make_fx(functionalize(f))`.
+
+    Sample inputs are needed to determine aliasing relationships of the inputs.
+    In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
+    inputs to the program.
+
+    Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
+
+    (1) Perform some initial checks on the metadata of "a" and "args..."
+        that can disqualify them from being reinplaced.
+
+      (1a) Check that the self argument we're attempting to reinplace
+           has acceptable dtype/size metadata to reinplace with.
+
+           For example, if we have:
+             a = torch.ones(1)
+             b = torch.ones(10)
+             out = torch.add(a, b)
+           We can't turn that into
+             a.add_(b)
+           Because that would require resizing "a".
+
+           Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
+           because that would require changing a's dtype (from e.g. float32 to bool).
+           Note that in this specific example, we could technically do better..
+
+           If we see the pattern:
+             a_1 = a.ge(b)
+             a_2 = aten._to_copy(a_1, a.dtype)
+           Then we this should be valid to completely re-inplace
+           (this is exactly what functionalization will emit when it sees a.ge_(b)).
+
+           This optimization is only really important for user programs
+           that directly use inplace comparison ops though.
+
+           We also cannot re-inplace on tensors that have overlapping memory,
+           e.g. torch.ones(1).expand(4, 4).add_(1)
+
+      (1b) Check if "a" is an alias of any of the program inputs.
+
+          If it is, skip and move to the next node.
+          Inplace'ing an op that would cause it to mutate a program is not sound,
+          because that would be a side effect visible to the user.
+
+          NOTE: there's a future optimization that we should make:
+          if "a" is a (alias of a)  program input, but later in the program
+          there is a node that looks like "a.copy_(...)",
+          Then re-inplacing is ok to do - we are temporarily reusing a's buffer,
+          which will later be overwritten by the copy_() call.
+
+          This will be an important optimization to have for programs that mutate
+          their inputs. It currently isn't implemented though.
+
+      (1c) Check if "a" and "args..." alias
+
+          For example, re-inplacing to create code like the below
+          isn't guaranteed to be sound:
+
+            aten.mul_(a, a)
+
+    (2) Check that "a" and all of its outstanding aliases are not used anywhere
+        later in the graph. If this is the case, then it's safe to re-inplace
+        to "b = foo_(a)".
+
+        There are a few caveats to this, explained in more detail below:
+        (a) If "a" is used later as an argument to a view op, that is okay.
+            It's only a problem if "a" (or that view) is later passed
+            into a normal operator, or if it is returned as the program output.
+        (b) If "a" is a repeat argument in `foo()`, then don't reinplace.
+            Most ATen kernels don't make any guarantees that this is sound,
+            e.g. if you do aten.mul_(a, a).
+            So we'll just ban re-inplacing in this case.
+            It's only a problem if "a" (or that view) is later passed
+        (c) If "a" is used as an input into a view "inverse" / "scatter"
+            operator, it is potentially fine to re-inplace
+            (and remove that scatter operator from the graph).
+            See below for a more detailed example.
+
+        NOTE: there is an optimization in this step that is crucial
+        to fully recovering performance from functionalization.
+
+        Given this program:
+        def f(x):
+            a = torch.ops.aten.add(x, x)
+            b = torch.ops.aten.diagonal(a)
+            torch.ops.aten.fill_(b, 0)
+            return d
+
+        Functionalization will emit the following:
+        def f(x):
+            a = torch.ops.aten.add(x, x)
+            b = torch.ops.aten.diagonal(a, 0, 1)
+            b_updated = torch.ops.aten.fill(b, 0)
+            a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
+            return a_updated
+
+        Ordinarily, we would not be able to reinplace the fill,
+        because "b" aliases with "a" which is used by the diagonal_scatter call.
+
+        "re-inplacing" is on the hook for figuring out that it is ok to
+        completely, the expensive diagonal_scatter call, if we re-inplace the add().
+
+        So, for every `alias in alias_set(a)`, instead of checking
+        that "alias" is not used anywhere later in the graph,
+        we check that
+            EITHER:
+          (a) alias is not used anywhere later in the graph
+            OR:
+          (b) alias is used exactly once later on in the graph,
+              in the following op:
+
+                out = foo_scatter(alias, x, args...)
+
+              where the following must hold:
+                (i) "foo_scatter" is the "inverse" operator for foo.
+                    This only applies to "foo" ops that are view operators,
+                    which view into a subset of the original tensor's memory.
+                    In practice, there are ~4 operators where this applies:
+                      diagonal -> diagonal_scatter
+                      slice -> slice_scatter
+                      select -> select_scatter
+                      as_strided -> as_strided_scatter
+                (ii) "args..." are the same between the foo() and foo_scatter() calls.
+
+    (3) Perform the actual re-inplacing on foo!
+
+      (3b) is the common case, but special care is needed for {view}_scatter (3a)
+
+      (3a) {view}_scatter ops.
+
+        Consider this program:
+          a = torch.zeros(2, 2)
+          b = torch.ones(2)
+          a[0] = b
+
+        Post functionalization, that will look like:
+          a = torch.zeros(2)
+          b = torch.ones(1)
+          a_updated = torch.select_scatter(a, b, 0, 0)
+
+        In this case though, there is no "functional" op to re-inplace!
+        Instead, we'd like to directly remove toe select_scatter call.
+        We already know from (3) that this is valid,
+        because "a" has no later usages in the graph.
+
+        We perform the re-inplacing on the {view}_scatter op like so
+        Before:
+          a_updated = torch.select_scatter(a, b, args...)
+        After:
+          a_slice = a.select(a, args...)
+          a_slice.copy_(b)
+
+      (3b) Otherwise, replace the functional op with its inplace variant.
+        Before:
+          b = foo(a, args...)
+        After:
+          a.foo_(args...)
+
+    (4) Finally, after converting either:
+          Before:
+            b = foo(a)
+          After:
+            foo_(a)
+        or
+          Before:
+            b = {slice}_scatter(a, mutated_slice, args...)
+          After:
+            slice = {slice}(a, args...)
+            slice.copy_(mutated_slice)
+
+        We now need to find all later nodes that use "b" as an argument
+        and update them to take in "a" instead.
+
+        Note that for the majority of inplace ops, this isn't actually necessary
+        (because most inplace ops return "self" as their output).
+        This isn't generally true for all mutable ops though, which is why
+        we need to actually replace all of the arguments.
+
+        We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
+        That maps a given tensor storage to the set of all nodes that take in that storage
+        as an input.
+        Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
+        together.
+
+    (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
+        during step (3) get manually deleted from the graph.
+        Their outputs are no longer used, so technically standard DCE would be able
+        to do this, but we can no longer run FX's DCE pass now that we have mutable
+        ops in the graph.
+    """
+    _FunctionalizationMetadataProp(gm).propagate(*sample_args)
+
+    # Useful debug printing
+    # def _print(x):
+    # if isinstance(x, FakeTensor):
+    # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
+
+    # for n in gm.graph.nodes:
+    # print(n.format_node())
+    # if hasattr(n, 'meta'):
+    # print(f'node_idx: {n.meta["node_idx"]}')
+    # if 'fake_result' in n.meta:
+    # tree_map(_print, n.meta['fake_result'])
+    # if 'view_of' in n.meta:
+    # print(f'view_of: {str(n.meta["view_of"])}')
+    # print()
+
+    # We need to know which nodes correspond to inputs (or their aliases)
+    # so we know not to re-inplace them.
+    # NOTE: later, we'll need to add an optimization for fully recovering performance
+    # on programs that mutate inputs.
+    input_storages = {
+        StorageWeakRef(node.meta["fake_result"]._typed_storage())
+        for node in gm.graph.nodes
+        if (
+            node.op == "placeholder"
+            and isinstance(node.meta["fake_result"], torch.Tensor)
+        )
+    }
+
+    # We also need to know for a given node, what are all of its aliasing nodes.
+    storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set)
+    for n in gm.graph.nodes:
+        if "fake_result" in n.meta:
+            # Tree-mapping because some ops can return lists of tensors.
+            def _add_to_map(x):
+                if isinstance(x, FakeTensor):
+                    storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
+
+            pytree.tree_map_(_add_to_map, n.meta["fake_result"])
+
+    # inplace-ify functional ops, subject to the constraints written below.
+    all_later_view_inverse_nodes_to_delete = set()
+    for node in gm.graph.nodes:
+        if node.op == "call_function":
+            # Today, the re-inplace pass on directly acts on:
+            # - functional ops with an inplace variant
+            # - {view}_scatter ops that can be potentially removed from the graph.
+            # Both of these ops take in tensor first args, so filtering on this condition
+            # makes the later code simpler.
+            # We should revisit this at some point though, particularly when we also want
+            # the reinplacer to be able to handle out= and mutable operators
+            # and tensorlist first args (like `_foreach_` ops).
+            if not isinstance(node.target, torch._ops.OpOverload):
+                continue
+            if len(node.target._schema.arguments) < 1:
+                continue
+            if type(node.target._schema.arguments[0].type) is not torch.TensorType:
+                continue
+
+            # Step 1a: Check that the self argument we're attempting to reinplace
+            # has the same size/stride as the output.
+            # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
+            # As it would require resizing scalar_tensor.
+            # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
+            # this is probably an optimization to revisit later).
+            self_arg = node.args[0]
+            self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"])
+            node_flattened = pytree.tree_leaves(node.meta["fake_result"])
+            self_has_wrong_metadata = False
+            if len(self_flattened) == len(node_flattened):
+                for self_meta, node_meta in zip(self_flattened, node_flattened):
+                    if self_meta.numel() != node_meta.numel():
+                        self_has_wrong_metadata = True
+                    if self_meta.dtype != node_meta.dtype:
+                        self_has_wrong_metadata = True
+                    # We also cannot re-inplace on tensors that have internal memory overlap.
+                    # e.g. torch.ones(1).expand(4, 4).add_(1)
+                    if torch._debug_has_internal_overlap(self_meta) == 1:
+                        self_has_wrong_metadata = True
+            # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
+            # Since users should never really be calling the functional "torch.ops.aten.resize"
+            # op directly in their programs.
+            if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
+                continue
+
+            # Step 1b: ensure that the op we're trying to re-inplace isn't a program input
+            self_arg_storage = StorageWeakRef(
+                self_arg.meta["fake_result"]._typed_storage()
+            )
+            if self_arg_storage in input_storages:
+                # TODO: later, add the optimization for handling `copy_()` calls in the graph.
+                continue
+            if len([x for x in node.args if x is self_arg]) > 1:
+                # Step 1c:
+                # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
+                # so we prevent re-inplacing in this case.
+                continue
+
+            self_arg_storage = StorageWeakRef(
+                self_arg.meta["fake_result"]._typed_storage()
+            )
+            self_aliases = storage_to_nodes[self_arg_storage]
+
+            # First, we find all later usages of any of the aliases of self_arg.
+            later_node_usages = _get_all_later_node_usages(
+                self_aliases, node.meta["node_idx"]
+            )
+            # Then, we check if any of those later usages are actually view_scatter ops
+            # that are safe to fully remove.
+            later_view_inverse_node_usages = _get_view_inverse_node_usages(
+                later_node_usages, self_aliases
+            )
+
+            # Step 2: Check to see if the input to the op is reused later in the graph.
+            # If not (same goes for its aliases), then this op is safe to re-in place.
+            # This is a slightly roundabout way to check that there are no later usages of the current self argument.
+            # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
+            can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
+            if not can_reinplace:
+                continue
+
+            # Step 3a: Special handling for when we see *_scatter operators.
+            # When we see an operator like `b = torch.slice_scatter(a, ...)`,
+            # instead of trying to "inplace" it into a.slice_scatter_(..._),
+            # we would prefer to remove it from the graph entirely,
+            # and instead copy_() the slice directly into the larger tensor.
+            # See the description of the algorithm for a full example.
+            if (
+                node.target in _VIEW_INVERSE_MAP
+                and node not in all_later_view_inverse_nodes_to_delete
+            ):
+                view_op = _VIEW_INVERSE_MAP[node.target]
+                # Before:
+                #   base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
+                # After:
+                #   slice = torch.ops.aten.slice.default(base, args...)
+                #   slice.copy_(mutated_slice)
+                with gm.graph.inserting_before(node):
+                    mutated_slice_node = node.args[1]
+                    remaining_slice_args = node.args[2:]
+                    slice_node = gm.graph.create_node(
+                        "call_function",
+                        view_op,
+                        (self_arg,) + tuple(remaining_slice_args),
+                        node.kwargs,
+                    )
+                    gm.graph.create_node(
+                        "call_function",
+                        torch.ops.aten.copy_.default,
+                        (
+                            slice_node,
+                            mutated_slice_node,
+                        ),
+                        {},
+                    )
+                # Add the slice_scatter node to our "nodes to delete" list.
+                all_later_view_inverse_nodes_to_delete.add(node)
+
+            else:
+                # Step 3b: Check to see if this operator has an inplace variant.
+                maybe_inplace_op = _maybe_get_inplace_op(node.target)
+                if maybe_inplace_op is None:
+                    continue
+                # And if so, replace it with its inplace variant.
+                node.target = maybe_inplace_op
+
+            # At this point, 'storage_to_nodes' will be stale.
+            # Now that we're inplacing `b = foo(a)`, we need to effectively
+            # union together the dict values for b and a's storage.
+            # Hmm... morally I think we also want to keep the `fake_result` metadata
+            # up to date here, but I'm not sure how easy it is to do.
+            # Maybe it's fine to wait until the end of the pass to update it.
+            curr_node_storage = StorageWeakRef(
+                node.meta["fake_result"]._typed_storage()
+            )
+            storage_to_nodes[self_arg_storage].update(
+                storage_to_nodes[curr_node_storage]
+            )
+            storage_to_nodes[curr_node_storage].update(
+                storage_to_nodes[self_arg_storage]
+            )
+
+            # Need to remember the view_scatter view nodes we found so we can remove them alter.
+            all_later_view_inverse_nodes_to_delete.update(
+                later_view_inverse_node_usages
+            )
+
+            # Step 4:
+            # Now that we've replaced b = a.foo() with a.foo_(),
+            # We need to replace any later usages of "b" with "a"
+            for old in itertools.chain([node], later_view_inverse_node_usages):
+                new = old.args[0]
+                nodes_to_update = [
+                    n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"]
+                ]
+                for node_to_update in nodes_to_update:
+
+                    def replace_arg(a):
+                        if a == old:
+                            return new
+                        return a
+
+                    # First, replace usages of "b" with "a"
+                    node_to_update.args = tree_map_only(
+                        Node, replace_arg, node_to_update.args
+                    )
+                    node_to_update.kwargs = tree_map_only(
+                        Node, replace_arg, node_to_update.kwargs
+                    )
+
+                    # Second, update our storage_to_nodes data structure.
+                    old_flattened_res = pytree.tree_leaves(old.meta["fake_result"])
+                    node_flattened_res = pytree.tree_leaves(
+                        node_to_update.meta["fake_result"]
+                    )
+
+                    old_res_storage = {
+                        StorageWeakRef(x._typed_storage())
+                        for x in old_flattened_res
+                        if isinstance(x, FakeTensor)
+                    }
+                    node_res_storage = {
+                        StorageWeakRef(x._typed_storage())
+                        for x in node_flattened_res
+                        if isinstance(x, FakeTensor)
+                    }
+
+                    # This will happen if we're updating a view op, e.g.
+                    # e.g. replacing
+                    #     x = view(old)
+                    #     x = view(new)
+                    # When that happens, we need to make sure to keep our
+                    # storage mapping up to date.
+                    #
+                    # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
+                    # or multiple tensors that all share the same storage.
+                    # We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
+                    if (
+                        len(old_res_storage) == 1
+                        and len(node_res_storage) == 1
+                        and old_res_storage == node_res_storage
+                    ):
+                        new_flattened_res = pytree.tree_leaves(new.meta["fake_result"])
+                        new_res_storage = {
+                            StorageWeakRef(x._typed_storage())
+                            for x in new_flattened_res
+                            if isinstance(x, FakeTensor)
+                        }
+                        assert len(new_res_storage) == 1
+                        (new_ref,) = new_res_storage
+                        (node_ref,) = node_res_storage
+                        # Technically, "old_ref" and all its aliases will remain
+                        # in our mapping.
+                        # That should be fine though, since we deleted "old"
+                        # from the graph at this point.
+                        storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
+                        storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
+
+    # Step 4: delete any _scatter nodes that we de-functionalized
+    # Need to take care not to delete any of these nodes until after *all* modifications
+    # to the graph are finished.
+    for to_delete in all_later_view_inverse_nodes_to_delete:
+        gm.graph.erase_node(to_delete)
+
+    gm.recompile()
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py
new file mode 100644
index 0000000000000000000000000000000000000000..e475a5bc9b6df55dc640d80dc6510ed263368c4a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py
@@ -0,0 +1,658 @@
+# mypy: allow-untyped-defs
+import functools
+import logging
+import operator
+import sys
+from typing import Any, Optional, TYPE_CHECKING
+
+
+# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
+if TYPE_CHECKING:
+    import sympy
+
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+else:
+    ShapeEnv = Any
+
+import torch
+import torch.utils._pytree as pytree
+from torch import fx
+from torch._subclasses.meta_utils import is_sparse_any
+from torch.fx._compatibility import compatibility
+from torch.fx._utils import lazy_format_graph_code
+from torch.fx.experimental.proxy_tensor import py_sym_types
+from torch.fx.experimental.sym_node import SymNode
+from torch.fx.graph_module import GraphModule
+
+
+__all__ = ["insert_deferred_runtime_asserts"]
+
+log = logging.getLogger(__name__)
+graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")
+
+
+def _get_example_value(node: fx.Node) -> Optional[str]:
+    """
+    Get the example value key for a node, since dynamo uses "example_value"
+    while non-strict export uses "val.
+    """
+    if "example_value" in node.meta:
+        return node.meta["example_value"]
+    elif "val" in node.meta:
+        return node.meta["val"]
+    else:
+        return None
+
+
+def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
+    val = _get_example_value(node)
+    if isinstance(val, py_sym_types):
+        return val.node.expr
+    return None
+
+
+@compatibility(is_backward_compatible=True)
+def insert_deferred_runtime_asserts(
+    gm: GraphModule,
+    shape_env: ShapeEnv,
+    name: str,
+    export: bool = False,
+) -> None:
+    """
+    During tracing, we may have discovered that some data-dependent values
+    had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
+    that x.item() >= 0.  These asserts can happen unpredictably during fake
+    tensor propagation, so we cannot conveniently insert them into the FX graph
+    when they occur.  Instead, we accumulate them in the ShapeEnv, and in this
+    pass insert them into the graph as proper tests.
+
+    This pass also deduplicates size-related computation, CSE-ing ops that produce
+    symbolic values and/or are involved in runtime asserts. Additionally, shape calls
+    (size/stride/storage_offset) are turned into compute on input sizes if possible,
+    allowing intermediate tensors to be freed earlier. For example, here dynamo will
+    DCE the cat and repeat calls:
+
+        z = torch.cat([x, x], dim=0)  # 2*s0
+        w = z.repeat(y.shape[0])  # 2*s0*s1
+        _w = w.shape[0]
+        # something with _w, but not w ...
+
+        # turns into ->
+        _w0 = 2 * s0
+        _w = _w0 * s1
+
+        # where s0, s1 are either SymInt graph inputs, or the result of added size calls
+
+    Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
+    the same expression, and redundant constrain_range calls are also deduplicated.
+    Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
+    information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
+    and we delete all previous calls, adding bound checks at the end of this pass.
+    """
+
+    # Import sympy locally
+    import sympy
+
+    from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
+    from torch.fx.experimental.symbolic_shapes import (
+        _get_placeholder_expr,
+        _has_uninterpretable_sympy_function,
+        CallMethodKey,
+        cast_symbool_to_symint_guardless,
+        ConvertIntKey,
+        DivideByKey,
+        free_symbols,
+        InnerTensorKey,
+        resolve_unbacked_bindings,
+    )
+    from torch.utils._sympy.numbers import int_oo
+    from torch.utils._sympy.reference import (
+        OptimizedPythonReferenceAnalysis,
+        PythonReferenceAnalysis,
+    )
+    from torch.utils._sympy.value_ranges import ValueRanges
+
+    # TODO: Request simplification on runtime asserts before emitting them
+    ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
+    graph = gm.graph
+    tracer = fx.proxy.GraphAppendingTracer(graph)
+    graph_code_log.debug(
+        "%s",
+        lazy_format_graph_code(
+            f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
+        ),
+    )
+
+    # We are going to mutate the dict
+    expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
+    placeholders = set()
+    first_non_placeholder = None
+    for node in graph.nodes:
+        if node.op != "placeholder":
+            first_non_placeholder = node
+            break
+        else:
+            placeholders.add(node)
+
+    def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
+        """
+        If a size/stride/storage offset call on an intermediate tensor,
+        we can try to compute the value from input shapes instead.
+        """
+        return (
+            (val := _get_sym_val(node)) is not None
+            and not isinstance(val, sympy.Number)
+            # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
+            and not _has_uninterpretable_sympy_function(val)
+            and any(
+                isinstance(arg, fx.Node)
+                and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
+                and arg.op != "placeholder"
+                for arg in node.args
+            )
+        )
+
+    # Figure out what key to use, val or example_value
+    val_key = "val"
+    for node in graph.nodes:
+        if "example_value" in node.meta:
+            val_key = "example_value"
+            break
+        elif "val" in node.meta:
+            break
+
+    def _node_metadata_hook(
+        node: torch.fx.Node,
+        stack_trace: Optional[str] = None,
+        nn_module_stack: Optional[dict[str, Any]] = None,
+        custom: Optional[dict[str, Any]] = None,
+    ) -> None:
+        fake_args = pytree.tree_map(
+            lambda arg: (
+                _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
+            ),
+            node.args,
+        )
+        try:
+            target = node.target
+            if node.op == "call_method":
+                assert isinstance(node.target, str)
+                target = getattr(fake_args[0], node.target)
+                fake_args = fake_args[1:]
+            node.meta[val_key] = target(*fake_args)  # type: ignore[operator]
+        except NotImplementedError:
+            # This can happen when attempting to reify a symbol with an unsupported call_function node,
+            # e.g. with NestedTensors + sym_size.int via match_symbol().
+            # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
+            pass
+        if stack_trace is not None:
+            node.meta["stack_trace"] = stack_trace
+        if nn_module_stack is not None:
+            node.meta["nn_module_stack"] = nn_module_stack
+        if custom is not None:
+            node.meta["custom"] = custom
+
+    # Track asserts/checks we've added
+    added_asserts: set[sympy.Expr] = set()
+    constrained_unbacked_symbols: set[sympy.Symbol] = set()
+
+    Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
+
+    def _sympy_interp(expr_to_proxy, expr):
+        # sympy_interp() with hash consing
+        from sympy import Integer, Number, Symbol
+        from sympy.logic.boolalg import BooleanAtom
+
+        from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
+
+        # hash cons
+        if expr in expr_to_proxy:
+            return expr_to_proxy[expr]
+        # base cases, don't cache
+        if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
+            return sympy_interp(Analysis, expr_to_proxy, expr)
+
+        # hash cons on arguments, run expr handler
+        expr_to_proxy[expr] = _run_sympy_handler(
+            Analysis,
+            [_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
+            expr,
+        )
+        return expr_to_proxy[expr]
+
+    def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
+        # This is probably unnecessary, but since torch._check() calls for single-symbol bounds
+        # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
+        # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
+        if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
+            return False
+        lhs, rhs = expr.args
+        return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
+            isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
+        )
+
+    def add_runtime_asserts(ras):
+        for ra in ras:
+            if (
+                # redundant
+                ra.expr in added_asserts
+                # if we've already added a constrain_range call for this symbol,
+                # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
+                or (
+                    len(ra.expr.free_symbols) == 1
+                    and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
+                    and _is_bound_expr_for_symbol(ra.expr)
+                )
+                # don't try to reify sympy functions we can't turn into FX nodes
+                or _has_uninterpretable_sympy_function(ra.expr)
+            ):
+                continue
+
+            log.debug("inserting runtime assert %s", ra.expr)
+            # Need to process ALL free symbols, not just unbacked ones
+            fvs = free_symbols(ra.expr)
+            missing = fvs - expr_to_proxy.keys()
+            if missing:
+                i1 = min(missing, key=str)
+                # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
+                # assert shape_env.is_unbacked_symint(i1), i1
+                ras_by_symbol.setdefault(i1, []).append(ra)
+            else:
+                # Convert the sympy expression into a sequence of FX
+                # nodes
+                with _set_node_metadata_hook(gm, _node_metadata_hook):
+                    res = _sympy_interp(expr_to_proxy, ra.expr).node
+
+                    graph.call_function(
+                        torch.ops.aten._assert_scalar.default,
+                        # TODO: use ra.msg here, but it's pretty
+                        # useless right now
+                        (
+                            res,
+                            f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
+                        ),
+                    )
+                added_asserts.add(ra.expr)
+
+    nodes = list(graph.nodes)
+    for i, node in enumerate(nodes[:-1]):
+        # Placeholders can match symbols, but when we destructure them
+        # with size we have to make sure we insert the nodes after all
+        # the placeholders
+        with graph.inserting_before(
+            nodes[i + 1] if node not in placeholders else first_non_placeholder
+        ):
+            # Unfortunately, this logic still must remain because manual
+            # make_fx calls may not explicitly bind all symbolic ints as
+            # arguments to the function, so we must infer it from the other
+            # arguments
+            if (
+                node in placeholders
+                and (example_value := _get_example_value(node)) is not None
+            ):
+
+                def match_symbol(symint, cb):
+                    if (
+                        isinstance(symint, torch.SymInt)
+                        and isinstance(symint.node, SymNode)
+                        and isinstance(
+                            s := _get_placeholder_expr(symint.node), sympy.Symbol
+                        )
+                        and s not in expr_to_proxy
+                    ):
+                        with _set_node_metadata_hook(gm, _node_metadata_hook):
+                            expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
+
+                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
+
+                match_symbol(example_value, lambda: node)
+
+                if isinstance(t := example_value, torch.Tensor):
+                    for i, s in enumerate(t.size()):
+                        match_symbol(
+                            s,
+                            lambda: graph.call_function(
+                                torch.ops.aten.sym_size.int, (node, i)
+                            ),
+                        )
+                    if not is_sparse_any(t):
+                        for i, s in enumerate(t.stride()):
+                            match_symbol(
+                                s,
+                                lambda: graph.call_function(
+                                    torch.ops.aten.sym_stride.int, (node, i)
+                                ),
+                            )
+                        match_symbol(
+                            t.storage_offset(),
+                            lambda: graph.call_function(
+                                torch.ops.aten.sym_storage_offset.default, (node,)
+                            ),
+                        )
+
+            # Handle asserts that aren't associated with any symbol.  This
+            # doesn't really have to be in the loop as it will only run once,
+            # it just needs to happen right after the placeholders.
+            # insert this after placeholders & added sym nodes, and before non-placeholders.
+            if node == first_non_placeholder:
+                add_runtime_asserts(ras_by_symbol.pop(None, []))  # type: ignore[call-overload]
+
+            # deduplicate asserts already present in graph, and remove trivial asserts
+            if node.target in (
+                torch._check,
+                torch.ops.aten._assert_scalar.default,
+            ):
+                cond = node.args[0] if node.args else node.kwargs.get("cond")
+                if (
+                    cond == True  # noqa: E712
+                    or (assert_expr := _get_sym_val(cond)) in expr_to_proxy
+                    and assert_expr in added_asserts
+                ):
+                    arg = cond
+                    gm.graph.erase_node(node)
+                    if isinstance(arg, fx.Node) and not arg.users:
+                        gm.graph.erase_node(arg)
+                else:
+                    added_asserts.add(assert_expr)  # type: ignore[arg-type]
+
+            # hash cons, replace function calls that return torch.SymInts with direct references to
+            # FX nodes built up to reify the sympy expression.
+            if (
+                node.op != "placeholder"
+                and (sym_expr := _get_sym_val(node)) is not None
+            ):
+                # this guards against deleting calls like item() that produce new untracked symbols
+                def has_new_untracked_symbols():
+                    # pyrefly: ignore [missing-attribute]
+                    for symbol in sym_expr.free_symbols:
+                        if symbol not in expr_to_proxy:
+                            return True
+                    return False
+
+                # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
+                # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
+                # (is backed), but produces an unbacked symbol. In this case keep the node alive.
+                resolved_unbacked_bindings = resolve_unbacked_bindings(
+                    shape_env, node.meta.get("unbacked_bindings", {})
+                )
+
+                def has_new_unbacked_bindings():
+                    assert resolved_unbacked_bindings is not None
+                    for key in resolved_unbacked_bindings:
+                        if key not in expr_to_proxy:
+                            return True
+                    return False
+
+                # maybe re-reify expression, replace current node
+                if (
+                    sym_expr in expr_to_proxy
+                    or (  # example value is redundant
+                        _is_intermediate_tensor_sym_call(node)
+                        # shape call on intermediate tensor, turn into computation on input shapes
+                        and not has_new_untracked_symbols()
+                    )
+                ) and not has_new_unbacked_bindings():
+                    if _is_intermediate_tensor_sym_call(
+                        node
+                    ):  # reify from input shapes
+                        with _set_node_metadata_hook(
+                            gm,
+                            functools.partial(
+                                _node_metadata_hook,
+                                stack_trace=node.meta.get("stack_trace"),
+                                nn_module_stack=node.meta.get("nn_module_stack"),
+                            ),
+                        ):
+                            expr_to_proxy[sym_expr] = _sympy_interp(
+                                expr_to_proxy,
+                                sym_expr,
+                            )  # type: ignore[arg-type]
+                        # won't try DCE-ing tensor compute here
+                    hash_node = expr_to_proxy[sym_expr].node  # type: ignore[arg-type]
+                    node.replace_all_uses_with(hash_node)
+                    gm.graph.erase_node(node)
+                    log.debug(
+                        "CSE node %s -> %s for expr %s",
+                        node,
+                        hash_node,
+                        sym_expr,
+                    )
+
+                # store node in hash cons, don't delete/replace
+
+                elif sym_expr not in expr_to_proxy and not isinstance(
+                    sym_expr,
+                    (sympy.Number, sympy.logic.boolalg.BooleanAtom),
+                ):  # don't hash cons primitives
+                    expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer)  # type: ignore[arg-type]
+
+            # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
+            # so calls before that are redundant.
+            if node.target in (
+                torch.ops.aten.sym_constrain_range.default,
+                torch.ops.aten.sym_constrain_range_for_size.default,
+            ):
+                gm.graph.erase_node(node)
+
+            defs = []
+
+            # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
+            # equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
+            # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
+            # information about the old symbol when we re-export, raising errors on data-dependent guards.
+            # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
+            if unbacked_bindings := resolve_unbacked_bindings(
+                shape_env, node.meta.get("unbacked_bindings")
+            ):
+                for s, keypath in unbacked_bindings.items():
+                    defs.append(s)
+
+                    # TODO: some CSE when generating these nodes can probably
+                    # help reduce graph size and improve compile time
+                    def go(node, keypath):
+                        if keypath == ():
+                            return node
+                        if (
+                            len(keypath) >= 2
+                            and isinstance(keypath[0], CallMethodKey)
+                            and isinstance(keypath[1], pytree.SequenceKey)
+                        ):
+                            if keypath[0].name == "size":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_size.int,
+                                        (node, keypath[1].idx),
+                                    ),
+                                    keypath[2:],
+                                )
+                            if keypath[0].name == "stride":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_stride.int,
+                                        (node, keypath[1].idx),
+                                    ),
+                                    keypath[2:],
+                                )
+
+                            return go(
+                                graph.call_method(
+                                    keypath[0].name, (node, keypath[1].idx)
+                                ),
+                                keypath[2:],
+                            )
+                        elif isinstance(keypath[0], CallMethodKey):
+                            if keypath[0].name == "storage_offset":
+                                return go(
+                                    graph.call_function(
+                                        torch.ops.aten.sym_storage_offset.default,
+                                        (node,),
+                                    ),
+                                    keypath[1:],
+                                )
+
+                            return go(
+                                graph.call_method(keypath[0].name, (node,)), keypath[1:]
+                            )
+                        elif isinstance(keypath[0], pytree.SequenceKey):
+                            return go(
+                                graph.call_function(
+                                    operator.getitem, (node, keypath[0].idx)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], ConvertIntKey):
+                            return go(
+                                graph.call_function(
+                                    cast_symbool_to_symint_guardless, (node,)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], DivideByKey):
+                            # TODO: need to assert divisibility
+                            return go(
+                                graph.call_function(
+                                    operator.floordiv, (node, keypath[0].divisor)
+                                ),
+                                keypath[1:],
+                            )
+                        elif isinstance(keypath[0], InnerTensorKey):
+                            return go(
+                                graph.call_function(
+                                    getattr, (node, keypath[0].inner_name)
+                                ),
+                                keypath[1:],
+                            )
+                        else:
+                            raise AssertionError(f"unrecognized keypath {keypath}")
+
+                    if s not in expr_to_proxy:
+                        with _set_node_metadata_hook(gm, _node_metadata_hook):
+                            expr_to_proxy[s] = fx.Proxy(
+                                go(node, keypath), tracer=tracer
+                            )
+                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
+
+            for i0 in defs:
+                ras = ras_by_symbol.pop(i0, [])
+                # Before we perform any asserts, first apply range
+                # refinement.  This is important, because if we are going
+                # to retrace the graph (and we typically are if we send
+                # the graph to AOTAutograd), we need to make sure we apply
+                # range refinement (ala _check_is_size) first, BEFORE we
+                # run any of the asserts.  Otherwise, we may decide to
+                # perform substitutions based on the asserts which we then
+                # can't back out, because value ranges can only be applied
+                # to asserts.)
+                #
+                # A perhaps better long term plan is to avoid this order
+                # dependence by making it possible to refine ranges on
+                # arbitrary expressions, not just symbols.  But it is not
+                # so easy to make use of this information, see
+                # https://twitter.com/ezyang/status/1745801370299482492
+                # We actually made an attempt at this in
+                # https://github.com/pytorch/pytorch/pull/119043
+                # which didn't work.
+                #
+                # Another ideas for how to do this:
+                # - Have bound_sympy be the source of truth of the ranges of any expression
+                # - Cache intermediate results for every subexpression of bound_sympy
+                # - This cache should be possible to edit to refine ranges
+                #
+                # One issue with this proposal is that if
+                # we have a bound on 2x, we are not going to be able to
+                # apply it for 4x.  Similarly, we may have bounds for an
+                # equivalent expression that we are not applying because
+                # it's not a perfect match (e.g. x < y vs y > x)".
+                #
+                # The first issue we already have it and it's impossible
+                # to solve in general, so any implementation on a best
+                # effort basis should do.
+                #
+                # The second issue is a preexisting one. It can be mitigated
+                # with a normalization algorithm. In general, it may also
+                # be on a best effort basis, but since our grammar is not
+                # terribly difficult, chances are we could even fully
+                # normalize SymPy expressions... who knows.
+                if i0 in constrained_unbacked_symbols:
+                    continue  # constrain symbol just once
+
+                if i0 in shape_env.size_like:
+                    if export:
+                        graph.call_function(
+                            torch.ops.aten.sym_constrain_range_for_size.default,
+                            (expr_to_proxy[i0].node,),
+                        )
+                    else:
+                        graph.call_function(
+                            torch._check_is_size, (expr_to_proxy[i0].node,)
+                        )
+
+                vr = shape_env.var_to_range[i0]
+                if vr.is_int and vr.upper == sys.maxsize - 1:
+                    # treat upper bound == sys.maxsize - 1 for int symbols as +oo
+                    # to avoid redundant runtime assert
+                    vr = ValueRanges(vr.lower, int_oo)
+                if not shape_env._default_unspecified_value_range().issubset(vr):
+                    # The runtime range is constrained, so add a runtime
+                    # assert and also explicitly refine the range
+                    # (refinement should not be necessary once runtime
+                    # asserts cause refinement, but that's NYI)
+                    def convert(s):
+                        if s in (int_oo, -int_oo):
+                            return None
+                        try:
+                            return int(s)
+                        except TypeError:
+                            return None
+
+                    if (
+                        expr_to_proxy[i0].node.target
+                        is not cast_symbool_to_symint_guardless
+                    ):
+                        # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
+                        # raises AOTAutograd errors on cast_symbool_to_symint_guardless
+
+                        with _set_node_metadata_hook(
+                            gm,
+                            functools.partial(
+                                _node_metadata_hook,
+                                stack_trace=node.meta.get("stack_trace"),
+                                nn_module_stack=node.meta.get("nn_module_stack"),
+                                # nodes added in `apply_runtime_assertion_pass` will have the same annotation
+                                # as the input node to the assertion
+                                custom=node.meta.get("custom"),
+                            ),
+                        ):
+                            if (min_val := convert(vr.lower)) is not None:
+                                ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
+                                graph.call_function(
+                                    torch.ops.aten._assert_scalar.default,
+                                    (
+                                        ge,
+                                        f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
+                                    ),
+                                )
+                                added_asserts.add(i0 >= min_val)
+                            if (max_val := convert(vr.upper)) is not None:
+                                le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
+                                graph.call_function(
+                                    torch.ops.aten._assert_scalar.default,
+                                    (
+                                        le,
+                                        f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
+                                    ),
+                                )
+                                added_asserts.add(i0 <= max_val)
+
+                constrained_unbacked_symbols.add(i0)
+                add_runtime_asserts(ras)
+
+    # delete unused reified symbols
+    for expr, proxy in expr_to_proxy.items():
+        if (
+            isinstance(expr, sympy.Symbol)
+            and proxy.node.op != "placeholder"  # keep placeholders intact
+            and not proxy.node.users
+        ):
+            log.debug("deleting unused reified symbol for %s", expr)
+            gm.graph.erase_node(proxy.node)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ea218356138de640c7fb7a74fb2efbcb4b21e5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py
@@ -0,0 +1,230 @@
+# mypy: ignore-errors
+
+import traceback
+from typing import Any, NamedTuple, Optional
+
+import torch
+import torch.fx
+from torch._dispatch.python import enable_python_dispatcher
+from torch._guards import detect_fake_mode
+from torch._prims_common import is_contiguous_for_memory_format_or_false
+from torch._subclasses.meta_utils import is_sparse_any
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_aggregate, Node
+
+
+__all__ = ["TensorMetadata", "ShapeProp"]
+
+
+@compatibility(is_backward_compatible=True)
+class TensorMetadata(NamedTuple):
+    # TensorMetadata is a structure containing pertinent information
+    # about a tensor within a PyTorch program.
+
+    # General Tensor metadata
+    shape: torch.Size
+    dtype: torch.dtype
+    requires_grad: bool
+    stride: tuple[int, ...]
+    memory_format: Optional[torch.memory_format]
+
+    # Quantization metadata
+    is_quantized: bool
+    qparams: dict[str, Any]
+
+
+# When include_contiguity is True, we will set contiguity when its always true for the tensor.
+# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
+# In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous,
+# def_not_contiguous and unknown).
+def _extract_tensor_metadata(
+    result: torch.Tensor, include_contiguity=True
+) -> TensorMetadata:
+    """
+    Extract a TensorMetadata NamedTuple describing `result`.
+    """
+    shape = result.shape
+    dtype = result.dtype
+    requires_grad = result.requires_grad
+    stride = result.stride() if not is_sparse_any(result) else ()
+
+    memory_format = None
+
+    if include_contiguity and not is_sparse_any(result):
+        memory_formats = (
+            torch.contiguous_format,
+            torch.channels_last,
+            torch.channels_last_3d,
+        )
+        for query_format in memory_formats:
+            if is_contiguous_for_memory_format_or_false(
+                result, memory_format=query_format
+            ):
+                memory_format = query_format
+                break
+
+    is_quantized = result.is_quantized
+    qparams: dict[str, Any] = {}
+    if is_quantized:
+        qscheme = result.qscheme()
+        qparams["qscheme"] = qscheme
+        if qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric):
+            qparams["scale"] = result.q_scale()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_zero_point()  # type: ignore[assignment]
+        elif qscheme in (
+            torch.per_channel_affine,
+            torch.per_channel_affine_float_qparams,
+            torch.per_channel_symmetric,
+        ):
+            # In this branch, scale and zero_point are expected to be tensors,
+            # we store the values as immutable_list in TensorMetadata for
+            # easier serialization downstream
+            qparams["scale"] = result.q_per_channel_scales().tolist()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_per_channel_zero_points().tolist()  # type: ignore[assignment]
+            qparams["axis"] = result.q_per_channel_axis()  # type: ignore[assignment]
+
+    return TensorMetadata(
+        shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams
+    )
+
+
+@compatibility(is_backward_compatible=True)
+class ShapeProp(torch.fx.Interpreter):
+    """
+    Execute an FX graph Node-by-Node and
+    record the shape and type of the result
+    into the corresponding node.
+
+    Example:
+         In this example, we record the shape
+         and data type of a module given
+         an example input ``torch.randn(50, D_in)``.
+         We print the name, shape and dtype of each node.
+
+        class TwoLayerNet(torch.nn.Module):
+            def __init__(self, D_in, H, D_out):
+                super().__init__()
+                self.linear1 = torch.nn.Linear(D_in, H)
+                self.linear2 = torch.nn.Linear(H, D_out)
+            def forward(self, x):
+                h_relu = self.linear1(x).clamp(min=0)
+                y_pred = self.linear2(h_relu)
+                return y_pred
+        N, D_in, H, D_out = 64, 1000, 100, 10
+        x = torch.randn(N, D_in)
+        y = torch.randn(N, D_out)
+        model = TwoLayerNet(D_in, H, D_out)
+        gm = torch.fx.symbolic_trace(model)
+        sample_input = torch.randn(50, D_in)
+        ShapeProp(gm).propagate(sample_input)
+
+        for node in gm.graph.nodes:
+            print(node.name, node.meta['tensor_meta'].dtype,
+                node.meta['tensor_meta'].shape)
+
+        The output of this code is:
+
+        x torch.float32 torch.Size([50, 1000])
+        linear1 torch.float32 torch.Size([50, 100])
+        clamp_1 torch.float32 torch.Size([50, 100])
+        linear2 torch.float32 torch.Size([50, 10])
+        output torch.float32 torch.Size([50, 10])
+
+    Args:
+         module (GraphModule): The module to be executed
+         fake_mode (FakeTensorMode): A fake mode for copying the gm
+
+    """
+
+    def __init__(self, gm, fake_mode=None):
+        super().__init__(gm)
+        if fake_mode is None:
+            fake_mode = detect_fake_mode()
+        if fake_mode is not None:
+            from torch._dynamo.utils import deepcopy_to_fake_tensor
+
+            # Note:
+            # We need fake execution cause the inputs are fake, however, we cannot fakify the module
+            # - because we need to write to the tensor_meta of the real module. So we fakify to
+            # produce a result (L131 below), to extract tensor meta, and then keep going.
+            #
+            # If we were to fakify, we would write to the wrong node, and then downstream fusion
+            # would be missing the tensor_meta.
+            #
+            # See torch/_inductor/overrides.py for where this is called upstream of fusion.
+            self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
+            self.fake_mode = fake_mode
+        else:
+            self.fake_module = None
+            self.fake_mode = None
+
+        self.real_module = self.module
+
+    def run_node(self, n: Node) -> Any:
+        from torch.fx.experimental.symbolic_shapes import (
+            compute_unbacked_bindings,
+            rebind_unbacked,
+        )
+
+        try:
+            if self.fake_module is not None:
+                # Hacky swap. Alternatively, we could do this with overriding
+                # call_module and get_attr.
+                self.module = self.fake_module
+            try:
+                if self.fake_mode is not None:
+                    with self.fake_mode, enable_python_dispatcher():
+                        result = super().run_node(n)
+                        rebind_unbacked(self.fake_mode.shape_env, n, result)
+                else:
+                    result = super().run_node(n)
+            finally:
+                self.module = self.real_module
+        except Exception as e:
+            traceback.print_exc()
+            raise RuntimeError(
+                f"ShapeProp error for: node={n.format_node()} with meta={n.meta}"
+            ) from e
+
+        found_tensor = False
+
+        def extract_tensor_meta(obj):
+            if isinstance(obj, torch.Tensor):
+                nonlocal found_tensor
+                found_tensor = True
+                return _extract_tensor_metadata(obj)
+            else:
+                return obj
+
+        meta = map_aggregate(result, extract_tensor_meta)
+        if found_tensor:
+            n.meta["tensor_meta"] = meta
+
+        if self.fake_mode:
+            if (shape_env := self.fake_mode.shape_env) and (
+                symbol_to_path := compute_unbacked_bindings(shape_env, result)
+            ):
+                n.meta["unbacked_bindings"] = symbol_to_path
+
+        n.meta["type"] = type(result)
+        return result
+
+    def propagate(self, *args):
+        """
+        Run `module` via interpretation and return the result and
+        record the shape and type of each node.
+
+        Args:
+            *args (Tensor): the sample input.
+
+        Returns:
+            Any: The value returned from executing the Module
+        """
+        if self.fake_mode is not None:
+            fake_args = [
+                self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
+                for t in args
+            ]
+        else:
+            fake_args = args
+        return super().run(*fake_args)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b244750f33dc5f5a7b233afc70f4b1e1f26cd8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_module.py
@@ -0,0 +1,656 @@
+# mypy: allow-untyped-defs
+import inspect
+import logging
+from collections import OrderedDict
+from collections.abc import Callable
+from typing import Any, Optional
+
+import torch
+from torch.fx._compatibility import compatibility
+from torch.fx._utils import lazy_format_graph_code
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node
+
+
+__all__ = ["Partition", "split_module"]
+log = _LOGGER = logging.getLogger(__name__)
+
+
+@compatibility(is_backward_compatible=True)
+class Partition:
+    def __init__(self, name: str):
+        self.name: str = name
+        self.submod_name = f"submod_{name}"
+        self.node_names: list[str] = []
+        self.inputs: dict[str, None] = {}
+        self.outputs: dict[str, None] = {}
+        self.dependencies: dict[str, None] = {}
+        self.dependents: dict[str, None] = {}
+        self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+        self.environment: dict[Node, Node] = {}
+        self.targets: dict[str, Any] = {}
+
+    def __repr__(self) -> str:
+        return (
+            f"name: {self.name},\n"
+            f" nodes: {self.node_names},\n"
+            f" inputs: {self.inputs},\n"
+            f" outputs: {self.outputs},\n"
+            f" partitions depended on: {self.dependencies},\n"
+            f" partition dependents: {self.dependents}"
+        )
+
+
+def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any:
+    attr_val = mod
+    for atom in qualname.split("."):  # type: ignore[union-attr]
+        if not hasattr(attr_val, atom):
+            raise AttributeError(f"Node target {qualname} not found!")
+        attr_val = getattr(attr_val, atom)
+    return attr_val
+
+
+# Creates subgraphs out of main graph
+@compatibility(is_backward_compatible=True)
+def split_module(
+    m: GraphModule,
+    root_m: torch.nn.Module,
+    split_callback: Callable[[Node], int],
+    qualname_map: Optional[dict[str, str]] = None,
+    keep_original_order: Optional[bool] = False,
+    keep_original_node_name: Optional[bool] = False,
+    keep_original_input_name: bool = True,
+    *,
+    partition_affix: Optional[str] = None,
+):
+    """
+    Creates subgraphs out of main graph
+
+    Args:
+        m (GraphModule): Graph module to split
+        root_m (torch.nn.Module): root nn module. Not currently used. Included
+            because the root nn module is usually transformed via
+            torch.fx._symbolic_trace.symbolic_trace (see example below)
+        split_callback (Callable[[Node], int]): Callable function
+            that maps a given Node instance to a numeric partition identifier.
+            split_module will use this function as the policy for which operations
+            appear in which partitions in the output Module.
+        qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
+            mapping from new target names in the module after split to old target
+            names in the original module.
+        keep_original_order: Optional[bool]: keep the original order of the GraphModule
+            or use the Topological order of the new constructed GraphModule
+        keep_original_node_name: Optional[bool]: If the partitioned graphs should
+            have the same node names as the original graph.
+        keep_original_input_name: bool: If the partitioned graphs should
+            have the same input names as the original graph.
+        partition_affix: Optional[str]: If specified, the submodules' names will contain
+            the affix, e.g. "submod__".
+
+    Returns:
+        GraphModule: the module after split.
+
+    Example:
+
+        This is a sample setup:
+
+            import torch
+            from torch.fx._symbolic_trace import symbolic_trace
+            from torch.fx.graph_module import GraphModule
+            from torch.fx.node import Node
+            from torch.fx.passes.split_module import split_module
+
+            class MyModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.param = torch.nn.Parameter(torch.rand(3, 4))
+                    self.linear = torch.nn.Linear(4, 5)
+
+                def forward(self, x, y):
+                    z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
+                    w = self.linear(y).clamp(min=0.0, max=1.0)
+                    return z + w
+
+            # symbolically trace model
+            my_module = MyModule()
+            my_module_traced = symbolic_trace(my_module)
+
+            # random mod partitioning
+            partition_counter = 0
+            NPARTITIONS = 3
+
+            def mod_partition(node: Node):
+                global partition_counter
+                partition = partition_counter % NPARTITIONS
+                partition_counter = (partition_counter + 1) % NPARTITIONS
+                return partition
+
+            # split module in module with submodules
+            module_with_submodules = split_module(
+                my_module_traced, my_module, mod_partition
+            )
+
+        Output looks like this. Original graph is broken into partitions
+
+            > print(module_with_submodules)
+            GraphModule(
+                (submod_0): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_1): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_2): GraphModule()
+            )
+
+            def forward(self, x, y):
+                param = self.param
+                submod_0 = self.submod_0(x, param, y);  x = param = y = None
+                getitem = submod_0[0]
+                getitem_1 = submod_0[1];  submod_0 = None
+                submod_1 = self.submod_1(getitem, getitem_1);  getitem = getitem_1 = None
+                getitem_2 = submod_1[0]
+                getitem_3 = submod_1[1];  submod_1 = None
+                submod_2 = self.submod_2(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
+                return submod_2
+
+        Output of split module is the same as output of input traced module.
+        This is an example within a test setting:
+
+            > orig_out = my_module_traced(x, y)
+            > submodules_out = module_with_submodules(x, y)
+            > self.assertEqual(orig_out, submodules_out)
+            True
+    """
+
+    log.debug(
+        "%s",
+        lazy_format_graph_code("pre split_module", m, colored=True),
+    )
+
+    def construct_graph(
+        node: Node,
+        base_mod_env: dict[str, Node],
+        base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
+    ):
+        if node.op == "placeholder":
+            default_value = (
+                node.args[0] if len(node.args) > 0 else inspect.Signature.empty
+            )
+            if keep_original_node_name:
+                args = (
+                    () if default_value is inspect.Signature.empty else (default_value,)
+                )
+                base_mod_env[node.name] = base_mod_graph.create_node(
+                    "placeholder",
+                    node.name,
+                    args=args,  # type: ignore[arg-type]
+                    type_expr=node.type,
+                )
+            else:
+                base_mod_env[node.name] = base_mod_graph.placeholder(
+                    node.target,  # type: ignore[arg-type]
+                    type_expr=node.type,
+                    default_value=default_value,
+                )
+            base_mod_env[node.name].meta = node.meta.copy()
+        elif node.op == "get_attr":
+            base_mod_env[node.name] = base_mod_graph.get_attr(node.target)  # type: ignore[arg-type]
+            base_mod_env[node.name].meta = node.meta.copy()
+            assert isinstance(node.target, str)
+            attr_val = _get_attr_from_qualname(m, node.target)
+            base_mod_attrs[node.target] = attr_val  # type: ignore[index]
+        return base_mod_env, base_mod_attrs
+
+    import sympy
+
+    partitions: dict[str, Partition] = {}
+    orig_nodes: dict[str, Node] = {}
+    symbol_to_node: dict[sympy.Symbol, Node] = {}
+
+    def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
+        from torch.fx.experimental.symbolic_shapes import free_symbols
+
+        defined = getattr(def_node, "_fx_partition", None)
+        used = getattr(use_node, "_fx_partition", None)
+
+        log.debug(
+            "record_cross_partition_use %s (%s) %s (%s)",
+            def_node.name,
+            defined,
+            use_node.name if use_node is not None else "-",
+            used,
+        )
+
+        if defined != used:
+            if defined is not None:
+                def_partition = partitions[defined]
+                def_partition.outputs.setdefault(def_node.name)
+                if used is not None:
+                    def_partition.dependents.setdefault(used)
+
+            if used is not None:
+                use_partition = partitions[used]
+                use_partition.inputs.setdefault(def_node.name)
+                # We have made def_node an input to the use_partition.  If
+                # this input has symbolic symbols in its size, those also must
+                # be made as inputs to the partition
+                if (def_val := def_node.meta.get("example_value")) is not None:
+                    for s in sorted(free_symbols(def_val), key=str):
+                        s_node = symbol_to_node[s]
+                        use_partition.inputs.setdefault(s_node.name)
+                        if symbol_to_node[s].op != "placeholder":
+                            # If the node that defines the symbol is not a
+                            # placeholder, we must make it an output of the
+                            # partition.  Note that this may be in a different
+                            # partition than defined!  Although, this doesn't
+                            # really make a difference for correctness, since
+                            # defined is guaranteed to have the symbol in
+                            # scope and can return it; you just get less
+                            # optimal codegen in this case.
+                            s_defined = getattr(s_node, "_fx_partition", None)
+                            if s_defined is not None:
+                                s_def_partition = partitions[s_defined]
+                                s_def_partition.outputs.setdefault(s_node.name)
+                                s_def_partition.dependents.setdefault(used)
+                                use_partition.dependencies.setdefault(s_defined)
+                if defined is not None:
+                    use_partition.dependencies.setdefault(defined)
+
+    def instantiate_node_partition_mapping(node):
+        partition_idx = split_callback(node)
+        partition_name = str(partition_idx)
+        if partition_affix is not None:
+            # For example, if user specifies partition_affix = "pp", then the
+            # partition name will be "pp_0", "pp_1", etc
+            partition_name = "_".join([partition_affix, partition_name])
+
+        log.debug(
+            "instantiate_node_partition_mapping %s (%s)", node.name, partition_name
+        )
+
+        # add node to partitions
+        partition = partitions.get(partition_name)
+        if partition is None:
+            partitions[partition_name] = partition = Partition(partition_name)
+
+        partition.node_names.append(node.name)
+        node._fx_partition = partition_name
+
+    # Global State Nodes are nodes which by their global state effects,
+    # "taint" all downstream nodes while they are active.
+    GLOBAL_STATE_NODES = [
+        torch.amp._enter_autocast,
+        torch.amp._exit_autocast,
+        torch._C._set_grad_enabled,
+    ]
+
+    # For grad regions:
+    # ------------------------
+    # 1. first region: we do nothing
+    # 2. subsequent regions: we insert the set_grad at the beginning
+    grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
+
+    # For autocast regions:
+    # ------------------------
+    # 1. first region: we will only insert the _exit at the end
+    # 2. intermediate regions: we will insert both the
+    #    _enter at the beginning and _exit at the end
+    # 3. last region: we will only insert _enter at the beginning
+    # We will do so in the order in which the autocasts were instantiated.
+    autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
+    autocast_exits: dict[Node, Optional[Node]] = {}
+
+    active_grad = None
+    active_autocasts = set()
+
+    for node in m.graph.nodes:
+        # This will prefer placeholder bindings, because those come first.
+        # This is a little dangerous though: it is possible that an unbacked
+        # symbol is used without any binding site for it, in which case we
+        # will get a KeyError not able to find it.  I'd like to fix this by
+        # having passes.runtime_assert establish some invariants that I can
+        # rely on later, but this needs some extra work.  Quick fix first.
+        # See https://github.com/pytorch/pytorch/issues/130534
+        if (
+            (val := node.meta.get("example_value")) is not None
+            and isinstance(val, (torch.SymInt, torch.SymFloat))
+            and isinstance(s0 := val.node.expr, sympy.Symbol)
+            and s0 not in symbol_to_node
+        ):
+            symbol_to_node[val.node.expr] = node
+
+        if node.op in ["placeholder", "get_attr", "output"]:
+            continue
+
+        instantiate_node_partition_mapping(node)
+
+        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
+            if node.target is torch._C._set_grad_enabled:
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], bool)
+                active_grad = node
+                grad_regions[active_grad] = set({split_callback(node)})
+            elif node.target is torch.amp._enter_autocast:
+                # Should all be python constants
+                assert all(not isinstance(arg, Node) for arg in node.args)
+                active_autocasts.add(node)
+                autocast_regions[node] = set({split_callback(node)})
+                autocast_exits[node] = None
+            elif node.target is torch.amp._exit_autocast:
+                assert len(node.args) == 1
+                autocast_regions[node.args[0]].add(split_callback(node))
+                active_autocasts.remove(node.args[0])
+                autocast_exits[node.args[0]] = node
+
+        if active_grad is not None:
+            grad_regions[active_grad].add(split_callback(node))
+
+        for a in active_autocasts:
+            autocast_regions[a].add(split_callback(node))
+
+    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
+
+    # pyrefly: ignore [bad-assignment]
+    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
+    # pyrefly: ignore [bad-assignment]
+    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
+
+    if _LOGGER.isEnabledFor(logging.DEBUG):
+        _LOGGER.debug("autocast_regions: %s", autocast_regions)
+        _LOGGER.debug("grad_regions: %s", grad_regions)
+
+    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
+
+    # split nodes into partitions
+    highest_partition = -1
+    for node in m.graph.nodes:
+        orig_nodes[node.name] = node
+
+        # TODO currently placeholders/parameters aren't put into random partitions,
+        # rather they're added to the graphs where they are used down below
+        if node.op in ["placeholder", "get_attr"]:
+            continue
+        if node.op == "output":
+            torch.fx.graph.map_arg(
+                node.args[0], lambda n: record_cross_partition_use(n, None)
+            )
+            continue
+
+        if assert_monotonically_increasing:
+            pid = split_callback(node)
+            assert highest_partition <= pid, (
+                "autocast or set_grad_enabled require monotonically increasing partitions:"
+                f"highest: {highest_partition}, this node's: {pid}"
+            )
+            highest_partition = pid
+
+        # do not capture cross-partition dependencies for global state nodes as they will be
+        # self-contained - their setup and unwind will be isolated to each partition submodule.
+        if node.target not in GLOBAL_STATE_NODES:
+            torch.fx.graph.map_arg(
+                node.args, lambda def_node: record_cross_partition_use(def_node, node)
+            )
+            torch.fx.graph.map_arg(
+                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
+            )  # noqa: B950
+
+    original_partition_order = list(partitions.keys())
+    # find partitions with no dependencies
+    root_partitions: list[str] = []
+    for partition_name, partition in partitions.items():
+        if not len(partition.dependencies):
+            root_partitions.append(partition_name)
+
+    # check partitions for circular dependencies and create topological partition ordering
+    sorted_partitions: list[str] = []
+    while root_partitions:
+        root_partition = root_partitions.pop()
+        sorted_partitions.append(root_partition)
+        for dependent in partitions[root_partition].dependents:
+            partitions[dependent].dependencies.pop(root_partition)  # noqa: B909
+            if not partitions[dependent].dependencies:
+                root_partitions.append(dependent)
+    if len(sorted_partitions) != len(partitions):
+        raise RuntimeError("cycle exists between partitions!")
+
+    # Enter prelude
+    for regions_mapping in [autocast_regions, grad_regions]:
+        for node, regions in regions_mapping.items():
+            assert len(regions) > 0
+            # pyrefly: ignore [index-error]
+            partitions[str(regions[0])].environment[node] = node
+            # pyrefly: ignore [index-error]
+            for r in regions[1:]:
+                partition = partitions[str(r)]
+                new_node = partition.graph.create_node(
+                    op=node.op,
+                    target=node.target,
+                    args=tuple(arg for arg in node.args),
+                    kwargs={},
+                    type_expr=node.type,
+                )
+                new_node.meta = (
+                    node.meta.copy()
+                )  # is it really a good idea to copy this?
+                partition.environment[node] = new_node
+
+    # add placeholders to partition inputs
+    for partition_name in sorted_partitions:
+        partition = partitions[partition_name]
+        new_inputs: dict[str, None] = {}
+
+        counter = 0
+
+        for inp in partition.inputs:
+            orig_node = orig_nodes[inp]
+            # We don't pass in get_attr nodes as inputs to the partition, but
+            # instead set them as targets and use getattr within the module
+
+            def add_placeholder():
+                if keep_original_input_name:
+                    name = inp
+                else:
+                    nonlocal counter
+                    name = f"arg_{counter}"
+                    counter += 1
+                placeholder = partition.graph.placeholder(
+                    name,
+                    type_expr=orig_nodes[inp].type,
+                )
+                new_inputs[inp] = None
+                return placeholder
+
+            if orig_node.op == "get_attr":
+                assert isinstance(orig_node.target, str)
+
+                orig_attr = _get_attr_from_qualname(m, orig_node.target)
+                if isinstance(orig_attr, torch.nn.Module):
+                    placeholder = partition.graph.get_attr(orig_node.target)
+                    partition.targets[orig_node.target] = orig_attr
+                else:
+                    placeholder = add_placeholder()
+            else:
+                placeholder = add_placeholder()
+            placeholder.meta = orig_nodes[inp].meta.copy()
+            partition.environment[orig_nodes[inp]] = placeholder
+        partition.inputs = new_inputs
+
+    # Transform nodes and collect targets for partition's submodule
+    for node in m.graph.nodes:
+        if hasattr(node, "_fx_partition"):
+            partition = partitions[node._fx_partition]
+
+            # swap out old graph nodes in kw/args with references to new nodes in this submodule
+            environment = partition.environment
+            gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
+            gathered_kwargs = torch.fx.graph.map_arg(
+                node.kwargs, lambda n: environment[n]
+            )
+
+            if node.op not in ["call_module", "get_attr"]:
+                target = node.target
+            else:
+                target_attr = _get_attr_from_qualname(m, node.target)
+                target = node.target.replace(".", "_")
+                partition.targets[target] = target_attr
+                # Fill in the passed-in mapping from new qualname to old qualname
+                if qualname_map is not None:
+                    # When creating the split module later, the submodules will have
+                    # path prefix matching the corresponding partition's submod_name
+                    qualname = f"{partition.submod_name}.{target}"
+                    qualname_map[qualname] = node.target
+
+            assert isinstance(gathered_args, tuple)
+            assert isinstance(gathered_kwargs, dict)
+            name = node.name if keep_original_node_name else None
+            new_node = partition.graph.create_node(
+                op=node.op,
+                target=target,
+                args=gathered_args,
+                kwargs=gathered_kwargs,
+                type_expr=node.type,
+                name=name,
+            )
+            new_node.meta = node.meta.copy()
+            partition.environment[node] = new_node
+
+    # Exit epilogue
+    for regions_mapping in [autocast_regions]:
+        for node in reversed(regions_mapping):
+            regions = regions_mapping[node]
+            assert len(regions) > 0
+            # pyrefly: ignore [index-error]
+            for r in regions[:-1]:
+                partition = partitions[str(r)]
+                exit_node = autocast_exits[node]
+                assert exit_node is not None, "Missing exit node"
+                new_node = partition.graph.create_node(
+                    op=exit_node.op,
+                    target=exit_node.target,
+                    args=(partition.environment[node],),
+                    kwargs={},
+                    type_expr=exit_node.type,
+                )
+                new_node.meta = (
+                    exit_node.meta.copy()
+                )  # is it really a good idea to copy this?
+
+    # original module environment dict mapping node names to nodes
+    orig_mod_env: dict[str, Node] = {}
+    # Set up values to construct base module
+    base_mod_env: dict[str, Node] = {}
+    base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+    base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
+    if not keep_original_order:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
+    else:
+        # Go through the graph to construct the mapping dict
+        for node in m.graph.nodes:
+            orig_mod_env[node.name] = node
+
+    # Do some things iterating over the partitions in topological order again:
+    # 1) Finish off submodule Graphs by setting corresponding outputs
+    # 2) Construct GraphModules for each submodule
+    # 3) Construct the base graph by emitting calls to those submodules in
+    #    topological order or original order specified by keep_original_order
+
+    construct_order_partitions = (
+        sorted_partitions if not keep_original_order else original_partition_order
+    )
+
+    already_constructed_attr_nodes = set()
+
+    # We actually need to insert the placeholder nodes in the original order
+    # otherwise graph signature will be wrong.
+    original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
+
+    for partition_name in construct_order_partitions:
+        partition = partitions[partition_name]
+
+        # Set correct output values
+        output_vals = tuple(
+            partition.environment[orig_nodes[name]] for name in partition.outputs
+        )
+
+        # skip output node generation if there are no output values
+        num_output_vals = len(output_vals)
+        if num_output_vals == 1:
+            partition.graph.output(output_vals[0])
+        elif num_output_vals > 1:
+            partition.graph.output(output_vals)
+        else:
+            # Invariant - Graph should always have an output node.
+            partition.graph.output(())
+
+        if keep_original_order:
+            # first get the attr nodes required by this partition
+            orig_mod_attr_nodes: list[Node] = [
+                orig_mod_env[key]
+                for key in partition.inputs
+                if key not in original_order
+            ]
+
+            for node in original_order:
+                if node in already_constructed_attr_nodes:
+                    continue  # already added this attr to the base graph
+                base_mod_env, _based_mod_attrs = construct_graph(
+                    node, base_mod_env, base_mod_attrs
+                )
+                already_constructed_attr_nodes.add(node)
+
+            # Construct GraphModule for this partition
+            for node in orig_mod_attr_nodes:  # type: ignore[attr-defined]
+                if node in already_constructed_attr_nodes:
+                    continue
+                base_mod_env, base_mod_attrs = construct_graph(
+                    node, base_mod_env, base_mod_attrs
+                )
+                already_constructed_attr_nodes.add(node)
+
+        base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
+            partition.targets, partition.graph
+        )  # noqa: B950
+
+        # Emit call in base graph to this submodule
+        output_val = base_mod_graph.call_module(
+            partition.submod_name,
+            tuple(base_mod_env[name] for name in partition.inputs),
+        )
+
+        num_outputs = len(partition.outputs)
+        if num_outputs > 1:
+            # Unpack multiple return values from submodule
+            output_val_proxy = torch.fx.proxy.Proxy(output_val)
+            for i, output_name in enumerate(partition.outputs):
+                base_mod_env[output_name] = output_val_proxy[i].node  # type: ignore[index]
+        elif num_outputs == 1:
+            base_mod_env[next(iter(partition.outputs))] = output_val
+
+    # When keep_original_order=True and if the graph doesn't have any
+    # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
+    # are never populated.
+    # For this case, we call `construct_graph` here which takes care of updating them.
+    if keep_original_order and not base_mod_env:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
+    # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
+    for node in m.graph.nodes:
+        if node.op == "output":
+            base_mod_graph.output(
+                torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
+            )  # noqa: B950
+
+    ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
+    log.debug(
+        "%s",
+        lazy_format_graph_code("post split_module", ret, colored=True),
+    )
+    return ret
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..88da7ac7c4f55fb5cf1c22546d09ceb3b406d6fb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/split_utils.py
@@ -0,0 +1,312 @@
+# mypy: allow-untyped-defs
+import copy
+from dataclasses import dataclass, field
+from typing import Optional, Union
+
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import map_arg
+from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
+
+from .tools_common import NodeList
+
+
+__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
+
+
+@compatibility(is_backward_compatible=False)
+def getattr_recursive(obj, name):
+    for layer in name.split("."):
+        if isinstance(obj, torch.nn.ModuleList):
+            if hasattr(obj, "_modules") and layer in obj._modules:
+                obj = obj._modules[layer]
+            else:
+                return None
+        elif hasattr(obj, layer):
+            obj = getattr(obj, layer)
+        else:
+            return None
+    return obj
+
+
+@compatibility(is_backward_compatible=False)
+def setattr_recursive(obj, attr, value):
+    if "." not in attr:
+        setattr(obj, attr, value)
+    else:
+        layer = attr.split(".")
+        setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Component:
+    """
+    A component serves as a container for a subgraph we want to create afterwards.
+    """
+
+    graph: torch.fx.Graph
+    order: int
+    name: str
+
+    # Stores the placeholder nodes in `graph`.
+    input_placeholders: list = field(default_factory=list)
+
+    # Store the nodes in original graph that are placeholder in `graph`.
+    orig_inputs: list = field(default_factory=list)
+
+    # Store the nodes in original graph that are outputs in `graph`.
+    orig_outputs: list = field(default_factory=list)
+
+    # Mapping from get_attr node in original graph to get_attr node in `graph`.
+    getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
+    constructor_args: list[str] = field(default_factory=list)
+    gm: Optional[torch.fx.GraphModule] = None
+
+
+@compatibility(is_backward_compatible=False)
+def split_by_tags(
+    gm: torch.fx.GraphModule,
+    tags: list[str],
+    return_fqn_mapping: bool = False,
+    return_tuple: bool = False,
+    GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
+) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
+    """
+    Splits a GraphModule using tags on its graph nodes. We honor the order of
+    tags. For example, we have tags = ["a", "b", "c"], the function will create
+    the initial submodules in the order of "a", "b", "c".
+
+    To set a tag:
+    gm.graph.nodes[idx].tag = "mytag"
+
+    This will result in all nodes with the same tag being extracted and placed in their
+    own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
+    and output nodes are created when needed while get_attr nodes get copied to submodules
+    where they are used.
+
+    Given the following module def:
+
+    class SimpleModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.linear1 = torch.nn.Linear(...)
+            self.linear2 = torch.nn.Linear(...)
+            self.linear3 = torch.nn.Linear(...)
+
+        def forward(self, in1, in2):
+            r1 = self.linear1(in1)
+            r2 = self.linear2(in2)
+            r3 = torch.cat([r1, r2])
+            return self.linear3(r3)
+
+    Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
+
+    ro:
+    def forward(self, in1):
+        self = self.root
+        linear1 = self.linear1(in1)
+        return linear1
+
+    main:
+    def forward(self, in2, linear1):
+        self = self.root
+        linear2 = self.linear2(in2)
+        cat_1 = torch.cat([linear1, linear2])
+        linear3 = self.linear3(cat_1)
+        return linear3
+
+    main:
+    def forward(self, in1, in2):
+        self = self.root
+        ro_0 = self.ro_0(in1)
+        main_1 = self.main_1(in2, ro_0)
+        return main_1
+
+    Returns:
+        split_gm: torch fx graph after split
+        orig_to_split_fqn_mapping: a map between the original fqn and the fqn
+            after split for call_module and get_attr.
+    """
+
+    def flatten(x: torch.fx.node.Argument) -> NodeList:
+        """
+        Stores nodes in x to a list and returns the list.
+        """
+        r: NodeList = []
+        map_arg(x, r.append)
+        return r
+
+    # Mapping from node in original module to node in created submodule.
+    node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Mapping from node in original module or created submodules to
+    # corresponding component.
+    node_to_component: dict[torch.fx.Node, Component] = {}
+
+    # Mapping from tag to the corresponding component.
+    tag_to_component: dict[str, Component] = {}
+
+    # Stores all components.
+    all_components: list[Component] = []
+
+    # Stores nodes that will be used in main graph.
+    used_in_main: dict[torch.fx.Node, None] = {}
+
+    # Main graph after split.
+    main_g = torch.fx.Graph()
+
+    # Mapping from node in original module to node in main graph after split.
+    main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Output node of original module.
+    output_node: Optional[torch.fx.Node] = None
+
+    # Create a component for each tag, we don't expect to create other components afterwards.
+    for tag in tags:
+        comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
+        all_components.append(comp)
+        tag_to_component[tag] = comp
+
+    # Traverse the nodes in original graph and take care of them.
+    for node in gm.graph.nodes:
+        if node.op == "output":
+            if output_node is not None:
+                raise RuntimeError("Multiple output nodes in graph!")
+            output_node = node
+            continue
+
+        # Placeholders in the original graph get copied to main graph.
+        if node.op == "placeholder":
+            main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
+            main_remapping[node].meta = copy.copy(node.meta)
+            continue
+
+        # Get_attr nodes are ignored because we are not tagging them.
+        # Instead, we copy them directly to the submodules use them afterwards.
+        if node.op == "get_attr":
+            continue
+
+        # Now we process callable nodes which are nodes with op of call_module,
+        # call_function or call_method. Every callable nodes should be tagged.
+        assert hasattr(node, "tag"), f"Node does not have tag: {node.format_node()}"
+
+        upstream_components = [
+            node_to_component[x]
+            for x in flatten(node.args) + flatten(node.kwargs)
+            if x.op not in {"placeholder", "get_attr"}
+        ]
+
+        comp = tag_to_component[node.tag]
+        node_to_component[node] = comp
+
+        # Max order of upperstream components.
+        mx = max((c.order for c in upstream_components), default=0)
+
+        # Expect the component for `node` has higher order then its upstream components.
+        assert comp.order >= mx, (
+            f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
+        )
+
+        # Map a input of `node` to nodes in the component's graph.
+        def remap_func(x):
+            # If input is a get_attr node, copy it to current component's graph.
+            # Returns the get_attr node in current component's graph.
+            if x.op == "get_attr":
+                if x not in comp.getattr_maps:
+                    comp.getattr_maps[x] = comp.graph.get_attr(
+                        x.target, type_expr=x.type
+                    )
+                    comp.getattr_maps[x].meta = copy.copy(x.meta)
+                return comp.getattr_maps[x]
+
+            # If input is not a placeholder, it should have been put into a component
+            # already. If it's the current component then we return the corresponding
+            # node in the component.
+            if x.op != "placeholder" and node_to_component[x] == comp:
+                return node_remapping[x]
+
+            # If input is a placeholder or it's in other components, we want to make it
+            # as a placeholder in current component's graph.
+            if x not in comp.orig_inputs:
+                comp.orig_inputs.append(x)
+                placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
+                placeholder.meta = copy.copy(x.meta)
+                comp.input_placeholders.append(placeholder)
+                used_in_main[x] = None
+
+            return comp.input_placeholders[comp.orig_inputs.index(x)]
+
+        n = comp.graph.node_copy(node, remap_func)
+        n.tag = node.tag  # type: ignore[attr-defined]
+        node_remapping[node] = n
+        node_to_component[n] = comp
+
+    if output_node is None:
+        raise RuntimeError("Graph had no output node!")
+
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            # We don't need components mapping for nodes of type "get_attr"
+            # that are consumed by the output. Only need to make sure we create
+            # corresponding counterparts in the resulting graph.
+            main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
+        else:
+            # All component results consumed by the output node should be
+            # marked as "used in main".
+            used_in_main[x] = None
+
+    # If a node is used in main graph then we mark it as an output in the component
+    # it belongs to.
+    for n in used_in_main:
+        if n.op != "placeholder":
+            node_to_component[n].orig_outputs.append(n)
+
+    # Now we create a graphmodule for each component.
+    orig_to_split_fqn_mapping: dict[str, str] = {}
+    for comp in all_components:
+        outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
+
+        if return_tuple:
+            comp.graph.output(outs)
+        else:
+            # Take care of the args of FX output node. If there's a single
+            # output then the output node args is like (output_single), else
+            # if there're multiple outputs then the output node args is like
+            # ((output_0, output_1, ...)).
+            comp.graph.output(outs[0] if len(outs) == 1 else outs)
+
+        comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
+            gm, subgraph=comp.graph, comp_name=comp.name
+        )
+        orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
+
+        # Create a call_module node in main graph.
+        main_node = main_g.call_module(
+            comp.name,
+            args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
+            kwargs=None,
+        )
+
+        if len(outs) == 1 and not return_tuple:
+            main_remapping[comp.orig_outputs[0]] = main_node
+        else:
+            for i, o in enumerate(comp.orig_outputs):
+                # Use Proxy to record getitem access.
+                main_remapping[o] = torch.fx.Proxy(main_node)[i].node  # type: ignore[index]
+
+    main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
+    main_root = HolderModule({comp.name: comp.gm for comp in all_components})
+    main_g._codegen = gm.graph._codegen
+
+    # If the output nodes consumes get_attr directly in the original graph,
+    # then we need to make sure get_attr is copied to the new graph.
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            setattr(main_root, x.name, getattr_recursive(gm, x.target))  # type: ignore[arg-type]
+
+    result_gm = GraphModuleCls(main_root, main_g)
+    if return_fqn_mapping:
+        return result_gm, orig_to_split_fqn_mapping
+
+    return result_gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d90f9d55cfdb194e2d2a0577a84b5fd9d7f0262
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py
@@ -0,0 +1,1121 @@
+# mypy: allow-untyped-defs
+import argparse
+import copy
+import json
+import logging
+import os
+from collections import defaultdict
+from collections.abc import Iterable, Sequence
+from dataclasses import dataclass
+from typing import Any, Literal, NamedTuple, Optional
+
+import torch
+from torch._logging import trace_structured
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
+from torch.fx.passes.graph_manipulation import get_size_of_node
+
+from .graph_drawer import FxGraphDrawer
+from .operator_support import get_node_target, OperatorSupportBase
+from .shape_prop import ShapeProp
+from .split_utils import split_by_tags
+from .tools_common import (
+    CALLABLE_NODE_OPS,
+    FxNetAccFusionsFinder,
+    is_node_output_tensor,
+    NodeList,
+    NodeSet,
+    Tensors,
+)
+
+
+__all__ = [
+    "FxNetAccNodesFinder",
+    "FxNetSplitterInternalError",
+    "Subgraph",
+    "SplitResult",
+    "generate_inputs_for_submodules",
+    "NodeEvent",
+    "NodeEventTracker",
+]
+_LOGGER = logging.getLogger(__name__)
+
+DEFAULT_MIN_ACC_MODULE_SIZE = 1
+DEFAULT_SKIP_FUSION = False
+DEFAULT_ALLOW_NON_TENSOR = False
+
+# ENV var and constants for node tracker
+
+TRACKER_DUMP_PATH = "_fx_net_tracker"
+NODES_SUFFIX = "_nodes.txt"
+ALL_SUFFIX = "_all.txt"
+
+ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
+ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
+ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
+    "FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
+)
+
+DUMP_PREFIX = os.environ.get(
+    ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
+)
+
+"""
+Different modes of the event tracker for local debugging:
+"0": No local dumps. Information available by setting breakpoints and visually inspect in pdb.
+"1": Dump all events to DUMP_PREFIX_all.txt
+"2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
+     recursively and dump to DUMP_PREFIX_nodex.txt
+"3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
+In addition to the above local dumps, tracker is always enabled and dumps via trace_structured.
+"""
+TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
+    ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
+)  # type: ignore[assignment]
+
+
+class _SplitterSettingBase:
+    def __init__(
+        self,
+        min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
+        skip_fusion=DEFAULT_SKIP_FUSION,
+        allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
+        max_acc_splits: int = -1,
+    ):
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--min-acc-module-size",
+            "--min_acc_module_size",
+            required=False,
+            type=int,
+            help="Minimum size limit of an accelerator subgraph.",
+        )
+        parser.add_argument(
+            "--max-acc-splits",
+            "--max_acc_splits",
+            required=False,
+            type=int,
+            help="Enforce a maximum number of split subgraphs.",
+        )
+        parser.add_argument(
+            "--skip-fusion",
+            "--skip_fusion",
+            default=False,
+            action="store_true",
+            help="If true then no fusion groups. Fusion group is used to "
+            "enforce no non-tensor data flow between submodules. If we don't "
+            "have this constrain, setting this to false is recommended as it "
+            "can reduce overhead.",
+        )
+        parser.add_argument(
+            "--allow-non-tensor",
+            "--allow_non_tensor",
+            default=False,
+            action="store_true",
+            help="For some backends non-tensor data flow between cpu and them "
+            "are not allowed. Therefore, if a node supported by accelerator but "
+            "it has non-tensor inputs or outputs to a cpu node we would want to "
+            "consider it as a cpu node during splitting. However, for some backends "
+            "we might not care about non-tensor data flow and we can set this option "
+            "to true to disable the functionality that prevent non-tensor data flow.",
+        )
+        args, _unknown = parser.parse_known_args()
+
+        self.min_acc_module_size: int = (
+            args.min_acc_module_size
+            if args.min_acc_module_size
+            else min_acc_module_size
+        )
+        self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
+        self.allow_non_tensor: bool = (
+            args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
+        )
+        self.max_acc_splits: int = max_acc_splits
+
+
+@compatibility(is_backward_compatible=False)
+class NodeEvent:
+    """
+    An event in graph split that happened on a node.
+    source: Subject of the event
+    desc: readable description
+    dep: Optional dependency, usually the node that caused the event.
+    """
+
+    def __init__(
+        self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None
+    ):
+        self.source = source
+        self.desc = desc
+        self.dep = dep
+
+    def to_str(self):
+        # source: The name of the subject of the event.
+        # desc: description of the event, in the format of |
+        # dep: The name of the cause of this event, which is another node, or #
+        # if it's caused by the subject node
+        return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
+
+
+@compatibility(is_backward_compatible=False)
+class NodeEventTracker:
+    """
+    Tracks node events during the splitter execution.
+    """
+
+    def __init__(self, tracker_mode, dump_prefix):
+        self.tracker_mode = tracker_mode
+        self.dump_prefix = dump_prefix
+        # list of events
+        self.events = []
+        # dict from node name to event index
+        self.node_events = {}
+        self.writer = print
+
+    def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
+        """
+        Add a new event to the tracker.
+        """
+        event = NodeEvent(node, desc, dep)
+        self.events.append(event)
+        if node.name not in self.node_events:
+            self.node_events[node.name] = []
+        self.node_events[node.name].append(len(self.events) - 1)
+
+    def print_node(self, node_name, recursive=False, tab="", writer=None):
+        """
+        Print a node and its events.
+        @param recursive: if True, print nodes that caused the events on this current node.
+        @param tab: Indentation for dependencies.
+        @param writer: function to write to file. If None, use print.
+        """
+        if not writer:
+            writer = self.writer
+        for idx in self.node_events.get(node_name, []):
+            event = self.events[idx]
+            writer(tab + event.to_str())
+            if recursive and event.dep is not None:
+                self.print_node(
+                    event.dep.name, recursive=True, tab="| " + tab, writer=writer
+                )
+
+    def to_dict(self):
+        """
+        Create dict dump on all events.
+        """
+        ret: dict[str, list[str]] = {}
+        for name in self.node_events:
+            ret[name] = []
+            for idx in self.node_events.get(name, []):
+                event = self.events[idx]
+                ret[name].append(event.to_str())
+        return ret
+
+    def print_all(self, writer=None):
+        """
+        Print all nodes in a list.
+        @param writer: function to write to file. If None, use print.
+        """
+        if not writer:
+            writer = self.writer
+        for name in self.node_events:
+            writer(f"Node: {name}:")
+            self.print_node(name, recursive=False, tab="  ", writer=writer)
+
+    def dump(self):
+        """
+        Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
+        """
+        # dump via trace_structured
+        trace_structured(
+            "artifact",
+            metadata_fn=lambda: {
+                "name": "fx_net_acc_splitter_finder_events",
+                "encoding": "json",
+            },
+            payload_fn=lambda: json.dumps(self.to_dict()),
+        )
+
+        def writeln(f):
+            def fn(x):
+                return f.write(x + "\n")
+
+            return fn
+
+        # Mode 0: no local dump
+        # Mode >=1: Dump all events to file
+        if self.tracker_mode >= 1:
+            with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
+                self.print_all(writeln(f))
+
+        def dump_selected_nodes(nodes):
+            with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
+                for node_name in nodes:
+                    writeln(f"===== Tracking node {node_name} =====")
+                    self.print_node(
+                        node_name, recursive=True, tab="|-", writer=writeln(f)
+                    )
+                    writeln(f"===== End of tracking node {node_name} =====")
+
+        # Mode 2: Dump specific nodes in recursive manner.
+        # Mode 3: Dump all nodes with more than 1 event in recursive manner.
+        if self.tracker_mode == 2 or self.tracker_mode == 3:
+            nodes = (
+                os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
+                    ","
+                )
+                if self.tracker_mode == 2
+                else [
+                    name for name, events in self.node_events.items() if len(events) > 1
+                ]
+            )
+            dump_selected_nodes(nodes)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccNodesFinder:
+    """
+    Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
+    input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
+
+    I.e. if we have a chain:
+
+    ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
+
+    where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
+
+    This behavior can be turned off by passing allow_non_tensor=True.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        operator_support: OperatorSupportBase,
+        allow_non_tensor: bool,
+    ):
+        self.module = module
+        self.operator_support = operator_support
+        self.allow_non_tensor = allow_non_tensor
+        self.acc_nodes: NodeSet = set()
+
+        self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
+
+    def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
+        """
+        Transitively excludes nodes from ACC supported set.
+        For every node in the worklist:
+        - removes its downstream ACC nodes from ACC supported set,
+        - if any downstream ACC node produces non-tensor output,
+          then it gets added into the worklist.
+        """
+        while cpu_worklist:
+            node = cpu_worklist.pop(0)
+
+            for user in node.users:
+                if user in self.acc_nodes:
+                    self.acc_nodes.remove(user)
+                    self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
+                    if not is_node_output_tensor(user):
+                        self.tracker.add(user, "new_cpu_node|non_tensor_output")
+                        cpu_worklist.append(user)
+
+    def reduce_acc_nodes_non_tensor_input(self):
+        """
+        Excludes nodes from ACC supported set that have direct
+        upstream CPU nodes that produce non-tensor outputs.
+        """
+        non_tensor_cpu_nodes: NodeList = []
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if node in self.acc_nodes:
+                continue
+            if is_node_output_tensor(node):
+                continue
+            self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
+            non_tensor_cpu_nodes.append(node)
+
+        self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
+
+    def reduce_acc_nodes_non_tensor_output(self):
+        """
+        Excludes nodes from ACC supported set that produce non-tensor
+        outputs and have downstream CPU nodes.
+        """
+        while True:
+            new_cpu_nodes: NodeList = []
+
+            for acc_node in self.acc_nodes:
+                if is_node_output_tensor(acc_node):
+                    continue
+                for user in acc_node.users:
+                    if user not in self.acc_nodes:
+                        new_cpu_nodes.append(acc_node)
+                        self.tracker.add(
+                            acc_node, "acc_del|non_tensor_output_with_cpu_user", user
+                        )
+                        break
+
+            if not new_cpu_nodes:
+                break
+
+            for new_cpu_node in new_cpu_nodes:
+                self.acc_nodes.remove(new_cpu_node)
+
+            self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
+
+    def __call__(self) -> NodeSet:
+        submodules = dict(self.module.named_modules())
+        self.acc_nodes = set()
+        for n in self.module.graph.nodes:
+            if n.op not in CALLABLE_NODE_OPS:
+                self.tracker.add(n, "init_cpu|not_callable")
+                continue
+            if not self.operator_support.is_node_supported(submodules, n):
+                self.tracker.add(n, "init_cpu|operator_support")
+                continue
+
+            self.tracker.add(n, "init_acc|callable_and_operator_supported")
+            self.acc_nodes.add(n)
+
+        if not self.allow_non_tensor:
+            self.reduce_acc_nodes_non_tensor_input()
+            self.reduce_acc_nodes_non_tensor_output()
+        self.tracker.dump()
+        return self.acc_nodes
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetSplitterInternalError(Exception):
+    pass
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Subgraph:
+    is_acc: bool
+    nodes: NodeList
+    device_ordinal: Optional[int] = None
+
+
+@compatibility(is_backward_compatible=False)
+class SplitResult(NamedTuple):
+    """
+    Stores the results of the splitter.
+
+    Attributes:
+        split_module: root module after splitting.
+        submodule_inputs: a dict that maps submodule name to its inputs.
+        non_acc_submodule_prefix: the prefix for non acc submodules. For
+            acc submodule the prefix is always "_run_on_acc_".
+    """
+
+    split_module: torch.fx.GraphModule
+    submodule_inputs: dict[str, Any]
+    non_acc_submodule_prefix: str
+
+
+@compatibility(is_backward_compatible=False)
+def generate_inputs_for_submodules(
+    model: torch.nn.Module,
+    inputs: Sequence[Any],
+    target_submodules: Iterable[str],
+    deepcopy: bool = False,
+) -> dict[str, Any]:
+    """
+    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
+    function doesn't work.
+
+    Args:
+        model: root model.
+        inputs: inputs to the root model.
+        target_submodules: submodules that we want to generate inputs for.
+
+    Returns:
+        A dict that maps from submodule name to its inputs.
+    """
+
+    handles = []
+    results = {}
+    submodule_to_names = {mod: name for name, mod in model.named_modules()}
+
+    def pre_forward(module, module_inputs):
+        results[submodule_to_names[module]] = (
+            copy.deepcopy(module_inputs) if deepcopy else module_inputs
+        )
+
+    for name, mod in model.named_modules():
+        if name in target_submodules:
+            if not isinstance(mod, torch.jit.ScriptModule):
+                handles.append(mod.register_forward_pre_hook(pre_forward))
+
+    def clean_up_handles():
+        for h in handles:
+            h.remove()
+
+    try:
+        with torch.no_grad():
+            model(*inputs)
+    except Exception as e:
+        clean_up_handles()
+        raise e
+
+    clean_up_handles()
+    return results
+
+
+class _SplitterBase:
+    """
+    Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
+    Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
+    Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
+
+    Given the following graph:
+          ==> b ==>
+        //         \\
+       a             d
+        \\         //
+          ==> c ==>
+
+    class SimpleModule(torch.nn.Module):
+        def forward(self, a):
+            b = torch.sin(a)
+            c = torch.cos(a)
+            d = b + c
+            return d
+
+    and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
+    we will get the following split result:
+
+    main:
+    def forward(self, a):
+        run_on_acc_0_0 = self._run_on_acc_0_0(a)
+        getitem = run_on_acc_0_0[0]
+        getitem_1 = run_on_acc_0_0[1]
+        run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
+        return run_on_cpu_1_1
+
+    _run_on_acc_0_0:
+    def forward(self, a):
+        sin_1 = torch.sin(a)
+        cos_1 = torch.cos(a)
+        return (sin_1, cos_1)
+
+    _run_on_cpu_1_1:
+    def forward(self, sin_1, cos_1):
+        add_1 = sin_1 + cos_1
+        return add_1
+    """
+
+    # PCIe bandwidth for the backend, default to 100 GB/s
+    PCIe_BW = 100 * 2**30
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Sequence[Any],
+        operator_support: OperatorSupportBase,
+        settings: _SplitterSettingBase,
+        non_acc_submodule_name: str = "_run_on_cpu_",
+        return_tuple: bool = False,
+        nodes_finder: Optional[FxNetAccNodesFinder] = None,
+    ):
+        """
+        Preprocesses graph before splitting:
+        - finds nodes supported by ACC,
+        - finds fusion groups for ACC nodes having non-tensor IO,
+        - builds a graph of direct dependencies,
+        - builds a map of fused nodes to their fusions.
+        As a result we get self.acc_nodes, self.deps and self.fusions.
+        """
+        assert isinstance(module, torch.fx.GraphModule)
+
+        self.module = module
+        ShapeProp(self.module).propagate(*sample_input)
+
+        self.settings = settings
+        self.operator_support = operator_support
+        self.sample_input = sample_input
+        if nodes_finder is None:
+            nodes_finder = FxNetAccNodesFinder(
+                self.module, self.operator_support, self.settings.allow_non_tensor
+            )
+        self.acc_nodes = nodes_finder()
+
+        if self.settings.skip_fusion:
+            self.fusions = {}
+        else:
+            self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
+
+        # Modify deps to add more deps for fused nodes
+        self.deps = self.find_deps()
+        self.update_deps_for_fusions()
+
+        self.non_acc_submodule_name = non_acc_submodule_name
+        self._node_submodule_map: dict[str, str] = {}
+        self._return_tuple = return_tuple
+
+        self.tags: list[str] = []
+
+    # ===============================================================
+    # Helpers for ctor and initial state
+    # ===============================================================
+
+    def get_node_submodule_map(self) -> dict[str, str]:
+        """Returns a map from node name to submodule name, e.g.
+        node: main_module_impl_impl_over_arch_unary_multiple_embedding
+          _pooling_embedding_pooling_sparse_entity_equivalence_key
+          _proxy_embedding_bag
+        maps to submodule name of: _run_on_acc_1
+        """
+        return self._node_submodule_map
+
+    def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
+        """
+        Builds a graph of node dependencies. Leaf nodes don't have any
+        dependencies and the "output" node doesn't have nodes depending on it.
+
+        Resulting graph has only direct dependencies, i.e. there are no
+        transitive dependencies.
+        """
+        deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op != "output":
+                    deps[user].add(node)
+        return deps
+
+    def update_deps_for_fusions(self):
+        """
+        Updates graph of dependencies so that:
+        - nodes from the same fusion depend on the same set of outer nodes,
+        - outer nodes depending on a fusion depend on all nodes in that fusion.
+        """
+        for node in self.fusions:
+            fusion = self.fusions[node]
+            for fused_neighbor in fusion:
+                self.deps[node].update(self.deps[fused_neighbor] - fusion)
+
+                for user in fused_neighbor.users:
+                    if user not in fusion:
+                        self.deps[user].add(node)
+
+    # ===============================================================
+    # Helpers for preview
+    # ===============================================================
+
+    def _lower_model_to_backend(
+        self, mod: torch.fx.GraphModule, inputs: Tensors
+    ) -> torch.nn.Module:
+        """
+        Lower the model to a backend.
+        """
+
+        return mod
+
+    def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
+        """
+        When an error occurs during lowering or running the lowered mod, we use this
+        function to find culprits in the `mod` that causes the error.
+        """
+
+        return "Unable to find a culprit because _find_culprit() function is not implemented."
+
+    def _draw_graph_based_on_node_support(
+        self, mod: torch.fx.GraphModule, supported_nodes: NodeList
+    ):
+        color_map = {
+            "default": "AliceBlue",
+            "supported": "chartreuse1",
+            "unsupported": "crimson",
+        }
+
+        class CustomDrawer(FxGraphDrawer):
+            def _get_node_style(self, node):
+                template = super()._get_node_style(node)
+                if node in supported_nodes:
+                    template["fillcolor"] = color_map["supported"]
+                elif node.op in CALLABLE_NODE_OPS:
+                    template["fillcolor"] = color_map["unsupported"]
+                else:
+                    template["fillcolor"] = color_map["default"]
+
+                return template
+
+        drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
+        dot_graph = drawer.get_main_dot_graph()
+        # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
+        dot_graph.write_raw("node_support.dot")  # type: ignore[attr-defined]
+
+    def node_support_preview(self, dump_graph: bool = False):
+        submodules = dict(self.module.named_modules())
+
+        supported_nodes: NodeList = []
+        supported_node_types = defaultdict(set)
+        unsupported_node_types = defaultdict(set)
+
+        def get_dtype(arg):
+            tensor_meta = arg.meta.get("tensor_meta")
+            return getattr(tensor_meta, "dtype", None)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            target = get_node_target(submodules, node)
+
+            # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
+            arg_dtypes = [
+                get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
+                for arg in node.args
+            ]
+
+            # Find last non-None element. If all elements are None, return max_len.
+            last_index = len(arg_dtypes) - next(
+                (
+                    i
+                    for i, dtype in enumerate(reversed(arg_dtypes))
+                    if dtype is not None
+                ),
+                len(arg_dtypes),
+            )
+
+            # Strip None elements at the end.
+            arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
+            kwarg_dtypes_tuple = tuple(
+                (k, get_dtype(arg))
+                for k, arg in node.kwargs.items()
+                if isinstance(arg, torch.fx.Node)
+            )
+
+            if self.operator_support.is_node_supported(submodules, node):
+                supported_nodes.append(node)
+                supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
+            else:
+                unsupported_node_types[target].add(
+                    (arg_dtypes_tuple, kwarg_dtypes_tuple)
+                )
+
+        if dump_graph:
+            self._draw_graph_based_on_node_support(self.module, supported_nodes)
+
+        reports = "\nSupported node types in the model:\n"
+        for t, dtypes in supported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        reports += "\nUnsupported node types in the model:\n"
+        for t, dtypes in unsupported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        print(reports)
+
+        # Return reports for testing purpose
+        return reports
+
+    def split_preview(self, dump_graph: bool = False):
+        reports = ""
+        subgraphs = self.put_nodes_into_subgraphs()
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        for i, subgraph in enumerate(subgraphs):
+            reports += (
+                f"_run_on_acc_{i}: "
+                if subgraph.is_acc
+                else f"{self.non_acc_submodule_name}{i}: "
+            )
+            reports += f"{len(subgraph.nodes)} node(s)\n"
+
+        self.tag(subgraphs)
+        split_mod = self.split(remove_tag=True)
+        split_mod.eval()
+
+        if dump_graph:
+            drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
+            dot_graphs = drawer.get_all_dot_graphs()
+            for name, dot_graph in dot_graphs.items():
+                # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
+                dot_graph.write_raw(f"{name}.dot")  # type: ignore[attr-defined]
+
+        max_qps: float = self.PCIe_BW
+        bottleneck_module = ""
+
+        for node in split_mod.graph.nodes:
+            if node.op == "call_module" and "acc" in node.target:
+                reports += f"\nProcessing acc submodule {node.target}\n"
+
+                submod = getattr(split_mod, node.target)
+
+                def get_submod_inputs(main_mod, submod, example_inputs):
+                    sub_inputs = None
+
+                    def get_inputs(self, inputs):
+                        nonlocal sub_inputs
+                        sub_inputs = inputs
+
+                    handle = submod.register_forward_pre_hook(get_inputs)
+                    main_mod(*example_inputs)
+                    handle.remove()
+                    return sub_inputs
+
+                submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
+                ShapeProp(submod).propagate(*submod_inputs)
+
+                total_input_bytes = 0
+                total_output_bytes = 0
+
+                reports += "Checking inputs...\n"
+                for n in submod.graph.nodes:
+                    if n.op == "placeholder":
+                        if not is_node_output_tensor(n):
+                            reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
+                        else:
+                            total_input_bytes += get_size_of_node(submod, n)[0]
+                    if n.op == "output":
+                        output_node = n
+
+                reports += "Checking outputs...\n"
+
+                def get_bytes(node: torch.fx.Node):
+                    nonlocal total_output_bytes
+                    nonlocal reports
+                    if not is_node_output_tensor(node):
+                        reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
+                    else:
+                        total_output_bytes += get_size_of_node(submod, node)[0]
+
+                map_arg(output_node.args, get_bytes)  # type: ignore[possibly-undefined]
+                qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
+                reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
+                reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
+
+                if qps < max_qps:
+                    max_qps = qps
+                    bottleneck_module = node.target
+
+                try:
+                    lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during lowering!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                    continue
+
+                try:
+                    lowered_submod(*submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during inference!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                else:
+                    reports += "Lowering and running succeed!\n"
+
+        reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
+        reports += f" bottleneck is submodule {bottleneck_module}."
+        print(reports)
+
+        # return the reports for testing purposes
+        return reports
+
+    # ===============================================================
+    # Helpers for extend_acc_subgraph() method
+    # ===============================================================
+
+    def find_reverse_deps(
+        self, tag_id: Optional[int] = None
+    ) -> dict[torch.fx.Node, NodeSet]:
+        """
+        Builds reversed topological node dependencies, if tag_id is specified,
+        we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
+        """
+        result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
+                    result[node].add(user)
+
+        return result
+
+    def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
+        processed_node = set()
+
+        for node, fusion in self.fusions.items():
+            if node in processed_node:
+                continue
+
+            new_dep = set()
+
+            # Create a new dependency set which include all the
+            # dependencies of the nodes in the fusion group
+            for n in fusion:
+                new_dep.update(deps[n])
+
+            # Exclude nodes in the fusion
+            new_dep.difference_update(fusion)
+
+            # Update dependency
+            for n in fusion:
+                deps[n] = new_dep
+
+                for arg in n.all_input_nodes:
+                    if arg not in fusion:
+                        deps[arg].update(fusion)
+
+                processed_node.add(n)
+
+    def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
+        """
+        Finds parent nodes of the `tag` subgraph.
+
+        Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
+        and is not a placeholder, we consider it as the parent node of the subgraph.
+        """
+        parent_nodes = set()
+
+        for node in self.module.graph.nodes:
+            if node.op in CALLABLE_NODE_OPS and node.tag == tag:
+                for arg in node.all_input_nodes:
+                    if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
+                        parent_nodes.add(arg)
+
+        return parent_nodes
+
+    def extend_acc_subgraph(self, tag: str):
+        """
+        Extend the acc subgraph with `tag` going the reversed topological direction.
+        """
+        # Dict that maps node to its users and ignore users that
+        # are in the subgraph that has greater tag
+        deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
+        self.update_reverse_deps_for_fusions(deps)
+
+        # Parent nodes of the subgraph
+        parent_nodes = self.find_parent_nodes_of_subgraph(tag)
+
+        visited_nodes: NodeSet = set()
+
+        while parent_nodes:
+            node = None
+
+            # Find a acc node that depends on visited nodes only
+            for n in parent_nodes:
+                if deps[n] <= visited_nodes and n in self.acc_nodes:
+                    node = n
+                    break
+
+            if node is None:
+                break
+
+            # Put the node into `tag` subgraph
+            node.tag = tag  # type: ignore[attr-defined]
+            parent_nodes.remove(node)
+            visited_nodes.add(node)
+
+            # If node is in a fusion group, add all fusion buddies to parent nodes
+            if node in self.fusions:
+                for fusion_node in self.fusions[node]:
+                    if fusion_node not in visited_nodes:
+                        parent_nodes.add(fusion_node)
+
+            # Add inputs of the node to parent nodes
+            for arg in node.all_input_nodes:
+                if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
+                    parent_nodes.add(arg)
+
+    # ===============================================================
+    # Helpers for split() method
+    # ===============================================================
+
+    def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
+        """
+        Finds nodes that consume module inputs or get_attr nodes.
+        """
+        starter_cpu_nodes: NodeSet = set()
+        starter_acc_nodes: NodeSet = set()
+        for node in self.module.graph.nodes:
+            # edge case, call_function, but with no dependencies
+            if node.op == "call_function" and len(node.all_input_nodes) == 0:
+                if node in self.acc_nodes:
+                    starter_acc_nodes.add(node)
+                else:
+                    starter_cpu_nodes.add(node)
+
+            if node.op not in {"placeholder", "get_attr"}:
+                continue
+
+            for user in node.users:
+                if user in self.acc_nodes:
+                    starter_acc_nodes.add(user)
+                else:
+                    starter_cpu_nodes.add(user)
+
+        return starter_cpu_nodes, starter_acc_nodes
+
+    def put_nodes_into_subgraphs(self) -> list[Subgraph]:
+        # We start graph traversal from leaf nodes
+        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
+        visited_nodes: NodeSet = set()
+
+        # Determine which subgraph to start from based on which subgraph has
+        # 0-dep node
+        acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
+
+        current_subgraph_nodes: NodeList = []
+
+        # Result accumulator
+        subgraphs: list[Subgraph] = []
+        while current_cpu_nodes or current_acc_nodes:
+            # Find the first node that should belong to the current subgraph and has all dependencies resolved
+            current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
+            node = next(
+                (n for n in current_nodes if self.deps[n] <= visited_nodes),
+                None,
+            )
+
+            # If nothing was found, then it's time to flip the mode and start a new subgraph
+            if node is None:
+                if not current_subgraph_nodes:
+                    raise FxNetSplitterInternalError("Subgraph can't be empty")
+
+                subgraphs.append(
+                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+                )
+                acc_subgraph = not acc_subgraph
+                current_subgraph_nodes = []
+                continue
+
+            current_nodes.remove(node)
+            visited_nodes.add(node)
+            current_subgraph_nodes.append(node)
+
+            # Add fusion buddies
+            if node in self.fusions:
+                if node in self.acc_nodes:
+                    current_acc_nodes.update(self.fusions[node] - visited_nodes)
+                else:
+                    current_cpu_nodes.update(self.fusions[node] - visited_nodes)
+
+            # Put depending nodes into the queue
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                # Add downstream nodes
+                if user in self.acc_nodes:
+                    current_acc_nodes.add(user)
+                else:
+                    current_cpu_nodes.add(user)
+
+        # Check if the last subgraph was not created
+        if current_subgraph_nodes:
+            subgraphs.append(
+                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+            )
+
+        if not subgraphs:
+            raise FxNetSplitterInternalError("Couldn't create subgraphs")
+
+        return subgraphs
+
+    def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
+        """
+        This pass finds ACC submodules with less than specified size and merges
+        them with adjacent CPU submodules.
+        """
+        result: list[Subgraph] = []
+        for subgraph in subgraphs:
+            if subgraph.is_acc:
+                if len(subgraph.nodes) >= self.settings.min_acc_module_size:
+                    result.append(subgraph)
+                else:
+                    print(
+                        "Eliminating acc subgraph because it's smaller than the threshold: "
+                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
+                    )
+                    if result:
+                        result[-1].nodes.extend(subgraph.nodes)
+                    else:
+                        subgraph.is_acc = False
+                        result.append(subgraph)
+            else:
+                if result and not result[-1].is_acc:
+                    result[-1].nodes.extend(subgraph.nodes)
+                else:
+                    result.append(subgraph)
+        return result
+
+    def tag(self, subgraphs: list[Subgraph]):
+        self.tags = []
+        for subgraph in subgraphs:
+            tag = (
+                f"_run_on_acc_{len(self.tags)}"
+                if subgraph.is_acc
+                else f"{self.non_acc_submodule_name}{len(self.tags)}"
+            )
+            self.tags.append(tag)
+            for node in subgraph.nodes:
+                if hasattr(node, "tag"):
+                    raise FxNetSplitterInternalError(f"Node {node} was already tagged")
+
+                node.tag = tag  # type: ignore[attr-defined]
+                self._node_submodule_map[node.name] = tag
+
+    def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
+        split_module = split_by_tags(
+            self.module, self.tags, return_tuple=self._return_tuple
+        )
+        if remove_tag:
+            for node in self.module.graph.nodes:
+                if hasattr(node, "tag"):
+                    del node.tag
+        return split_module  # type: ignore[return-value]
+
+    def __call__(self) -> torch.fx.GraphModule:
+        subgraphs = self.put_nodes_into_subgraphs()
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
+        non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
+        print(
+            f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
+        )
+        self.tag(subgraphs)
+        return self.split()
+
+    def generate_split_results(self) -> SplitResult:
+        split_module = self()
+        submodule_names = []
+        for name, _mod in split_module.named_children():
+            submodule_names.append(name)
+        if (
+            self.settings.max_acc_splits > 0
+            and len(submodule_names) > self.settings.max_acc_splits
+        ):
+            raise ValueError(
+                "Cannot fulfill max_acc_splits limit. "
+                "This may cause split fragmentation and "
+                "result in performance issues."
+            )
+
+        submodule_inputs = generate_inputs_for_submodules(
+            split_module, self.sample_input, submodule_names
+        )
+        return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cb9f62fd357573c2bc3b5365e50c8b9366920d7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0bc3928f5e1b8991d85c14738ea3e3cce18a77ee
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..157dc4017eda576f10793ef46b78cd97b0f5074b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py
@@ -0,0 +1,56 @@
+import unittest
+
+from ..pass_manager import (
+    inplace_wrapper,
+    PassManager,
+    these_before_those_pass_constraint,
+    this_before_that_pass_constraint,
+)
+
+
+class TestPassManager(unittest.TestCase):
+    def test_pass_manager_builder(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        pm = PassManager(passes)
+        pm.validate()
+
+    def test_this_before_that_pass_constraint(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        pm = PassManager(passes)
+
+        # add unfulfillable constraint
+        pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
+
+        self.assertRaises(RuntimeError, pm.validate)
+
+    def test_these_before_those_pass_constraint(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        constraint = these_before_those_pass_constraint(passes[-1], passes[0])
+        pm = PassManager([inplace_wrapper(p) for p in passes])
+
+        # add unfulfillable constraint
+        pm.add_constraint(constraint)
+
+        self.assertRaises(RuntimeError, pm.validate)
+
+    def test_two_pass_managers(self) -> None:
+        """Make sure we can construct the PassManager twice and not share any
+        state between them"""
+
+        passes = [lambda x: 2 * x for _ in range(3)]
+        constraint = these_before_those_pass_constraint(passes[0], passes[1])
+        pm1 = PassManager()
+        for p in passes:
+            pm1.add_pass(p)
+        pm1.add_constraint(constraint)
+        output1 = pm1(1)
+        self.assertEqual(output1, 2**3)
+
+        passes = [lambda x: 3 * x for _ in range(3)]
+        constraint = these_before_those_pass_constraint(passes[0], passes[1])
+        pm2 = PassManager()
+        for p in passes:
+            pm2.add_pass(p)
+        pm2.add_constraint(constraint)
+        output2 = pm2(1)
+        self.assertEqual(output2, 3**3)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tools_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tools_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a8f0df8449749167c4ec3dedaf719d78fad577
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/tools_common.py
@@ -0,0 +1,390 @@
+# mypy: allow-untyped-defs
+import collections
+import heapq
+import operator
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.node import _get_qualified_name
+
+
+__all__ = [
+    "get_acc_ops_name",
+    "get_node_target",
+    "is_node_output_tensor",
+    "FxNetAccFusionsFinder",
+    "legalize_graph",
+    "stable_topological_sort",
+]
+
+Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
+TensorOrTensors = Union[torch.Tensor, Tensors]
+NodeList = list[torch.fx.Node]
+NodeSet = set[torch.fx.Node]
+Names = list[str]
+CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
+
+
+@compatibility(is_backward_compatible=False)
+def get_acc_ops_name(k):
+    if isinstance(k, str):
+        return k
+    elif k.__module__ and "acc_ops" in k.__module__:
+        return f"acc_ops.{k.__name__}"
+    else:
+        module = k.__module__.replace(
+            "torch._ops", "torch.ops"
+        )  # WAR for bug in how torch.ops assigns module
+        return f"{module if module else ''}.{k.__name__}"
+
+
+@compatibility(is_backward_compatible=False)
+def get_node_target(
+    submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
+) -> str:
+    """
+    Given a `node` returns its target typename.
+
+    For "call_method" node, return node.target which is the name of that method being called.
+    This could potential lead to conflict but should be okay because normally it's on a tensor.
+
+    For "call_function" node, return typename of node.target.
+
+    For "call_module" node, return typename of the module that node.target point to.
+
+    If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
+    "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
+    """
+
+    assert node.op in CALLABLE_NODE_OPS, (
+        "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
+    )
+
+    if node.op == "call_module":
+        assert isinstance(node.target, str)
+        submod = submodules[node.target]
+        submod_type = getattr(submod, "_base_class_origin", type(submod))
+        return get_acc_ops_name(submod_type)
+    elif node.op == "call_function":
+        target: Any = node.target
+        return (
+            f"acc_ops.{target.__name__}"
+            if target.__module__ is not None and "acc_ops" in target.__module__
+            else _get_qualified_name(target)
+        )
+    else:
+        assert isinstance(node.target, str)
+        return node.target
+
+
+@compatibility(is_backward_compatible=False)
+def is_node_output_tensor(node: torch.fx.Node) -> bool:
+    """Checks if the node output produces a Tensor or not.
+
+    NOTE: This requires to run `ShapeProp` on the containing fx graph before
+    calling this function. This is because it works by checking the `type`
+    metadata on the node. This metadata is produced by the `ShapeProp`.
+    """
+    type_ = node.meta.get("type", None)
+    return type_ is not None and issubclass(type_, torch.Tensor)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccFusionsFinder:
+    """
+    Finds groups of connected ACC nodes that pass non-tensor data between each other.
+    Such groups are called fusion groups.
+    """
+
+    def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
+        self.module = module
+        self.nodes = list(module.graph.nodes)
+        self.acc_nodes = acc_nodes
+
+    @dataclass
+    class FusionGroup:
+        # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
+        top_node_idx: int
+
+        # Nodes in this fusion group.
+        nodes: NodeSet
+
+        # Inputs to this fusion group.
+        inputs: NodeSet
+
+        # Nodes that in the fusion group that haven't been processed yet.
+        nodes_need_process: NodeSet
+
+        def add_node(self, node):
+            """
+            Add a node to fusion group.
+            """
+            if node in self.nodes:
+                return
+
+            self.nodes_need_process.add(node)
+            self.nodes.add(node)
+            self.inputs.discard(node)
+            self.inputs.update(
+                {
+                    n
+                    for n in node.all_input_nodes
+                    if n.op in CALLABLE_NODE_OPS and n not in self.nodes
+                }
+            )
+
+    def recursive_add_node(
+        self,
+        fusion_group: "FxNetAccFusionsFinder.FusionGroup",
+        inputs: Union[NodeSet, NodeList],
+        visited: Optional[NodeSet] = None,
+    ):
+        """
+        Start from inputs and going reverse topological order. If any upstream node
+        is in the fusion group, add all the nodes in this path to fusion group.
+        """
+        for arg in inputs:
+            # skip the node if already seen
+            if visited is not None:
+                if arg in visited:
+                    continue
+                visited.add(arg)
+
+            # Skip placeholder and get_attr because they won't be in the fusion group.
+            if arg.op not in CALLABLE_NODE_OPS:
+                continue
+
+            # If the node has smaller idx, it's already an upstream node of the fusion
+            # group. We don't need to check it anymore.
+            if self.nodes.index(arg) < fusion_group.top_node_idx:
+                continue
+
+            # If the node is in the fusion group, return True.
+            if arg in fusion_group.nodes:
+                return True
+
+            # Check the upstream nodes of the node, if any of them is in the fusion group
+            # we'll add this node to fusion group and return True.
+            if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
+                fusion_group.add_node(arg)
+                return True
+
+        return False
+
+    def __call__(self) -> dict[torch.fx.Node, NodeSet]:
+        result: dict[torch.fx.Node, NodeSet] = {}
+        acc_nodes = list(self.acc_nodes)
+
+        for node in acc_nodes:
+            if node in result:
+                continue
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if "tensor_meta" in node.meta:
+                continue
+            if node not in self.acc_nodes:
+                continue
+
+            fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
+                top_node_idx=self.nodes.index(node),
+                nodes={node},
+                inputs=set(node.all_input_nodes),
+                nodes_need_process={node},
+            )
+            while fusion_group.nodes_need_process:
+                node = fusion_group.nodes_need_process.pop()
+                self.recursive_add_node(
+                    fusion_group,
+                    fusion_group.inputs,
+                    visited=set(),
+                )
+
+                # Optionally add downstream nodes
+                if "tensor_meta" not in node.meta:
+                    for user in node.users:
+                        if user.op not in CALLABLE_NODE_OPS:
+                            continue
+                        if user in fusion_group.nodes:
+                            continue
+
+                        fusion_group.add_node(user)
+                        self.recursive_add_node(
+                            fusion_group,
+                            fusion_group.inputs,
+                            visited=set(),
+                        )
+
+                # Add some upstream nodes
+                for arg in node.all_input_nodes:
+                    if arg.op not in CALLABLE_NODE_OPS:
+                        continue
+                    if "tensor_meta" in arg.meta:
+                        continue
+                    if arg in fusion_group.nodes:
+                        continue
+
+                    fusion_group.add_node(arg)
+                    fusion_group.top_node_idx = min(
+                        fusion_group.top_node_idx, self.nodes.index(arg)
+                    )
+                    self.recursive_add_node(
+                        fusion_group,
+                        fusion_group.inputs,
+                        visited=set(),
+                    )
+
+            if not (set(fusion_group.nodes) <= self.acc_nodes):
+                self.acc_nodes -= fusion_group.nodes
+            else:
+                for n in fusion_group.nodes:
+                    result[n] = fusion_group.nodes
+
+        return result
+
+
+@compatibility(is_backward_compatible=False)
+def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Replace the graph of the given GraphModule with one that contains the same nodes as the
+    original, but in topologically sorted order.
+
+    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
+    order of its input GraphModule, so that this order is restored before further transformation.
+
+    Arguments:
+        gm: The graph module to topologically sort. It is modified in-place.
+
+    Returns:
+        The graph module in-place sorted
+
+    Warning:
+        This topological sort is NOT stable, it will NOT preserve the original node order.
+        If you need a stable topological sort, use stable_topological_sort instead.
+    """
+
+    # These operators are used for making runtime assertions before any
+    # data-dependent operators occur. We want to prioritize sorting these to
+    # ensure that these assertions appear before any data-dependent operations
+    # in the graph.
+    PRIORITIZED_OPS = [
+        operator.add,
+        operator.mul,
+        operator.sub,
+        operator.floordiv,
+        operator.truediv,
+        operator.mod,
+        operator.le,
+        operator.lt,
+        operator.ge,
+        operator.gt,
+        operator.eq,
+        operator.ne,
+        torch.ops.aten.sym_constrain_range.default,
+        torch.ops.aten.sym_constrain_range_for_size.default,
+        torch.ops.aten._assert_async.msg,
+        torch.ops.aten.scalar_tensor.default,
+        torch.ops.aten._assert_scalar.default,
+    ]
+
+    indeg = dict.fromkeys(gm.graph.nodes, 0)
+    new_graph = torch.fx.Graph()
+    # Track how many unfulfilled dependencies each node has
+    for node in gm.graph.nodes:
+        for user in node.users:
+            indeg[user] += 1
+    queue: collections.deque = collections.deque()
+    # Add all nodes with no dependencies to the queue
+    for node in gm.graph.nodes:
+        if indeg[node] == 0:
+            queue.append(node)
+    env: dict[torch.fx.Node, torch.fx.Node] = {}
+    # Pop nodes from the queue, and add nodes that have had all their
+    # dependencies fulfilled
+    while len(queue) > 0:
+        cur = queue.popleft()
+        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
+        for user in cur.users:
+            indeg[user] -= 1
+            if indeg[user] == 0:
+                if user.op == "call_function" and user.target in PRIORITIZED_OPS:
+                    queue.appendleft(user)
+                else:
+                    queue.append(user)
+    # If the new graph's size is not as large as the old one, then there must be
+    # a cycle (i.e. some node's dependencies were not satisfied.)
+    if len(new_graph.nodes) < len(gm.graph.nodes):
+        raise RuntimeError(
+            f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
+        )
+    new_graph._codegen = gm.graph._codegen
+    gm.graph = new_graph
+    return gm
+
+
+@compatibility(is_backward_compatible=False)
+def stable_topological_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Replace the graph of the given GraphModule with one that contains the same nodes as the
+    original, but in topologically sorted order while preserving the original node order
+    as much as possible.
+
+    This function performs a stable topological sort where nodes appear in an order that:
+    1. Respects data dependencies (topological ordering)
+    2. Preserves the original node order when there are no dependency constraints
+
+    The algorithm uses Kahn's algorithm with a priority queue: nodes with all dependencies
+    satisfied are added to a min-heap, ordered by their original position. This ensures
+    we always process the earliest node in the original order among ready nodes.
+
+    Arguments:
+        gm: The graph module to topologically sort. It is modified in-place.
+
+    Returns:
+        The graph module in-place sorted
+    """
+    indeg = dict.fromkeys(gm.graph.nodes, 0)
+    new_graph = torch.fx.Graph()
+
+    # Build node to original index mapping
+    node_to_id: dict[torch.fx.Node, int] = {
+        node: idx for idx, node in enumerate(gm.graph.nodes)
+    }
+
+    # Track how many unfulfilled dependencies each node has
+    for node in gm.graph.nodes:
+        for user in node.users:
+            indeg[user] += 1
+
+    # Priority queue: (original_index, node)
+    # Use min-heap to always process the node with smallest original index
+    ready_queue: list[tuple[int, torch.fx.Node]] = []
+    for node in gm.graph.nodes:
+        if indeg[node] == 0:
+            heapq.heappush(ready_queue, (node_to_id[node], node))
+
+    env: dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Process nodes
+    while ready_queue:
+        # Pop node with smallest original index
+        _, cur = heapq.heappop(ready_queue)
+        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
+
+        # Update in-degrees and add newly ready nodes
+        for user in cur.users:
+            indeg[user] -= 1
+            if indeg[user] == 0:
+                heapq.heappush(ready_queue, (node_to_id[user], user))
+
+    # Check if all nodes were processed
+    assert len(new_graph.nodes) == len(gm.graph.nodes), (
+        f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
+    )
+
+    new_graph._codegen = gm.graph._codegen
+    gm.graph = new_graph
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5e7e66868a0776609ff7ffff458f6a91ccf98a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py
@@ -0,0 +1 @@
+from .common import compare_graphs, HolderModule, lift_subgraph_as_module
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f0901c39b11fc61996c4ef0927057118b2e39a0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fd4fa63b62a9c695530ba2eb6d79d05adf229af
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b13e250230fb1ccc1cc6ac97390edeef54ac513f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94b1a0045ea20b50db8223ddc0e1e6416732421b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48a050ce84f7e84eecf55f3d488d31f4298695e0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a5eb942f0bb225982b3a45ef073158711f89637
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c97aa4093571604953f12f8ff4711fb401ca9c5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/common.py
@@ -0,0 +1,95 @@
+# mypy: allow-untyped-defs
+
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
+from torch.nn import Module
+
+
+__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
+
+
+@compatibility(is_backward_compatible=False)
+class HolderModule(Module):
+    """
+    HolderModule is used to copy all the attributes from original module to submodules
+    that uses the attributes
+    """
+
+    def __init__(self, d):
+        super().__init__()
+        for k, v in d.items():
+            self.add_module(k, v)
+
+
+@compatibility(is_backward_compatible=False)
+def lift_subgraph_as_module(
+    gm: GraphModule,
+    subgraph: Graph,
+    comp_name: str = "",
+    class_name: str = "GraphModule",
+) -> tuple[GraphModule, dict[str, str]]:
+    """
+    Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
+
+    Args:
+        gm (GraphModule): parent graph module
+
+        subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
+
+        comp_name (str): name for the new component
+
+        class_name (str): name for the submodule
+
+    """
+
+    # Loop through all module calls (call_module) and param fetches (get_attr)
+    # in this component, creating HolderModules as necessary to match the path.
+    # e.g. if in the original module there's a get_attr node fetches "conv.weight".
+    # We create a HolderModule as root -> add a HolderModule named "conv" ->
+    # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
+    # the original module.
+    submodule = HolderModule({})
+    orig_to_split_fqn_mapping: dict[str, str] = {}
+    for n in subgraph.nodes:
+        if n.op not in ("call_module", "get_attr"):
+            continue
+
+        target = n.target
+        assert isinstance(target, str)
+        target_name_parts = target.split(".")
+        curr = submodule
+        orig_gm = gm
+
+        for name in target_name_parts[:-1]:
+            if not hasattr(curr, name):
+                # pyrefly: ignore [missing-attribute]
+                curr.add_module(name, HolderModule({}))
+
+            curr = getattr(curr, name)
+            orig_gm = getattr(orig_gm, name)
+
+        leaf_node_name = target_name_parts[-1]
+        leaf_node = getattr(orig_gm, leaf_node_name)
+
+        orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
+        # Relies on custom __setattr__ magic.
+        setattr(curr, leaf_node_name, leaf_node)
+
+    return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
+
+
+@compatibility(is_backward_compatible=False)
+def compare_graphs(left: Graph, right: Graph) -> bool:
+    """
+    Return True if two graphs are identical, i.e they
+        - have the same number of outputs in the same order
+        - have the same number of inputs in the same order
+        - have the same set of nodes, and identical connectivity
+    """
+
+    matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
+    matches = matcher.match(right)
+
+    return len(matches) > 0
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0571c92f61b765732d34f06ba09080dc11a66b60
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py
@@ -0,0 +1,294 @@
+import copy
+from queue import SimpleQueue
+from typing import Optional as _Optional
+
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node
+from torch.fx.passes.tools_common import (  # noqa: F401
+    legalize_graph,
+    NodeList,
+    NodeSet,
+    stable_topological_sort,
+)
+from torch.fx.passes.utils import lift_subgraph_as_module  # type: ignore[attr-defined]
+
+
+@compatibility(is_backward_compatible=False)
+def topo_sort(nodes: NodeList) -> NodeList:
+    # sort nodes according to the topological order
+    indegree_map = dict.fromkeys(nodes, 0)
+    candidates: SimpleQueue[Node] = SimpleQueue()
+
+    for node in nodes:
+        for n in node.all_input_nodes:
+            if n in indegree_map:
+                indegree_map[node] += 1
+        if indegree_map[node] == 0:
+            candidates.put(node)
+
+    sorted_nodes: NodeList = []
+    while not candidates.empty():
+        node = candidates.get()
+        sorted_nodes.append(node)
+
+        for n in node.users:
+            if n in indegree_map:
+                indegree_map[n] -= 1
+                if indegree_map[n] == 0:
+                    candidates.put(n)
+
+    assert len(nodes) == len(sorted_nodes), (
+        "topological sorted nodes doesn't have same length as input nodes"
+    )
+
+    return sorted_nodes
+
+
+@compatibility(is_backward_compatible=False)
+def validate_partition(partition: NodeList) -> bool:
+    # verify the partition doesn't form a dependency cycle in the original graph
+    # returns True for valid partition, False for invalid
+
+    partition_set = set(partition)
+
+    outputs: NodeList = []
+    for node in partition_set:
+        for user_node in node.users:
+            if user_node not in partition_set:
+                # external user node, need to expose as an output
+                outputs.append(user_node)
+
+    # Perform BFS on the partition outputs.
+    # If it reaches a node within the partition, then it found a cycle.
+    # This function takes the ownership of `root_nodes` and may modify it.
+    def bfs_find_cycle(root_nodes: NodeList) -> bool:
+        # Set used to exclude nodes that have already been visited.
+        # If a node has been visited, that node and all its children have
+        # been checked for cycles.
+        visited: NodeSet = set()
+
+        # Start with `root_nodes` and traverse through (toward child nodes)
+        # their connected sub-graph. Nodes in `visited` won't be added
+        # to `queue` again.
+        queue: NodeList = root_nodes
+        while queue:
+            current = queue.pop()
+            visited.add(current)
+            if current in partition_set:
+                # Started from partition's `output` nodes, and reached
+                # another node in partition. Cycle!
+                return True
+            for user_node in current.users:
+                if user_node in visited:
+                    continue
+                queue.append(user_node)
+        # `root_nodes` don't cause cycle.
+        return False
+
+    # Use all output nodes as roots to traverse
+    # the graph to check cycles.
+    if bfs_find_cycle(outputs):
+        return False
+
+    return True
+
+
+@compatibility(is_backward_compatible=False)
+def fuse_as_graphmodule(
+    gm: GraphModule,
+    nodes: NodeList,
+    module_name: str,
+    partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None,
+    *,
+    always_return_tuple: bool = False,
+) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
+    """
+    Fuse nodes in graph_module into a GraphModule.
+
+    Args:
+        gm (GraphModule): target graph_module
+
+        nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
+
+        module_name: class name for the fused GraphModule
+
+        partition_lookup_table (Optional[Dict[Node, None]]): optional dict of nodes to speed up lookup
+
+        always_return_tuple (bool): whether to always return a tuple, even if there is only one output
+
+    Returns:
+        fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
+
+        original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
+
+        original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
+
+    """
+
+    # assumption: nodes are already sorted in topo order
+
+    for node in nodes:
+        assert node.graph.owning_module is gm, (
+            f"{node} doesn't belong to passed in graph module {gm._get_name()}"
+        )
+        assert not node._erased, f"{node} has been removed from owning graph"
+        assert node in gm.graph._find_nodes_lookup_table, (
+            f"{node} is not found in graph module {gm._get_name()}"
+        )
+
+    # validates partition doesn't introduce dependency circles in the graph
+    assert validate_partition(nodes), "Invalid partition, found dependency cycles"
+
+    # if no dict of partition nodes is provided, reconstruct it by nodes list to reduce lookup time
+    if partition_lookup_table is None:
+        partition_lookup_table = dict.fromkeys(nodes)
+
+    subgraph = Graph()
+
+    node_to_placeholder: dict[
+        Node, Node
+    ] = {}  # mapping of nodes from old graph to placeholder in new graph
+    node_map: dict[Node, Node] = {}  # mapping of nodes from old graph to new graph
+
+    # handles inputs through graph.node_copy's arg_transform functions
+    def remap_inputs(x: Node) -> Node:
+        if x.op == "get_attr":
+            # TODO: do we really need copy the get_attr node into the graph?
+            # do something here
+            pass
+
+        if x in partition_lookup_table:
+            # x is inside subgraph, return the copied node
+            # the node should have been copied already, as we are copying graph in the topological order
+            return node_map[x]
+
+        if x not in node_to_placeholder:
+            # x is not in subgraph, create a new placeholder for subgraph
+            placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
+            # copy all meta fields, even if some fields might be irrelevant for the placeholder node
+            placeholder_node.meta = copy.copy(x.meta)
+            node_to_placeholder[x] = placeholder_node
+
+        return node_to_placeholder[x]
+
+    # copy nodes in topological order
+    for node in nodes:
+        new_node = subgraph.node_copy(node, remap_inputs)
+        node_map[node] = new_node
+
+    # handles outputs
+    output_mapping: dict[Node, Node] = {}  # mapping from old output to new outputs
+
+    for node in nodes:
+        for user_node in node.users:
+            if user_node not in partition_lookup_table:
+                # external user node, need to expose as an output
+                output_mapping[node] = node_map[node]
+
+    # outs contain nodes in the new subgraph
+    outs = tuple(output_mapping.values())
+
+    if always_return_tuple:
+        # always return a tuple, even if there is only one output
+        subgraph.output(outs)
+    else:
+        # If there's a single output then return it directly, otherwise return a tuple.
+        subgraph.output(outs[0] if len(outs) == 1 else outs)
+
+    # lint to ensure correctness
+    subgraph.lint()  # type: ignore[no-untyped-call]
+    fused_gm: GraphModule
+    fused_gm, _ = lift_subgraph_as_module(
+        gm, subgraph, comp_name="", class_name=module_name
+    )
+
+    # sub_gm's input nodes in the original module
+    original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys())
+
+    # sub_gm's outputs node in the original module
+    original_outputs: tuple[Node, ...] = tuple(output_mapping.keys())
+
+    return fused_gm, original_inputs, original_outputs
+
+
+@compatibility(is_backward_compatible=False)
+def insert_subgm(
+    gm: GraphModule,
+    sub_gm: GraphModule,
+    orig_inputs: tuple[Node, ...],
+    orig_outputs: tuple[Node, ...],
+) -> GraphModule:
+    # add sub_gm into gm
+    submodule_name = sub_gm.__class__.__name__
+    gm.add_submodule(submodule_name, sub_gm)
+
+    def last_node(target_nodes: tuple[Node, ...]) -> Node | None:
+        for node in reversed(gm.graph.nodes):
+            if node in target_nodes:
+                return node
+        return None
+
+    last_output_node: Node | None = last_node(orig_outputs)
+    assert last_output_node is not None
+
+    # Create a call_module node in main graph.
+    with gm.graph.inserting_after(last_output_node):
+        module_node = gm.graph.call_module(
+            submodule_name, args=orig_inputs, kwargs=None
+        )
+        output_node = sub_gm.graph.output_node()
+
+    next_node = module_node.next
+    with gm.graph.inserting_before(next_node):
+        if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
+            # main_remapping[comp.orig_outputs[0]] = module_node
+            orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
+        else:
+            for i, orig_output in enumerate(orig_outputs):
+                # Use Proxy to record getitem access.
+                proxy_out = torch.fx.Proxy(module_node)[i].node  # type: ignore[index]
+                orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
+
+            module_node.meta["val"] = tuple(
+                orig_output.meta.get("val", None) for orig_output in orig_outputs
+            )
+    return gm
+
+
+@compatibility(is_backward_compatible=False)
+def erase_nodes(gm: GraphModule, nodes: NodeList) -> None:
+    # erase original nodes in inversed topological order
+    for node in reversed(nodes):
+        gm.graph.erase_node(node)
+
+
+@compatibility(is_backward_compatible=False)
+def fuse_by_partitions(
+    gm: GraphModule,
+    partitions: list[dict[Node, _Optional[int]]],
+    prefix: str = "fused_",
+    always_return_tuple: bool = False,
+) -> GraphModule:
+    for partition_id, partition in enumerate(partitions):
+        sorted_nodes = topo_sort(list(partition))
+
+        submodule_name = prefix + str(partition_id)
+        sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
+            gm,
+            sorted_nodes,
+            submodule_name,
+            partition,
+            always_return_tuple=always_return_tuple,
+        )
+
+        insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
+
+        erase_nodes(gm, sorted_nodes)
+
+    stable_topological_sort(gm)
+    gm.graph.lint()
+
+    return gm
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f253cb292860de6ec8d3f8418d4e9d5033ca9c5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py
@@ -0,0 +1,447 @@
+# mypy: allow-untyped-defs
+import copy
+import logging
+import os
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Any, Union
+
+import torch
+from torch.fx import Graph, Node
+from torch.fx._compatibility import compatibility
+
+
+__all__ = ["SubgraphMatcher", "InternalMatch"]
+
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger():
+    logger = logging.getLogger(__name__)
+
+    level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper()
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter("%(filename)s > %(message)s")
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+logger = _init_logger()
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class InternalMatch:
+    # Nodes from which the match was found
+    anchors: list[Node]
+    # Maps nodes in the pattern subgraph to nodes in the larger graph
+    nodes_map: dict[Node, Node] = field(default_factory=dict)
+
+    # nodes in target graph that are matched placeholder in pattern
+    placeholder_nodes: list[Node] = field(default_factory=list)
+
+    # nodes in matched subgraph returned by output
+    returning_nodes: list[Node] = field(default_factory=list)
+
+    # map from a string name to a node in the target graph
+    # only available if the matcher is `SubgraphMatcherWithNameNodesMap`
+    name_node_map: dict[str, Node] = field(default_factory=dict)
+
+    def __copy__(self):
+        return InternalMatch(
+            anchors=self.anchors,
+            nodes_map=self.nodes_map.copy(),
+            placeholder_nodes=self.placeholder_nodes.copy(),
+            returning_nodes=self.returning_nodes.copy(),
+        )
+
+
+@compatibility(is_backward_compatible=False)
+class SubgraphMatcher:
+    def __init__(
+        self,
+        pattern: Graph,
+        match_output: bool = False,
+        match_placeholder: bool = False,
+        remove_overlapping_matches: bool = True,
+        ignore_literals: bool = False,
+    ) -> None:
+        """
+        Args:
+            pattern: the targeted matching pattern, represented in fx.Graph.
+            match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
+                If False, output node is ignored during match.
+            match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
+                the targeted pattern. If False, placeholder nodes will be used a wildcard.
+            remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
+                will be returned.
+            ignore_literals: If True, will not check if literals are equal and
+                will instead treat them as wildcards.
+        """
+
+        self.pattern = pattern
+        self.match_output = match_output
+        self.match_placeholder = match_placeholder
+        self.remove_overlapping_matches = remove_overlapping_matches
+        self.ignore_literals = ignore_literals
+
+        if len(pattern.nodes) == 0:
+            raise ValueError(
+                "SubgraphMatcher cannot be initialized with an empty pattern"
+            )
+
+        for node in pattern.nodes:
+            if node.op != "output" and not node.is_impure():
+                assert len(node.users) > 0, (
+                    "SubgraphMatcher cannot be initialized with an pattern with dead code"
+                )
+
+        # TODO: assert pattern is a connected graph
+
+        self.pattern_placeholder_nodes = [
+            n for n in pattern.nodes if n.op == "placeholder"
+        ]
+        output_node = next(iter(reversed(pattern.nodes)))
+        # nodes returned by outputs
+        self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes
+
+        self.pattern_anchors: list[Node] = []
+        if match_output:
+            self.pattern_anchors = [output_node]
+        else:
+            # If a node has output_node as the ONLY user, then this node is a graph sink,
+            # and should be matched against as an anchor
+            self.pattern_anchors = [
+                n for n in output_node.all_input_nodes if len(n.users) == 1
+            ]
+
+    def _match_attributes(self, pn: Node, gn: Node) -> bool:
+        # Attributes matching is complicated. Right now we only support matching constant tensor
+        assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
+        assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
+
+        pn_value = torch.fx.graph_module._get_attr(pn.graph.owning_module, pn.target)
+        gn_value = torch.fx.graph_module._get_attr(gn.graph.owning_module, gn.target)
+
+        if type(pn_value) is not type(gn_value):
+            return False
+
+        # Don't require exact match on tensor values.
+        if isinstance(pn_value, torch.Tensor):
+            return isinstance(gn_value, torch.Tensor)
+        else:
+            raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
+        return False
+
+    def _nodes_are_equal(self, pn: Node, gn: Node, node_name_match: str = "") -> bool:
+        # if exact match for placeholder is not required, then use placeholder as a wildcard
+        if not self.match_placeholder and pn.op == "placeholder":
+            return True
+
+        if node_name_match and node_name_match in gn.name:
+            return True
+
+        if pn.op == gn.op:
+            if pn.op == "placeholder" or pn.op == "output":
+                return True
+            elif pn.op == "get_attr":
+                return self._match_attributes(pn, gn)
+            return pn.target == gn.target
+        return False
+
+    def _is_contained(self, nodes_map: dict[Node, Node]) -> bool:
+        # `lookup` represents all the nodes in `original_graph`
+        # that are part of `pattern`
+
+        # Placeholders can be used by other nodes in the graphs
+        lookup: dict[Node, Node] = {
+            gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder"
+        }
+
+        for gn, pn in lookup.items():
+            # nodes returned by output are allowed to be used in other areas of the graph
+            if pn in self.pattern_returning_nodes:
+                continue
+
+            for user in gn.users:
+                # If this node has users that were not in `lookup`, then it must leak out of the
+                # pattern subgraph
+                if user not in lookup:
+                    return False
+        return True
+
+    def _remove_overlapping_matches(
+        self, matches: list[InternalMatch]
+    ) -> list[InternalMatch]:
+        non_overlapping_matches: list[InternalMatch] = []
+        nodes_matched: set[Node] = set()
+
+        for match in matches:
+            found_overlap = False
+            for pn, gn in match.nodes_map.items():
+                if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
+                    found_overlap = True
+                    break
+
+            if not found_overlap:
+                non_overlapping_matches.append(match)
+                for pn, gn in match.nodes_map.items():
+                    if pn.op not in {"placeholder", "output"}:
+                        nodes_matched.add(gn)
+        return non_overlapping_matches
+
+    def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
+        assert not (isinstance(pn, Node) and isinstance(gn, Node)), (
+            "pn and gn cannot both be Node"
+        )
+
+        if isinstance(pn, Node) and not isinstance(gn, Node):
+            if pn.op == "placeholder":
+                # Check if we've already matched these nodes in the current
+                # traversal
+                if pn in match.nodes_map:
+                    return match.nodes_map[pn] == gn
+
+                match.nodes_map[pn] = gn
+                return True
+            else:
+                return False
+        elif not isinstance(pn, Node) and isinstance(gn, Node):
+            return False
+        else:
+            return type(gn) is type(pn) and gn == pn
+
+    def _match_nodes(
+        self, pn: Node, gn: Node, match: InternalMatch, node_name_match: str = ""
+    ) -> bool:
+        logger.info("  matching %s to %s", pn, gn)
+
+        assert isinstance(pn, Node) and isinstance(gn, Node), str(
+            f"pn and gn must be Node, pn: {pn}, gn: {gn}"
+        )
+
+        # Check if we've already matched these nodes in the current
+        # traversal
+        if pn in match.nodes_map:
+            return match.nodes_map[pn] == gn
+
+        # TODO: use a more efficient way to check if gn is matched before: two-way dict
+        if gn in match.nodes_map.values():
+            return False
+
+        if not self._nodes_are_equal(pn, gn, node_name_match):
+            return False
+
+        # Optimistically mark `pn` as a match for `gn`, and save a local copy of match
+        saved_match = copy.copy(match)
+        match.nodes_map[pn] = gn
+
+        # Placeholder is a wildcard and can be matched with any python object
+        # (including list/tuple)
+        if pn.op == "placeholder":
+            return True
+
+        # Recursively traverse upwards to check if `pn` is a true
+        # match for `gn`
+        match_found = True
+
+        def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool:
+            if len(args1) != len(args2):
+                return False
+
+            for a1, a2 in zip(args1, args2):
+                if isinstance(a1, Node) and isinstance(a2, Node):
+                    matched = self._match_nodes(a1, a2, match)
+                elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
+                    matched = _match_args(a1, a2)
+                else:
+                    matched = (
+                        self._match_literals(a1, a2, match) or self.ignore_literals
+                    )
+
+                if not matched:
+                    return False
+
+            return True
+
+        # Flatten all args/kwargs into 1 list of args
+        pn_args, gn_args = None, None
+        if (
+            (
+                len(pn.args) != len(gn.args)
+                or list(pn.kwargs.keys()) != list(gn.kwargs.keys())
+            )
+            and pn.op == "call_function"
+            and isinstance(pn.target, torch._ops.OpOverload)
+        ):
+            args_schema = pn.target._schema.arguments
+
+            def get_all_arguments(orig_args, orig_kwargs):
+                all_args = []
+                for i, schema in enumerate(args_schema):
+                    if schema.name in orig_kwargs:
+                        all_args.append(orig_kwargs[schema.name])
+                    elif not schema.kwarg_only and i < len(orig_args):
+                        all_args.append(orig_args[i])
+                    else:
+                        all_args.append(schema.default_value)
+                return all_args
+
+            pn_args = get_all_arguments(pn.args, pn.kwargs)
+            gn_args = get_all_arguments(gn.args, gn.kwargs)
+
+        elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(
+            gn.kwargs.keys()
+        ):
+            pn_args = list(pn.args)
+            gn_args = list(gn.args)
+            pn_args.extend(list(pn.kwargs.values()))
+            gn_args.extend(list(gn.kwargs.values()))
+        else:
+            match_found = False
+
+        match_found = (
+            match_found
+            and pn_args is not None
+            and gn_args is not None
+            and _match_args(pn_args, gn_args)
+        )
+
+        if not match_found:
+            # revert to saved_match before matching with current node
+            match = copy.copy(saved_match)
+            return False
+
+        return True
+
+    def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]:
+        """
+        Returns:
+            The matched subgraphs.
+            The returned subgraph would be fully self-contained, meaning the nodes (except placeholder
+            and nodes returned by output) can only be consumed by nodes within the matched subgraph.
+
+        Subgraph pattern matcher is implemented with the backtracking style in the following steps:
+
+        1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
+        are the "sinks" (nodes with no user other than the output node) of the pattern graph.
+        One pattern graph could have multiple anchors if it has multiple return values.
+
+        2. In the target graph, we identify the potential candidate nodes that can be matched
+        with each anchor. These anchor-candidate pairs are the starting points for
+        pairwise per-node matching.
+
+        3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
+        pattern and target graphs. For every pattern nodes along traversal path, we compare it
+        against the target nodes. In case any comparison failed, the match for this anchor-candidate
+        pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
+        for more details.
+
+        4. In the case of multiple anchors, every anchor will need to find a match using step 3.
+        In addition, the matches found between anchors need to have a common intersection node
+        in order for the match to be valid. This is implemented with backtracking. See `backtracking`
+        for more details.
+
+        Notice: graph traversal must be done in the reverser order because a tensor can have multiple
+        consumers, but can only have a single producer. Only with reverser order, we can we jointly
+        traverse the pattern and target graph in a deterministic path.
+
+        Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
+        in practice, it's unlikely to blow up.
+
+        """
+        from torch.fx.passes.utils.fuser_utils import validate_partition
+
+        # find candidate nodes to match with pattern anchors
+        match_candidates: dict[Node, list[Node]] = defaultdict(list)
+        for pattern_anchor in self.pattern_anchors:
+            for node in graph.nodes:
+                if self._nodes_are_equal(pattern_anchor, node, node_name_match):
+                    match_candidates[pattern_anchor].append(node)
+        match_candidates_list = list(match_candidates.items())
+
+        logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
+
+        matches: list[InternalMatch] = []
+
+        def backtracking(anchor_index, match):
+            if anchor_index == len(match_candidates_list):
+                match.placeholder_nodes = [
+                    match.nodes_map[pn] for pn in self.pattern_placeholder_nodes
+                ]
+                match.returning_nodes = [
+                    match.nodes_map[pn] for pn in self.pattern_returning_nodes
+                ]
+                matches.append(match)
+
+                logger.info("Found a match: %s\n", match)
+                return
+
+            pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
+            saved_match = copy.copy(match)
+
+            for node in candidate_nodes:
+                logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
+
+                match_found = self._match_nodes(
+                    pattern_anchor, node, match, node_name_match
+                )
+                if match_found:
+                    # match next anchor
+                    backtracking(anchor_index + 1, match)
+                else:
+                    logger.info(
+                        "Failed to match anchor %s to %s\n", pattern_anchor, node
+                    )
+
+                # revert to saved_match before matching with current anchor
+                match = copy.copy(saved_match)
+
+        match = InternalMatch(anchors=self.pattern_anchors)
+        if match_candidates_list:
+            backtracking(0, match)
+
+        # filter out the matches where the subgraph is not fully_contained
+        before = len(matches)
+        matches = [match for match in matches if self._is_contained(match.nodes_map)]
+        after = len(matches)
+        if before != after:
+            logger.info(
+                "Filtered out %s matches because they are not fully contained",
+                before - after,
+            )
+
+        # filter out the matches that form a cycle if the subgraph is fused
+        valid_matches = []
+        for match in matches:
+            matched_compute_nodes = [
+                gn
+                for pn, gn in match.nodes_map.items()
+                if pn.op not in {"placeholder", "output"}
+            ]
+            if validate_partition(matched_compute_nodes):
+                valid_matches.append(match)
+        if len(valid_matches) != len(matches):
+            logger.info(
+                "Filtered out %s matches because \
+                          matched subgraph would form a cycle if fused",
+                len(matches) - len(valid_matches),
+            )
+
+        if self.remove_overlapping_matches:
+            before = len(valid_matches)
+            matches = self._remove_overlapping_matches(valid_matches)
+            after = len(matches)
+            if before != after:
+                logger.info(
+                    "Filtered out %s matches because matched subgraphs are overlapping",
+                    before - after,
+                )
+
+        logger.info("Matches returned: %s", matches)
+
+        return matches
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3114d55b635fcb5d02b8e57faade2474ec021e7f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py
@@ -0,0 +1,114 @@
+from torch.fx import Graph, GraphModule, Node
+from torch.fx._compatibility import compatibility
+
+from .matcher_utils import InternalMatch, SubgraphMatcher
+
+
+__all__ = ["SubgraphMatcherWithNameNodeMap"]
+
+
+def _split_to_graph_and_name_node_map(
+    gm: GraphModule,
+) -> tuple[GraphModule, dict[str, Node]]:
+    from torch.fx.graph import _PyTreeInfo
+    from torch.utils._pytree import tree_flatten, tree_unflatten
+
+    name_node_map = {}
+    for n in gm.graph.nodes:
+        if n.op == "output":
+            assert gm._out_spec is not None
+            output = tree_unflatten(n.args[0], gm._out_spec)
+            assert isinstance(output, tuple), (
+                "Expecting the pattern graph to return a tuple"
+            )
+            assert len(output) >= 2, (
+                "Expecting the pattern graph to have at least two outputs"
+            )
+            *out, name_node_map = output
+            flattened, out_spec = tree_flatten(out)
+            assert isinstance(name_node_map, dict), (
+                "Expecting the input graph to have a dict output as the last element"
+            )
+            n.args = (flattened,)
+            orig_pytree_info = gm._graph._codegen.pytree_info  # type: ignore[attr-defined]
+            gm._graph._codegen.pytree_info = _PyTreeInfo(  # type: ignore[attr-defined]
+                orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
+            )
+    gm.recompile()
+    return gm, name_node_map
+
+
+@compatibility(is_backward_compatible=False)
+class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
+    """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name,
+    this requires pattern to have specific format (returning and additional dictionary at the output,
+    that has node name as key, and the node in the pattern graph as value, see Example for more details)
+
+    Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during
+    initialization since we need to modify the graph (which requires `recompile` the GraphModule)
+
+    Example::
+        def pattern(x, weight):
+            conv = F.conv2d(x, weight)
+            relu = F.relu(conv)
+            return relu, {"conv": conv, "relu": relu}
+
+
+        def target_graph(x, weight):
+            conv = F.conv2d(x, weight)
+            relu = F.relu(conv)
+            relu *= 2
+            return relu
+
+
+        pattern_gm = export_for_training(pattern, example_inputs).module()
+        target_gm = export_for_training(target_graph, example_inputs).module()
+        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
+        matches = matcher.match(target_gm)
+        for match in matches:
+            match.name_node_map["conv"].meta["annotation"] = ...
+
+    """
+
+    def __init__(
+        self,
+        pattern_gm: GraphModule,
+        match_output: bool = False,
+        match_placeholder: bool = False,
+        remove_overlapping_matches: bool = True,
+        ignore_literals: bool = False,
+    ) -> None:
+        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
+        self.name_node_map = name_node_map
+        super().__init__(
+            pattern_gm.graph,
+            match_output,
+            match_placeholder,
+            remove_overlapping_matches,
+            ignore_literals,
+        )
+
+    def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]:
+        """The returned InternalMatch will have name_node_map populated with a map
+        from node name (str) to the target node, e.g.
+        {"conv": target_conv_ndoe, "relu": target_relu_node}
+
+        this requires the pattern graph returns an additional
+        output of node name to node, e.g. instead of:
+        ```
+        def pattern(...):
+            ...
+            return relu
+        ```
+        we should do:
+        ```
+        def pattern(...):
+            ...
+            return relu, {"conv": conv, "relu": relu}
+        ``` instead
+        """
+        internal_matches = super().match(graph, node_name_match)
+        for internal_match in internal_matches:
+            for k, n in self.name_node_map.items():
+                internal_match.name_node_map[k] = internal_match.nodes_map[n]
+        return internal_matches
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..82259b8a36ab78ce67ab14411ca4522cc33cd83c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py
@@ -0,0 +1,163 @@
+import logging
+import os
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+from torch.fx.node import Node
+
+
+__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"]
+
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger() -> logging.Logger:
+    logger = logging.getLogger(__name__)
+
+    level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper()
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter("%(filename)s > %(message)s")
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+logger = _init_logger()
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class SourcePartition:
+    # Nodes in a particular partition
+    nodes: list[Node]
+
+    # The source these nodes decomposed from
+    source: Any
+
+    # Nodes in the graph that are needed as inputs to the partition
+    # These do not include the params of the partition
+    input_nodes: list[Node] = field(default_factory=list)
+
+    # Nodes in the partition that are being used by nodes outside of the
+    # partition
+    output_nodes: list[Node] = field(default_factory=list)
+
+    # Parameters that are being used
+    params: list[Node] = field(default_factory=list)
+
+
+@compatibility(is_backward_compatible=False)  # type: ignore[misc]
+def get_source_partitions(
+    graph: Graph,
+    wanted_sources: list[Any],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> dict[Any, list[SourcePartition]]:
+    """
+    Args:
+        graph: The graph we want to partition
+        wanted_sources: List of sources of nodes that were decomposed from this
+            source. This can be a function (ex. torch.nn.functional.linear) or a
+            leaf module type (ex. torch.nn.Linear).
+
+    Returns:
+        Dictionary mapping sources that were given to a list of SourcePartitions
+        that correspond to the list of nodes that were decomposed from the given
+        source.
+    """
+    modules: dict[type, dict[str, list[Node]]] = {}
+
+    for node in graph.nodes:
+        # The metadata source_fn should contain a tuple of a unique name for the
+        # source, and the source function if the node is decomposed from a
+        # function, or the type of module if the node is decomposed from a leaf
+        # module
+
+        # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
+        # be different from "source_fn_stack", for example for the add_ node
+        # decomposed from batch norm. We should remove the check on "source_fn_stack"
+        # after we fix "torch_fn". T199561090
+        if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
+            torch_fn := node.meta.get("torch_fn", None)
+        ) is not None:
+            node_fqn, source_fn = torch_fn
+            source_fn_name = source_fn.split(".")[1]
+            if source_fn_name in wanted_sources:
+                diff_modules = modules.setdefault(source_fn_name, {})
+                partition = diff_modules.setdefault(node_fqn, [])
+                partition.append(node)
+
+        if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
+            source_fn = source_fn_st[-1]
+            if source_fn[1] in wanted_sources:
+                diff_modules = modules.setdefault(source_fn[1], {})
+                partition = diff_modules.setdefault(source_fn[0], [])
+                partition.append(node)
+
+    def make_partition(nodes: list[Node], module_type: type) -> SourcePartition:
+        input_nodes = set()
+        output_nodes = set()
+        params = set()
+        for node in nodes:
+            for arg in node.args:
+                if isinstance(arg, Node) and arg not in nodes and arg.op != "get_attr":
+                    input_nodes.add(arg)
+
+            if node.op == "get_attr":
+                params.add(node)
+                # get_attr nodes won't be output nodes
+                continue
+
+            for user in node.users:
+                if user not in nodes:
+                    output_nodes.add(node)
+
+        return SourcePartition(
+            nodes,
+            module_type,
+            list(input_nodes),
+            list(output_nodes),
+            list(params),  # type: ignore[arg-type]
+        )
+
+    ret: dict[type[Any], list[SourcePartition]] = {}
+
+    if filter_fn:
+        # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
+        # filter condition
+        filtered_modules = {}
+        for tp, name_to_partition in modules.items():
+            filtered_name_to_partition = {
+                name: partition
+                for name, partition in name_to_partition.items()
+                if all(map(filter_fn, partition))
+            }
+            filtered_modules[tp] = filtered_name_to_partition
+        modules = filtered_modules
+
+    for k, v in modules.items():
+        ret[k] = [make_partition(partition, k) for partition in v.values()]
+
+    return ret
+
+
+@compatibility(is_backward_compatible=False)  # type: ignore[misc]
+def check_subgraphs_connected(
+    subgraph1: SourcePartition, subgraph2: SourcePartition
+) -> bool:
+    """
+    Given two subgraphs A and B (in the form of a list of nodes), checks if
+    A has nodes connecting to at least one node in B -- aka there exists a node
+    in B that uses a node in A (not the other way around).
+    """
+
+    for node in reversed(subgraph1.nodes):
+        for user in node.users:
+            if user in subgraph2.nodes:
+                return True
+    return False
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e796d0226becb6dbb44f84ca9b063a615f0c239b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fed008e7c9017efb1798aabc66b2d8c097dad44f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc879a639f9ecc22036eccd08e46a9638c57137a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42da443209383d4722509cd4645d3d50d76f89e9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0381aec51e23a1bba015663926cf67baa40f915
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7f0c1c1a5edeabd3bdb749d4b4125e82d4a8e57
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a82c0233c0e116e960ce2ce8d655cabbfaea99f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e869ffe76badaf586654bd16ae956e28a5df8ccf
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_decompositions.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2aa72d465b7daa8329b59111e328090c1d2fd2e7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8acb355e5e3332a8d1720d9b8ea0145e95ea877e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48c6d30b56f13c431bea1dbe3d194c5b6f4cfd0c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c7c1079c37a84ef68916af335bf6291f38f6005
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_logging.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..284855c1d0590fd94ad61866ae7586f352f19b30
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4da0cc3a7e6b2d00c2fec66490749135d8d72c22
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_pickle.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63690c4a685ccf22c31ceb48d8df61d87001f53f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1d12718e5854665af46f1a3d0ab2ebcafd646b5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87a55b422ef3e1d9d629d100483b26ce5f26b256
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a53d24384eb83878c91e75959fa38e7a5a9c3551
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_shape_functions.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e771177590e30f87548139e04cab6b354bc4bae
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08bd00db3351fa1a5edd3174fcbc3979cf3d5129
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1afd3929a36b9854f4387d6276d41e1ff3c8e33
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cef0f3692236f3392c0619873b5fb4317039b63f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1df940e753cd16bb4aeb18501f467f63d212689
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/generate_bytecode.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fa7b72378fa05ab3c65f81914ffc6ed6273a63b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/quantized.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21a477d122afedd558c401cfc703574ae6d0ceec
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/supported_ops.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed9abc59c0e0621e209253f0f0f477ce3637d7db
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/__pycache__/unsupported_tensor_ops.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f67e3dd422461027c5e0dbef54af1d30f17fa63
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6e19576b32e46097cb8945f3de0e39ecd38c2cd
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/__pycache__/_property_propagation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b8fbb7fd329442aa867c0e39c03cd4f15199
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
@@ -0,0 +1,46 @@
+"""
+Tools to help with tensor property propagation.
+
+This is not intended to be imported directly; please use the exposed
+functionalities in `torch.jit`.
+"""
+
+from typing import Any
+
+import torch
+from torch import TensorType
+from torch._C import Graph
+
+
+def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None:
+    """
+    Applies properties for each tensor in the graph inputs
+    using the example supplied.
+    """
+    graph_inputs = list(graph.inputs())
+    if len(graph_inputs) == 0:
+        return
+
+    # Strip self args off for methods
+    in_0 = graph_inputs[0]
+    if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
+        graph_inputs = graph_inputs[1:]
+
+    if not len(graph_inputs) == len(example_input):
+        raise RuntimeError(
+            "Number of inputs in graph does not match number of inputs in the example"
+        )
+
+    for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
+        if example_i is None:
+            continue  # Skip the type check
+
+        if isinstance(example_i, torch.Tensor) != isinstance(
+            graph_i.type(), TensorType
+        ):
+            raise RuntimeError(
+                f"Input {i} does not match type of example", graph_i, example_i
+            )
+
+        if isinstance(example_i, torch.Tensor):
+            graph_i.setType(TensorType.create_from_tensor(example_i))  # type: ignore[arg-type]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..608d1c2f7798d84498907c032a2a4acc6f65f7ef
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
@@ -0,0 +1,244 @@
+# mypy: allow-untyped-defs
+import os
+
+import torch
+from torch.jit._serialization import validate_map_location
+
+
+def _load_for_lite_interpreter(f, map_location=None):
+    r"""
+    Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
+
+    Args:
+        f: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        map_location: a string or torch.device used to dynamically remap
+            storages to an alternative set of devices.
+
+    Returns:
+        A :class:`LiteScriptModule` object.
+
+    Example:
+
+    .. testcode::
+
+        import torch
+        import io
+
+        # Load LiteScriptModule from saved file path
+        torch.jit._load_for_lite_interpreter('lite_script_module.pt')
+
+        # Load LiteScriptModule from io.BytesIO object
+        with open('lite_script_module.pt', 'rb') as f:
+            buffer = io.BytesIO(f.read())
+
+        # Load all tensors to the original device
+        torch.jit.mobile._load_for_lite_interpreter(buffer)
+    """
+    if isinstance(f, (str, os.PathLike)):
+        if not os.path.exists(f):
+            raise ValueError(f"The provided filename {f} does not exist")
+        if os.path.isdir(f):
+            raise ValueError(f"The provided filename {f} is a directory")
+
+    map_location = validate_map_location(map_location)
+
+    if isinstance(f, (str, os.PathLike)):
+        cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
+    else:
+        cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f.read(),
+            map_location,
+        )
+
+    return LiteScriptModule(cpp_module)
+
+
+class LiteScriptModule:
+    def __init__(self, cpp_module) -> None:
+        self._c = cpp_module
+        super().__init__()
+
+    def __call__(self, *input):
+        return self._c.forward(input)
+
+    def find_method(self, method_name):
+        return self._c.find_method(method_name)
+
+    def forward(self, *input):
+        return self._c.forward(input)
+
+    def run_method(self, method_name, *input):
+        return self._c.run_method(method_name, input)
+
+
+def _export_operator_list(module: LiteScriptModule):
+    r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
+    return torch._C._export_operator_list(module._c)
+
+
+def _get_model_bytecode_version(f_input) -> int:
+    r"""Take a file-like object to return an integer.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        version: An integer. If the integer is -1, the version is invalid. A warning
+            will show in the log.
+
+    Example:
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_bytecode_version
+
+        # Get bytecode version from a saved file path
+        version = _get_model_bytecode_version("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_bytecode_version(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
+
+
+def _get_mobile_model_contained_types(f_input) -> int:
+    r"""Take a file-like object and return a set of string, like ("int", "Optional").
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_mobile_model_contained_types
+
+        # Get type list from a saved file path
+        type_list = _get_mobile_model_contained_types("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
+
+
+def _backport_for_mobile(f_input, f_output, to_version):
+    r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        f_output: path to new model destination
+        to_version: the expected output model bytecode version
+    Returns:
+        success: A boolean. If backport success, return true, otherwise false
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if (isinstance(f_input, (str, os.PathLike))) and (
+        isinstance(f_output, (str, os.PathLike))
+    ):
+        return torch._C._backport_for_mobile(
+            os.fspath(f_input),
+            os.fspath(f_output),
+            to_version,
+        )
+    else:
+        return torch._C._backport_for_mobile_from_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f_input.read(),
+            str(f_output),
+            to_version,
+        )
+
+
+def _backport_for_mobile_to_buffer(f_input, to_version):
+    r"""Take a string containing a file name (file-like object).
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
+    else:
+        return torch._C._backport_for_mobile_from_buffer_to_buffer(
+            # pyrefly: ignore [missing-attribute]
+            f_input.read(),
+            to_version,
+        )
+
+
+def _get_model_ops_and_info(f_input):
+    r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
+
+    These root operators can call other operators within them (traced ops), and
+    a root op can call many different traced ops depending on internal code paths in the root op.
+    These traced ops are not returned by this function. Those operators are abstracted into the
+    runtime as an implementation detail (and the traced ops themselves can also call other operators)
+    making retrieving them difficult and their value from this api negligible since they will differ
+    between which runtime version the model is run on. Because of this, there is a false positive this
+    api can't prevent in a compatibility usecase. All the root ops of a model are present in a
+    target runtime, but not all the traced ops are which prevents a model from being able to run.
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
+        of the model to their OperatorInfo structs.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_ops_and_info
+
+        # Get bytecode version from a saved file path
+        ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_ops_and_info(os.fspath(f_input))
+    else:
+        # pyrefly: ignore [missing-attribute]
+        return torch._C._get_model_ops_and_info(f_input.read())
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20f33001556c8dddb94626888ffa6cafefce64ec
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/jit/mobile/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
new file mode 100644
index 0000000000000000000000000000000000000000..e441ff5a28936d8ca999fcb61ddc8dbbb2c8c12b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+struct AllocInfo {
+  pid_t pid;
+  char free;
+  char filename[60];
+};
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/err.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/err.h
new file mode 100644
index 0000000000000000000000000000000000000000..e1e6aa4e277c3a94dd642ff2a27e6cd564322e46
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/err.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+
+// `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occurred.
+//
+// All functions used in `libshm` (so far) indicate error by returning `-1`. If
+// you want to use a function with a different error reporting mechanism, you
+// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr)                          \
+  while (true) {                                                \
+    if ((expr) == -1) {                                         \
+      if (errno == EINTR) {                                     \
+        continue;                                               \
+      } else {                                                  \
+        throw std::system_error(errno, std::system_category()); \
+      }                                                         \
+    } else {                                                    \
+      break;                                                    \
+    }                                                           \
+  }
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/libshm.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..d3f7c7061abc9e56b7147fad7e85d1bcdacc61c8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+void libshm_init(const char* manager_exec_path);
+
+// Superclass to run a constructor before at::RefcountedMapAllocator
+class THManagedMapAllocatorInit {
+ protected:
+  THManagedMapAllocatorInit(const char* manager_handle, const char* filename);
+  std::string manager_handle_;
+};
+
+// Like a at::RefcountedMapAllocator, but it also makes use of an external
+// shared memory manager process to ensure that shared memory regions actually
+// get freed in the end (even if processes lose the memory).
+class THManagedMapAllocator : private THManagedMapAllocatorInit,
+                              public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+
+  void close() override;
+
+  ~THManagedMapAllocator() override {
+    close();
+  }
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/);
+
+  const char* manager_handle() const {
+    return manager_handle_.c_str();
+  }
+};
+
+#endif
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/socket.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/socket.h
new file mode 100644
index 0000000000000000000000000000000000000000..e048098b94efac3360d4d72835d60b346fab4842
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm/socket.h
@@ -0,0 +1,164 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+class Socket {
+ public:
+  int socket_fd;
+  Socket(const Socket& other) = delete;
+
+ protected:
+  Socket() {
+    SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
+  }
+  Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
+    other.socket_fd = -1;
+  };
+  explicit Socket(int fd) : socket_fd(fd) {}
+
+  virtual ~Socket() {
+    if (socket_fd != -1)
+      close(socket_fd);
+  }
+
+  struct sockaddr_un prepare_address(const char* path) {
+    struct sockaddr_un address;
+    address.sun_family = AF_UNIX;
+    strcpy(address.sun_path, path);
+    return address;
+  }
+
+  // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
+  size_t address_length(struct sockaddr_un address) {
+    return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
+  }
+
+  void recv(void* _buffer, size_t num_bytes) {
+    char* buffer = (char*)_buffer;
+    size_t bytes_received = 0;
+    ssize_t step_received;
+    struct pollfd pfd = {};
+    pfd.fd = socket_fd;
+    pfd.events = POLLIN;
+    while (bytes_received < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
+      if (pfd.revents & POLLIN) {
+        SYSCHECK_ERR_RETURN_NEG1(
+            step_received =
+                ::read(socket_fd, buffer, num_bytes - bytes_received));
+        TORCH_CHECK(step_received != 0, "Other end has closed the connection");
+        bytes_received += step_received;
+        buffer += step_received;
+      } else if (pfd.revents & (POLLERR | POLLHUP)) {
+        TORCH_CHECK(false, "An error occurred while waiting for the data");
+      } else {
+        TORCH_CHECK(false, "Shared memory manager connection has timed out");
+      }
+    }
+  }
+
+  void send(const void* _buffer, size_t num_bytes) {
+    const char* buffer = (const char*)_buffer;
+    size_t bytes_sent = 0;
+    ssize_t step_sent;
+    while (bytes_sent < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(
+          step_sent = ::write(socket_fd, buffer, num_bytes));
+      bytes_sent += step_sent;
+      buffer += step_sent;
+    }
+  }
+};
+
+class ManagerSocket : public Socket {
+ public:
+  explicit ManagerSocket(int fd) : Socket(fd) {}
+
+  AllocInfo receive() {
+    AllocInfo info;
+    recv(&info, sizeof(info));
+    return info;
+  }
+
+  void confirm() {
+    send("OK", 2);
+  }
+};
+
+class ManagerServerSocket : public Socket {
+ public:
+  explicit ManagerServerSocket(const std::string& path) {
+    socket_path = path;
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          bind(socket_fd, (struct sockaddr*)&address, len));
+      SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void remove() {
+    struct stat file_stat;
+    if (fstat(socket_fd, &file_stat) == 0)
+      SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
+  }
+
+  ~ManagerServerSocket() override {
+    unlink(socket_path.c_str());
+  }
+
+  ManagerSocket accept() {
+    int client_fd;
+    struct sockaddr_un addr;
+    socklen_t addr_len = sizeof(addr);
+    SYSCHECK_ERR_RETURN_NEG1(
+        client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
+    return ManagerSocket(client_fd);
+  }
+
+  std::string socket_path;
+};
+
+class ClientSocket : public Socket {
+ public:
+  explicit ClientSocket(const std::string& path) {
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          connect(socket_fd, (struct sockaddr*)&address, len));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void register_allocation(AllocInfo& info) {
+    char buffer[3] = {0, 0, 0};
+    send(&info, sizeof(info));
+    recv(buffer, 2);
+    TORCH_CHECK(
+        strcmp(buffer, "OK") == 0,
+        "Shared memory manager didn't respond with an OK");
+  }
+
+  void register_deallocation(AllocInfo& info) {
+    send(&info, sizeof(info));
+  }
+};
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..4dd193df93d110e3a04d33a3f9d3e3ec24948277
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+#ifdef SHM_EXPORTS
+#define SHM_API __declspec(dllexport)
+#else
+#define SHM_API __declspec(dllimport)
+#endif
+
+SHM_API void libshm_init(const char* manager_exec_path);
+
+class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size)
+      : at::RefcountedMapAllocator(filename, flags, size) {}
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
+
+  const char* manager_handle() const {
+    return "no_manager";
+  }
+};
+
+#endif
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0fa91c924021fe377ed07cceff0dd9faad8e09a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51c9f2bb40b217bff60e620dada51d179071ec63
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5970cd4aaa6ead97467baadb64d0e6ecb35aeaa
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef878d3c4b20ef38c7dfd6e14631e99b2fddcc1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .binary import _apply_native_binary, _is_native_binary
+from .core import is_masked_tensor, MaskedTensor
+from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
+from .reductions import _apply_reduction, _is_reduction
+from .unary import _apply_native_unary, _is_native_unary
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..125fb019ff766692329c8db3773f791fe1600981
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f6fbbc03402750480e49c1d319be136dcb51b65
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca9205a3850b9a77082872fd6ce7aadfeb0c2c1c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bef65298a5ff37414fd3b2e836f8e0dd4fb63f4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f4e54d1641ad969c6ad81e0464f84fc09331d79
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f114d8755f2786b0ddc9f1d20701ce723ddb1cb
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cedcb47a465d8ec62471fb1ef1af548250470e3
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6ea2846ab5544d29da3edb808b19244c519c094
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9add0a1dfbae1f8dee18fecdcfd3f60da5231d7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
@@ -0,0 +1,547 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from collections.abc import Callable
+from functools import partial
+from typing import Any, TYPE_CHECKING
+
+import torch
+
+from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
+from .core import (
+    _get_data,
+    _masks_match,
+    _maybe_get_mask,
+    is_masked_tensor,
+    MaskedTensor,
+)
+from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
+from .reductions import (
+    _apply_reduction,
+    NATIVE_REDUCE_FNS,
+    TENSOR_REDUCE_FNS,
+    TORCH_REDUCE_FNS,
+)
+from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
+
+
+if TYPE_CHECKING:
+    from torch._ops import OpOverload
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _check_args_kwargs_length(
+    args, kwargs, error_prefix, len_args=None, len_kwargs=None
+):
+    if len_args is not None and len_args != len(args):
+        raise ValueError(
+            f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
+        )
+    if len_kwargs is not None and len_kwargs != len(kwargs):
+        raise ValueError(
+            f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
+        )
+
+
+class _MaskedContiguous(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
+
+        if input.is_contiguous():
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.contiguous(), mask.contiguous())
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output
+
+
+class _MaskedToDense(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
+
+        if input.layout == torch.strided:
+            return input
+
+        ctx.layout = input.layout
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.to_dense(), mask.to_dense())
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        layout = ctx.layout
+
+        if layout == torch.sparse_coo:
+            return grad_output.to_sparse_coo()
+        elif layout == torch.sparse_csr:
+            return grad_output.to_sparse_csr()
+        elif layout == torch.strided:
+            return grad_output.to_dense()
+        raise ValueError("to_dense: Unsupported input layout: ", layout)
+
+
+class _MaskedToSparse(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
+
+        # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
+        if input.layout == torch.sparse_coo:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_coo().coalesce()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedToSparseCsr(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
+
+        if input._masked_data.ndim != 2:
+            raise ValueError(
+                f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
+            )
+
+        if input.layout == torch.sparse_csr:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_csr()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedWhere(torch.autograd.Function):
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def forward(ctx, cond, self, other):
+        ctx.mark_non_differentiable(cond)
+        ctx.save_for_backward(cond)
+        return torch.ops.aten.where(cond, self, other)
+
+    @staticmethod
+    # pyrefly: ignore [bad-override]
+    def backward(ctx, grad_output):
+        (cond,) = ctx.saved_tensors
+
+        def masked_out_like(mt):
+            return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
+
+        return (
+            None,
+            torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
+            torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
+        )
+
+
+_MASKEDTENSOR_FUNCTION_TABLE = {}
+
+_function_fn_apply_map = {
+    (
+        tuple(NATIVE_REDUCE_FNS),
+        tuple(TORCH_REDUCE_FNS),
+        tuple(TENSOR_REDUCE_FNS),
+    ): _apply_reduction,
+}
+
+for fn_map_list, apply_fn in _function_fn_apply_map.items():
+    for fn_map in fn_map_list:
+        for fn in fn_map:
+            _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
+
+
+def register_function_func(ops):
+    """
+    Used for registering a new __torch_function__ function to MaskedTensor
+    Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_function_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for op in ops:
+            _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
+
+    return wrapper
+
+
+@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_function_reductions(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_function_func([torch.Tensor.where, torch.where])
+def _function_where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
+    )
+    return _MaskedWhere.apply(*args)
+
+
+@register_function_func([torch.Tensor.contiguous])
+def _function_contiguous(func, *args, **kwargs):
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_dense])
+def _function_to_dense(func, *args, **kwargs):
+    return _MaskedToDense.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse])
+def _function_to_sparse(func, *args, **kwargs):
+    return _MaskedToSparse.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse_csr])
+def _function_to_sparse_csr(func, *args, **kwargs):
+    return _MaskedToSparseCsr.apply(args[0])
+
+
+_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {}
+
+
+def register_dispatch_func(aten_ops):
+    """
+    Used for registering a new __torch_dispatch__ function to MaskedTensor
+    Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_dispatch_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+            _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
+
+    return wrapper
+
+
+@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_reduction(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_dispatch_func(PASSTHROUGH_FNS)
+def _general_passthrough(func, *args, **kwargs):
+    return _apply_pass_through_fn(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
+def _general_unary(func, *args, **kwargs):
+    return _apply_native_unary(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
+def _general_binary(func, *args, **kwargs):
+    return _apply_native_binary(func, *args, **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.stride])
+def stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.aten.sym_stride])
+def sym_stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.prim.layout])
+def layout(func, *args, **kwargs):
+    return _get_data(args[0]).layout
+
+
+@register_dispatch_func(
+    [torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
+)
+def is_contiguous(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_strides_like_format])
+def is_strides_like_format(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_strides_like_format"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
+def is_non_overlapping_and_dense(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.contiguous])
+def contiguous(func, *args, **kwargs):
+    if _get_data(args[0]).is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have contiguous")
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_dispatch_func([torch.ops.aten.new_empty_strided])
+def new_empty_strided(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if tuple(args[1]) != tuple(data.size()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
+        )
+    if tuple(args[2]) != tuple(data.stride()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
+        )
+    return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
+
+
+@register_dispatch_func([torch.ops.aten._local_scalar_dense])
+def _local_scalar_dense(func, *args, **kwargs):
+    if not _maybe_get_mask(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
+    return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
+def _apply_fn_on_data(func, *args, **kwargs):
+    return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._to_copy])
+def _to_copy(func, *args, **kwargs):
+    new_data = func(_get_data(args[0]), *args[1:], **kwargs)
+    cloned_kwargs = kwargs.copy()
+    cloned_kwargs["dtype"] = torch.bool
+    new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._softmax])
+def _softmax(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
+    return MaskedTensor(result_data, mask)
+
+
+@register_dispatch_func([torch.ops.aten.ones_like])
+def ones_like(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
+    result_data = func(_get_data(args[0]), **kwargs)
+    return MaskedTensor(result_data, _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._softmax_backward_data])
+def _softmax_backward_data(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
+    grad, output, dim, _input_dtype = args
+    if is_masked_tensor(grad) and is_masked_tensor(output):
+        if not _masks_match(grad, output):
+            raise ValueError(
+                f"__torch_dispatch__, {func}: expected the masks of grad and output to match"
+            )
+        grad_data = _get_data(grad)
+        new_grad_data = torch.ops.aten._masked_softmax_backward(
+            grad_data,
+            _get_data(output),
+            ~_maybe_get_mask(grad),
+            dim % grad_data.ndim,
+        )
+        res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
+        return res
+    else:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
+        )
+
+
+@register_dispatch_func([torch.ops.aten.copy_])
+def copy_(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
+        raise ValueError("args[0] mask and args[1] mask must match but do not")
+    func(_get_data(args[0]), _get_data(args[1]))
+    return args[0]
+
+
+@register_dispatch_func([torch.ops.aten.where])
+def where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mx = args[1]
+    my = args[2]
+    if not is_masked_tensor(mx):
+        mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
+    if not is_masked_tensor(my):
+        my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
+    new_data = func(args[0], mx.get_data(), my.get_data())
+    new_mask = func(args[0], mx.get_mask(), my.get_mask())
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse])
+def _to_sparse(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise TypeError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
+    if mt.is_sparse_coo():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0])).coalesce()
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse_csr])
+def _to_sparse_csr(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    if mt.is_sparse_csr():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0]))
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_dense])
+def _to_dense(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    new_data = func(_get_data(args[0]))
+    new_mask = func(_maybe_get_mask(args[0]))
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._indices])
+def _indices(func, *args, **kwargs):
+    # Assumes data is sparse
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).indices()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._values])
+def _values(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).values()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
+def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
+    new_args = list(args)
+    if is_masked_tensor(args[-1]):
+        new_args[-1] = args[-1].get_data()
+    if is_masked_tensor(args[-2]):
+        new_args[-2] = args[-2].get_data()
+
+    new_data = func(*new_args, **kwargs)
+    new_args[-1] = torch.ones_like(new_args[-1])
+    new_mask = func(*new_args, **kwargs).bool()
+
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten.is_same_size])
+def is_same_size(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    return _get_data(args[0]).is_same_size(_get_data(args[1]))
+
+
+@register_dispatch_func([torch.ops.aten._is_any_true])
+def _is_any_true(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if mask is None:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor"
+        )
+    if data.dtype != torch.bool:
+        raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor")
+    if data.is_sparse:
+        raise ValueError(f"MaskedTensors with sparse data do not have {func}")
+
+    return MaskedTensor(func(data & mask), torch.tensor(True))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8315ae11be7175c2b5aaef178a4bc4785dcbcb29
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
@@ -0,0 +1,200 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import (
+    _map_mt_args_kwargs,
+    _masks_match,
+    _tensors_match,
+    _wrap_result,
+    is_masked_tensor,
+)
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+BINARY_NAMES = [
+    "add",
+    "atan2",
+    "arctan2",
+    "bitwise_and",
+    "bitwise_or",
+    "bitwise_xor",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "div",
+    "divide",
+    "floor_divide",
+    "fmod",
+    "logaddexp",
+    "logaddexp2",
+    "mul",
+    "multiply",
+    "nextafter",
+    "remainder",
+    "sub",
+    "subtract",
+    "true_divide",
+    "eq",
+    "ne",
+    "le",
+    "ge",
+    "greater",
+    "greater_equal",
+    "gt",
+    "less_equal",
+    "lt",
+    "less",
+    "maximum",
+    "minimum",
+    "fmax",
+    "fmin",
+    "not_equal",
+]
+
+INPLACE_BINARY_NAMES = [
+    n + "_"
+    for n in (
+        list(
+            set(BINARY_NAMES)
+            - {
+                "logaddexp",
+                "logaddexp2",
+                "equal",
+                "fmin",
+                "minimum",
+                "maximum",
+                "fmax",
+            }
+        )
+    )
+]
+
+
+def _get_at_least_one_mask(a, b):
+    if not is_masked_tensor(a) and not is_masked_tensor(b):
+        raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
+    if not _masks_match(a, b):
+        raise ValueError("a and b must have matching masks")
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return b.get_mask()
+
+
+def _binary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError("len(kwargs) must equal 0")
+    for a in args[2:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
+            )
+
+    if not _masks_match(*args[:2]):
+        raise ValueError(
+            "Input masks must match. If you need support for this, please open an issue on Github."
+        )
+
+    data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+
+    args0_layout = data_args[0].layout
+    same_layout = (
+        torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
+    ) and (args0_layout == data_args[1].layout)
+
+    if args0_layout == torch.sparse_coo:
+        if same_layout:
+            if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
+                raise ValueError(
+                    "sparse_coo indices must match. If you need support for this, please open an issue on Github."
+                )
+            if data_args[0].size() != data_args[1].size():
+                raise ValueError(
+                    "input1 and input2 must have the same size for binary functions."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        i = data_args[0].indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size)
+
+    elif args0_layout == torch.sparse_csr:
+        if same_layout:
+            if not (
+                _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
+                and _tensors_match(
+                    data_args[0].col_indices(), data_args[1].col_indices()
+                )
+            ):
+                raise ValueError(
+                    "sparse_csr indices must match. If you need support for this, please open an issue on Github."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v, size)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        result_mask = _get_at_least_one_mask(*args[:2])
+        # sparse tensors don't have strides so we can only expand if the layout is strided
+        if args0_layout == torch.strided:
+            result_mask = result_mask.expand_as(result_data)
+        return _wrap_result(result_data, result_mask)
+
+
+def _torch_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=False)
+
+    return binary_fn
+
+
+def _torch_inplace_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=True)
+
+    return binary_fn
+
+
+NATIVE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
+}
+NATIVE_INPLACE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_binary(name)
+    for name in INPLACE_BINARY_NAMES
+}
+
+NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
+NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
+
+
+def _is_native_binary(fn):
+    return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
+
+
+def _apply_native_binary(fn, *args, **kwargs):
+    if fn in NATIVE_BINARY_FNS:
+        return NATIVE_BINARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_BINARY_FNS:
+        return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad5621b29bd663ef4462f1be6c8f8f2c4762c2d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
@@ -0,0 +1,364 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+from typing import Any
+from typing_extensions import TypeIs
+
+import torch
+from torch.overrides import get_default_nowrap_functions
+
+
+__all__ = [
+    "MaskedTensor",
+    "is_masked_tensor",
+]
+
+
+def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
+    r"""Returns True if the input is a MaskedTensor, else False
+
+    Args:
+        a: any input
+
+    Examples:
+
+        >>> # xdoctest: +SKIP
+        >>> from torch.masked import MaskedTensor
+        >>> data = torch.arange(6).reshape(2, 3)
+        >>> mask = torch.tensor([[True, False, False], [True, True, False]])
+        >>> mt = MaskedTensor(data, mask)
+        >>> is_masked_tensor(mt)
+        True
+    """
+    return isinstance(obj, MaskedTensor)
+
+
+def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
+    if is_masked_tensor(a) or is_masked_tensor(b):
+        raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
+    if a.layout != b.layout:
+        raise ValueError(
+            f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
+        )
+
+    if a.dtype != b.dtype:
+        b = b.type(a.dtype)
+    if a.layout == b.layout == torch.sparse_coo:
+        return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
+            a.indices(), b.indices(), exact
+        )
+    elif a.layout == b.layout == torch.sparse_csr:
+        return (
+            _tensors_match(a.crow_indices(), b.crow_indices(), exact)
+            and _tensors_match(a.col_indices(), b.col_indices(), exact)
+            and _tensors_match(a.values(), b.values(), exact)
+        )
+    if exact:
+        return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
+    return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
+
+
+def _masks_match(a, b):
+    if is_masked_tensor(a) and is_masked_tensor(b):
+        mask_a = a.get_mask()
+        mask_b = b.get_mask()
+        return _tensors_match(mask_a, mask_b, exact=True)
+    return True
+
+
+def _map_mt_args_kwargs(args, kwargs, map_fn):
+    def _helper(a, map_fn):
+        if is_masked_tensor(a):
+            return map_fn(a)
+        elif torch.is_tensor(a):
+            return a
+        elif isinstance(a, list):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return a_impl
+        elif isinstance(a, tuple):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return tuple(a_impl)
+        else:
+            return a
+
+    if kwargs is None:
+        kwargs = {}
+    impl_args = []
+    for a in args:
+        impl_args.append(_helper(a, map_fn))
+    impl_kwargs = {}
+    for k in kwargs:
+        impl_kwargs[k] = _helper(a, map_fn)
+    return impl_args, impl_kwargs
+
+
+def _wrap_result(result_data, result_mask):
+    if isinstance(result_data, list):
+        return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
+    if isinstance(result_data, tuple):
+        return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
+    if torch.is_tensor(result_data):
+        return MaskedTensor(result_data, result_mask)
+    # Expect result_data and result_mask to be Tensors only
+    return NotImplemented
+
+
+def _masked_tensor_str(data, mask, formatter):
+    if data.layout in {torch.sparse_coo, torch.sparse_csr}:
+        data = data.to_dense()
+        mask = mask.to_dense()
+    if data.dim() == 1:
+        formatted_elements = [
+            formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
+            for d in data
+        ]
+        max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
+        return (
+            "["
+            + ", ".join(
+                [
+                    "--".rjust(max_len) if m else e
+                    for (e, m) in zip(formatted_elements, ~mask)
+                ]
+            )
+            + "]"
+        )
+    sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
+    sub_strings = ["\n".join(["  " + si for si in s.split("\n")]) for s in sub_strings]
+    return "[\n" + ",\n".join(sub_strings) + "\n]"
+
+
+def _get_data(a):
+    if is_masked_tensor(a):
+        return a._masked_data
+    return a
+
+
+def _maybe_get_mask(a):
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return None
+
+
+class MaskedTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, data, mask, requires_grad=False):
+        if is_masked_tensor(data) or not torch.is_tensor(data):
+            raise TypeError("data must be a Tensor")
+        if is_masked_tensor(mask) or not torch.is_tensor(mask):
+            raise TypeError("mask must be a Tensor")
+        # Use a Tensor that of the give size for the wrapper.
+        kwargs = {
+            "device": data.device,
+            "dtype": data.dtype,
+            "layout": data.layout,
+            "requires_grad": requires_grad,
+            "dispatch_sizes_strides_policy": "strides",
+            "dispatch_layout": True,
+        }
+        warnings.warn(
+            (
+                "The PyTorch API of MaskedTensors is in prototype stage "
+                "and will change in the near future. Please open a Github issue "
+                "for features requests and see our documentation on the torch.masked "
+                "module for further information about the project."
+            ),
+            UserWarning,
+            stacklevel=2,
+        )
+        if data.requires_grad:
+            warnings.warn(
+                "It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
+                "To avoid this, you can use data.detach().clone()",
+                UserWarning,
+                stacklevel=2,
+            )
+        # pyrefly: ignore [bad-argument-type]
+        return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
+
+    def _preprocess_data(self, data, mask):
+        from .._ops import _sparse_coo_where, _sparse_csr_where
+
+        if data.layout != mask.layout:
+            raise TypeError("data and mask must have the same layout.")
+        if data.layout == torch.sparse_coo:
+            data = data.coalesce()
+            mask = mask.coalesce()
+            if data._nnz() != mask._nnz():
+                data = _sparse_coo_where(mask, data, torch.tensor(0))
+        elif data.layout == torch.sparse_csr:
+            if data._nnz() != mask._nnz():
+                data = _sparse_csr_where(mask, data, torch.tensor(0))
+
+        # Have to pick awkward names to not conflict with existing fields such as data
+        self._masked_data = data.clone()
+        self._masked_mask = mask.clone()
+
+    def _validate_members(self):
+        data = self._masked_data
+        mask = self.get_mask()
+        if type(data) is not type(mask):
+            raise TypeError(
+                f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
+            )
+        if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
+            raise TypeError(f"data layout of {data.layout} is not supported.")
+        if data.layout == torch.sparse_coo:
+            if not _tensors_match(data.indices(), mask.indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse COO tensors but do not have the same indices."
+                )
+        elif data.layout == torch.sparse_csr:
+            if not _tensors_match(
+                data.crow_indices(), mask.crow_indices(), exact=True
+            ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse CSR tensors but do not share either crow or col indices."
+                )
+        if mask.dtype != torch.bool:
+            raise TypeError("mask must have dtype bool.")
+        if not (
+            data.dtype == torch.float16
+            or data.dtype == torch.float32
+            or data.dtype == torch.float64
+            or data.dtype == torch.bool
+            or data.dtype == torch.int8
+            or data.dtype == torch.int16
+            or data.dtype == torch.int32
+            or data.dtype == torch.int64
+        ):
+            raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
+        if data.dim() != mask.dim():
+            raise ValueError("data.dim() must equal mask.dim()")
+        if data.size() != mask.size():
+            raise ValueError("data.size() must equal mask.size()")
+
+    def __init__(self, data, mask, requires_grad=False):
+        self._preprocess_data(data, mask)
+        self._validate_members()
+
+    @staticmethod
+    def _from_values(data, mask):
+        """Differentiable constructor for MaskedTensor"""
+
+        class Constructor(torch.autograd.Function):
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def forward(ctx, data, mask):
+                return MaskedTensor(data, mask)
+
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def backward(ctx, grad_output):
+                return grad_output, None
+
+        result = Constructor.apply(data, mask)
+        return result
+
+    def _set_data_mask(self, data, mask):
+        self._masked_data = data
+        self._masked_mask = mask
+        self._validate_members()
+
+    def __repr__(self):  # type: ignore[override]
+        formatter = "{0:8.4f}"
+        if self.dim() == 0:
+            scalar_data = self.get_data().item()
+            data_formatted = (
+                formatter.format(scalar_data)
+                if isinstance(scalar_data, float)
+                else str(scalar_data)
+            )
+            if not self.get_mask().item():
+                data_formatted = "--"
+            return (
+                "MaskedTensor("
+                + data_formatted
+                + ", "
+                + str(self.get_mask().item())
+                + ")"
+            )
+        s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
+        s = "\n".join("  " + si for si in s.split("\n"))
+        return "MaskedTensor(\n" + s + "\n)"
+
+    # Seems like this needs to be defined before torch_dispatch to work
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
+
+        if func in _MASKEDTENSOR_FUNCTION_TABLE:
+            return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+        with torch._C.DisableTorchFunctionSubclass():
+            ret = func(*args, **kwargs)
+            if func in get_default_nowrap_functions():
+                return ret
+            else:
+                return torch._tensor._convert(ret, cls)
+
+    @classmethod
+    def unary(cls, fn, data, mask):
+        return MaskedTensor(fn(data), mask)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):  # type: ignore[override]
+        func = func.overloadpacket
+
+        from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
+
+        if func in _MASKEDTENSOR_DISPATCH_TABLE:
+            return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+        msg = (
+            f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
+            "If you would like this operator to be supported, please file an issue for a feature request at "
+            "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+            "In the case that the semantics for the operator are not trivial, it would be appreciated "
+            "to also include a proposal for the semantics."
+        )
+        warnings.warn(msg, stacklevel=2)
+        return NotImplemented
+
+    def __lt__(self, other):
+        if is_masked_tensor(other):
+            return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
+        return MaskedTensor(self.get_data() < other, self.get_mask())
+
+    def to_tensor(self, value):
+        return self.get_data().masked_fill(~self.get_mask(), value)
+
+    def get_data(self):
+        class GetData(torch.autograd.Function):
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def forward(ctx, self):
+                return self._masked_data.detach()
+
+            @staticmethod
+            # pyrefly: ignore [bad-override]
+            def backward(ctx, grad_output):
+                if is_masked_tensor(grad_output):
+                    return grad_output
+                return MaskedTensor(grad_output, self.get_mask())
+
+        return GetData.apply(self)
+
+    def get_mask(self):
+        return self._masked_mask
+
+    def is_sparse_coo(self):
+        return self.layout == torch.sparse_coo
+
+    def is_sparse_csr(self):  # type: ignore[override]
+        return self.layout == torch.sparse_csr
+
+    # Update later to support more sparse layouts
+    @property
+    def is_sparse(self):  # type: ignore[override]
+        return self.is_sparse_coo() or self.is_sparse_csr()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
new file mode 100644
index 0000000000000000000000000000000000000000..35c8e3d2aa9438dbcfc7995a1cdcd3c5cc8dc1fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
@@ -0,0 +1,24 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .core import MaskedTensor
+
+
+__all__ = [
+    "as_masked_tensor",
+    "masked_tensor",
+]
+
+
+# These two factory functions are intended to mirror
+#     torch.tensor - guaranteed to be a leaf node
+#     torch.as_tensor - differentiable constructor that preserves the autograd history
+
+
+def masked_tensor(
+    data: object, mask: object, requires_grad: bool = False
+) -> MaskedTensor:
+    return MaskedTensor(data, mask, requires_grad)
+
+
+def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
+    return MaskedTensor._from_values(data, mask)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba13f50c1fee9c9fc10563ffc9f4ff3211c0dca6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
@@ -0,0 +1,50 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+"""
+These are functions that should simply be applied to both mask and data.
+Take select or stack as an example. This operation can be applied to
+both the mask and data of a MaskedTensor and the result wrapped into
+a new MaskedTensor as a result.
+"""
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+PASSTHROUGH_FNS = [
+    torch.ops.aten.select,
+    torch.ops.aten.transpose,
+    torch.ops.aten.split,
+    torch.ops.aten.t,
+    torch.ops.aten.slice,
+    torch.ops.aten.slice_backward,
+    torch.ops.aten.select_backward,
+    torch.ops.aten.index,
+    torch.ops.aten.expand,
+    torch.ops.aten.view,
+    torch.ops.aten._unsafe_view,
+    torch.ops.aten._reshape_alias,
+    torch.ops.aten.cat,
+    torch.ops.aten.unsqueeze,
+    torch.ops.aten.unfold,
+    torch.ops.aten.unfold_backward,
+    torch.ops.aten.im2col,
+    torch.ops.aten.col2im,
+    torch.ops.aten.stack,
+]
+
+
+def _is_pass_through_fn(fn):
+    return fn in PASSTHROUGH_FNS
+
+
+def _apply_pass_through_fn(fn, *args, **kwargs):
+    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    result_data = fn(*data_args, **data_kwargs)
+    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+    result_mask = fn(*mask_args, **mask_kwargs)
+    return _wrap_result(result_data, result_mask)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
new file mode 100644
index 0000000000000000000000000000000000000000..6acc8415267bb9fdd7fe6af707cfbbaa74869184
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
@@ -0,0 +1,176 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+
+import torch
+
+from .core import is_masked_tensor
+from .creation import as_masked_tensor, masked_tensor
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _masked_all_all(data, mask=None):
+    if mask is None:
+        return data.all()
+    return data.masked_fill(~mask, True).all()
+
+
+def _masked_all_dim(data, dim, keepdim=False, mask=None):
+    if mask is None:
+        return torch.all(data, dim=dim, keepdim=keepdim)
+    return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
+
+
+def _masked_all(*args, **kwargs):
+    if len(args) == 1 and len(kwargs) == 1:
+        return _masked_all_all(args[0], mask=kwargs["mask"])
+    return _masked_all_dim(*args, **kwargs)
+
+
+def _multidim_any(mask, dim, keepdim):
+    if isinstance(dim, int):
+        return _multidim_any(mask, [dim], keepdim)
+    for d in sorted(dim, reverse=True):
+        mask = torch.any(mask, dim=d, keepdim=keepdim)
+    return mask
+
+
+def _get_masked_fn(fn):
+    if fn == "all":
+        return _masked_all
+    return getattr(torch.masked, fn)
+
+
+def _torch_reduce_all(fn):
+    def reduce_all(self):
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask().values() if self.is_sparse else self.get_mask()
+        # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
+        # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
+        # Therefore, this implementation calculates it using the strides.
+        if fn == "all":
+            result_data = masked_fn(data, mask=mask)
+
+        elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
+            sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
+            indices = (
+                data.to_sparse_coo().indices()
+                if not self.is_sparse_coo()
+                else data.indices()
+            )
+            idx = indices.unbind(1)[sparse_idx]
+            stride = data.size().numel() / torch.tensor(
+                data.size(), device=data.device
+            ).cumprod(0)
+            result_data = torch.sum(idx * stride)
+
+        # we simply pass in the values for sparse COO/CSR tensors
+        elif self.is_sparse:
+            result_data = masked_fn(masked_tensor(data.values(), mask))
+
+        else:
+            result_data = masked_fn(self, mask=mask)
+
+        return as_masked_tensor(result_data, torch.any(mask))
+
+    return reduce_all
+
+
+def _torch_reduce_dim(fn):
+    def reduce_dim(self, dim, keepdim=False, dtype=None):
+        if self.is_sparse:
+            msg = (
+                f"The sparse version of {fn} is not implemented in reductions.\n"
+                "If you would like this operator to be supported, please file an issue for a feature request at "
+                "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+                "In the case that the semantics for the operator are not trivial, it would be appreciated "
+                "to also include a proposal for the semantics."
+            )
+            warnings.warn(msg, stacklevel=2)
+            return NotImplemented
+        if not is_masked_tensor(self):
+            raise TypeError("Input to reduce_dim must be a MaskedTensor")
+
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask()
+        if fn == "all":
+            result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
+        else:
+            result_data = masked_fn(
+                self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
+            )
+        return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
+
+    return reduce_dim
+
+
+def _torch_reduce(fn):
+    def reduce_fn(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        return _torch_reduce_dim(fn)(*args, **kwargs)
+
+    return reduce_fn
+
+
+def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
+    return input, dim, keepdim, dtype
+
+
+def _torch_grad_reduce(fn):
+    def grad_reduce(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        # TODO: autograd.Function doesn't support kwarg
+        input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
+        return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
+
+    return grad_reduce
+
+
+REDUCE_NAMES = [
+    "sum",
+    "mean",
+    "amin",
+    "amax",
+    "argmin",
+    "argmax",
+    "prod",
+    "all",
+    "norm",
+    "var",
+    "std",
+]
+
+NATIVE_REDUCE_MAP = {
+    getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
+}
+TORCH_REDUCE_MAP = {
+    getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+TENSOR_REDUCE_MAP = {
+    getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+
+NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
+TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
+TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
+
+
+def _is_reduction(fn):
+    return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
+
+
+def _apply_reduction(fn, *args, **kwargs):
+    if fn in NATIVE_REDUCE_MAP:
+        return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TORCH_REDUCE_MAP:
+        return TORCH_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TENSOR_REDUCE_MAP:
+        return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e04ee6e810a7418829b68323097612391017b14e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+UNARY_NAMES = [
+    "abs",
+    "absolute",
+    "acos",
+    "arccos",
+    "acosh",
+    "arccosh",
+    "angle",
+    "asin",
+    "arcsin",
+    "asinh",
+    "arcsinh",
+    "atan",
+    "arctan",
+    "atanh",
+    "arctanh",
+    "bitwise_not",
+    "ceil",
+    "clamp",
+    "clip",
+    "conj_physical",
+    "cos",
+    "cosh",
+    "deg2rad",
+    "digamma",
+    "erf",
+    "erfc",
+    "erfinv",
+    "exp",
+    "exp2",
+    "expm1",
+    "fix",
+    "floor",
+    "frac",
+    "lgamma",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "logit",
+    "i0",
+    "isnan",
+    "nan_to_num",
+    "neg",
+    "negative",
+    "positive",
+    "pow",
+    "rad2deg",
+    "reciprocal",
+    "round",
+    "rsqrt",
+    "sigmoid",
+    "sign",
+    "sgn",
+    "signbit",
+    "sin",
+    "sinc",
+    "sinh",
+    "sqrt",
+    "square",
+    "tan",
+    "tanh",
+    "trunc",
+]
+
+INPLACE_UNARY_NAMES = [
+    n + "_"
+    for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
+]
+
+# Explicitly tracking functions we know are currently not supported
+# This might be due to missing code gen or because of complex semantics
+UNARY_NAMES_UNSUPPORTED = [
+    "atan2",
+    "arctan2",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "copysign",
+    "float_power",
+    "fmod",
+    "frexp",
+    "gradient",
+    "imag",
+    "ldexp",
+    "lerp",
+    "logical_not",
+    "hypot",
+    "igamma",
+    "igammac",
+    "mvlgamma",
+    "nextafter",
+    "polygamma",
+    "real",
+    "remainder",
+    "true_divide",
+    "xlogy",
+]
+
+
+def _unary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError(
+            "MaskedTensor unary ops require that len(kwargs) == 0. "
+            "If you need support for this, please open an issue on Github."
+        )
+    for a in args[1:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor unary ops do not support additional Tensor arguments"
+            )
+
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_mask
+    )
+    data_args, _data_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_data
+    )
+
+    if args[0].layout == torch.sparse_coo:
+        data_args[0] = data_args[0].coalesce()
+        s = data_args[0].size()
+        i = data_args[0].indices()
+        data_args[0] = data_args[0].coalesce().values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size=s)
+
+    elif args[0].layout == torch.sparse_csr:
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        return _wrap_result(result_data, mask_args[0])
+
+
+def _torch_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=False)
+
+    return unary_fn
+
+
+def _torch_inplace_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=True)
+
+    return unary_fn
+
+
+NATIVE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
+}
+NATIVE_INPLACE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_unary(name)
+    for name in INPLACE_UNARY_NAMES
+}
+
+NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
+NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
+
+
+def _is_native_unary(fn):
+    return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
+
+
+def _apply_native_unary(fn, *args, **kwargs):
+    if fn in NATIVE_UNARY_FNS:
+        return NATIVE_UNARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_UNARY_FNS:
+        return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..956bbb837ecd54ddefd8c6db90e8fa986ebb8009
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c85e0944543799631a76ea0f79f65d519c5ed7e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0046f2684c2f07489b7b52dcfc89cfad6ff93dff
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/event.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d6ca196d1568e16b24d0827d7018fb204158756
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mps/__pycache__/profiler.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b90b3d6a8d97fac7c623cb0678ac7a159456f9ca
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c7d0cc0ada3a35f5be8eb2bb9aa98a9a61907c4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7eff10bac06a3b745d7a6a80d54e9a43cdc9ae4e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/memory.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c62b15d282b0451b3651efd5630c27a328f1dc1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/mtia/__pycache__/mtia_graph.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e8df994bbf5b6bd83934849d9a720e46f6140f7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..392350f95a8801d37ad075666c253c128c03ea23
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2677f3c7069caea62a5e58136a7c56994418d388
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/pool.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e02550c62265bb4aa1378a5ff299574a37f8250e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/queue.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d979f02c95ebda819c9ca1bce63f8db0be31e2d2
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ba023a07cd5eb6c65549961ea79105b31128db7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c768fe3c6428a57e017fe6614515c1e09c26460
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..991c0d867b42aa947b82ee8ca9a1a2e3395f94a4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba8cd8777ea672d2ae2c5f543db562cd8d55de48
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lower_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62472e24ecb3a66395ecd585b3605b5c193d30a5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/__pycache__/_lowered_aoti_module.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa97bc30b4a047e270dd812b5676de354bf675f3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lower_utils.py
@@ -0,0 +1,101 @@
+import types
+
+import torch
+import torch.utils._pytree as pytree
+from torch.export import ExportedProgram
+from torch.export.pt2_archive._package import AOTI_FILES, package_pt2
+from torch.types import FileLike
+
+from ._lowered_aoti_module import LoweredBackendModule
+
+
+def get_new_ep_with_flat_inputs_outputs(ep: ExportedProgram) -> ExportedProgram:
+    class FlattenedModule(torch.nn.Module):
+        def __init__(
+            self,
+            original_module: torch.fx.GraphModule,
+            in_spec: pytree.TreeSpec,
+            out_spec: pytree.TreeSpec,
+        ) -> None:
+            super().__init__()
+            self.original_module = original_module
+            self.in_spec = in_spec
+            self.out_spec = out_spec
+
+        def forward(self, *flat_inputs):  # type: ignore[no-untyped-def]
+            # Unflatten inputs to original structure
+            inputs = pytree.tree_unflatten(flat_inputs, self.in_spec)
+            args, kwargs = inputs
+            outputs = self.original_module(*args, **kwargs)
+            # Flatten outputs
+            flat_outputs, _ = pytree.tree_flatten(outputs)
+            return tuple(flat_outputs)
+
+    flattened_module = FlattenedModule(
+        ep.module(), ep.call_spec.in_spec, ep.call_spec.out_spec
+    )
+    args, kwargs = ep.example_inputs
+    flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+    flat_ep = torch.export.export(flattened_module, tuple(flat_inputs))
+
+    return flat_ep
+
+
+def lower_exported_program(
+    exported_program: ExportedProgram, model_name: str, backend_id: str
+) -> tuple[ExportedProgram, AOTI_FILES]:
+    """
+    Lower an exported program to AOTInductor and return a delegate ExportedProgram
+    with the `executorch_call_delegate` HOP
+    """
+    args, kwargs = exported_program.example_inputs
+    out_spec = exported_program.call_spec.out_spec
+    flat_ep = get_new_ep_with_flat_inputs_outputs(exported_program)
+    flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+
+    aoti_files = torch._inductor.aot_compile(
+        flat_ep.module(), tuple(flat_inputs), options={"aot_inductor.package": True}
+    )
+    assert isinstance(aoti_files, list)
+
+    lowered_aoti_module = LoweredBackendModule(
+        flat_ep, backend_id, module_name=model_name
+    )
+
+    def patched_forward(self, *args, **kwargs):  # type: ignore[no-untyped-def]
+        flat_inputs, _ = pytree.tree_flatten((args, kwargs))
+        flat_outputs = torch._higher_order_ops.executorch_call_delegate(
+            self, *flat_inputs
+        )
+        if out_spec is not None and flat_outputs is not None:
+            return pytree.tree_unflatten(flat_outputs, out_spec)
+        else:
+            return flat_outputs
+
+    lowered_aoti_module.forward = types.MethodType(patched_forward, lowered_aoti_module)  # type: ignore[method-assign]
+
+    aoti_delegate_ep = torch.export.export(lowered_aoti_module, args, kwargs)
+
+    return aoti_delegate_ep, aoti_files
+
+
+def package_nativert_with_aoti_delegate(
+    f: FileLike,
+    model_name: str,
+    backend_id: str,
+    original_ep: ExportedProgram,
+    delegate_ep: ExportedProgram,
+    delegate_files: AOTI_FILES,
+) -> None:
+    """
+    Package a pt2 archive file that can be consumed by NativeRT with AOTI Delegate
+    """
+    package_pt2(
+        f,
+        exported_programs={
+            model_name: original_ep,
+            f"{model_name}-{backend_id}": delegate_ep,
+        },
+        aoti_files={f"{model_name}-{backend_id}": delegate_files},  # type: ignore[dict-item]
+    )
+    return
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08e83211ef330be11788b5ca82a1dcc9a0c9f9d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nativert/backends/_lowered_aoti_module.py
@@ -0,0 +1,31 @@
+import torch
+from torch.export import ExportedProgram
+
+
+class LoweredBackendModule(torch.nn.Module):
+    def __init__(
+        self,
+        original_exported_program: ExportedProgram,
+        backend_id: str,
+        *,
+        module_name: str | None = None,
+    ) -> None:
+        super().__init__()
+        self._backend_id = backend_id
+        self._module_name = module_name
+        self._original_exported_program = original_exported_program
+
+    @property
+    def backend_id(self) -> str:
+        return self._backend_id
+
+    @property
+    def module_name(self) -> str | None:
+        return self._module_name
+
+    @property
+    def original_module(self) -> ExportedProgram:
+        return self._original_exported_program
+
+    def forward(self, *args, **kwargs):  # type: ignore[no-untyped-def]
+        return torch._higher_order_ops.executorch_call_delegate(self, *args, **kwargs)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9ead589613711d03f1e23b3418723f08d0cb14f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28de7444db876009fe4d8c7cdc93862b8d9c55a4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2c601477ebed438b9157b4ab1565e762b8aab3c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f72aa3a4654101a1005e86301fac7c7343a94b0a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..234dfcd7bfdaddf79ba44fac4b3f6752efebaae9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/sdpa.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
new file mode 100644
index 0000000000000000000000000000000000000000..b347258b5f463789aa1425f9a8d61de1e306bee7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
@@ -0,0 +1,116 @@
+from typing import *  # noqa: F403
+
+import torch
+from torch.fx.experimental._constant_symnode import ConstantIntNode
+
+
+__all__ = ["NestedIntNode"]
+
+
+# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
+def _eq(lhs: Any, rhs: Any) -> bool:
+    return (
+        isinstance(lhs, NestedIntNode)
+        and isinstance(rhs, NestedIntNode)
+        and lhs.t_id == rhs.t_id
+        and lhs.coeff == rhs.coeff
+    )
+
+
+def _ge(lhs: Any, rhs: Any) -> bool:
+    if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
+        if lhs.t_id == rhs.t_id:
+            return lhs.coeff >= rhs.coeff
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(lhs, NestedIntNode):
+        if rhs.is_constant() and rhs.constant_int() <= 2:
+            return True
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(rhs, NestedIntNode):
+        if lhs.is_constant() and lhs.constant_int() < 2:
+            return False
+        raise ValueError("ge: relation is indeterminate")
+    else:
+        raise ValueError("inputs unsupported")
+
+
+class NestedIntNode:
+    def __init__(self, t_id: int, coeff: int) -> None:
+        self.t_id = t_id
+        self.coeff = coeff
+
+    def nested_int_coeff(self) -> int:
+        return self.coeff
+
+    def maybe_as_int(self) -> Optional[int]:
+        return None
+
+    def is_int(self) -> bool:
+        return True
+
+    def is_float(self) -> bool:
+        return False
+
+    def is_bool(self) -> bool:
+        return False
+
+    def is_nested_int(self) -> bool:
+        return True
+
+    def clone(self) -> "NestedIntNode":
+        return self
+
+    def _str(self) -> Any:
+        if self.coeff == 1:
+            return f"j{self.t_id}"
+        return f"{self.coeff}*j{self.t_id}"
+
+    def str(self) -> Any:
+        return self._str()
+
+    def __str__(self) -> Any:
+        return self._str()
+
+    def __repr__(self) -> Any:
+        return self._str()
+
+    def _graph_repr(self) -> Any:
+        return self._str()
+
+    def mul(self, other: Any) -> "NestedIntNode":
+        if other.is_constant():
+            other = other.constant_int()
+        else:
+            raise ValueError(f"unsupported: {type(other)}")
+        return NestedIntNode(self.t_id, self.coeff * other)
+
+    def eq(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_eq(self, other))
+
+    def ne(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _eq(self, other))
+
+    def gt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(other, self))
+
+    def lt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(self, other))
+
+    def le(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(other, self))
+
+    def ge(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(self, other))
+
+    def is_symbolic(self) -> bool:
+        return False
+
+    def nested_int(self) -> int:
+        return self.t_id
+
+    def is_constant(self) -> bool:
+        return False
+
+    def wrap_int(self, num: int) -> ConstantIntNode:
+        assert type(num) is int
+        return ConstantIntNode(num)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4e3fecf4e6cc3947d07757896a5eb1e9d7935b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
@@ -0,0 +1,676 @@
+# mypy: allow-untyped-defs
+from typing import *  # noqa: F403
+
+import torch
+from torch._C import DispatchKey, DispatchKeySet
+from torch._prims_common import is_expandable_to
+from torch.nested._internal.nested_int import NestedIntNode
+from torch.utils.weak import WeakTensorKeyDictionary
+
+
+_tensor_id_counter = 0
+_tensor_symint_registry = WeakTensorKeyDictionary()
+
+
+def get_tensor_symint(tensor, *, coeff=1):
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
+
+    # NB: Only FakeTensor is associated with a memo
+    tensor = mb_unwrap_functional_tensor(tensor)
+    if isinstance(tensor, FakeTensor):
+        return tensor.get_nested_int(coeff=coeff)
+
+    global _tensor_id_counter
+
+    tensor_symint = _tensor_symint_registry.get(tensor)
+    if tensor_symint is None:
+        tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff))
+        _tensor_id_counter += 1
+        _tensor_symint_registry[tensor] = tensor_symint
+    return tensor_symint
+
+
+# SDPA metadata; max / min seqlens are needed for e.g. flash
+def _get_sdpa_extreme_seqlen(func, tensor):
+    return int(func(tensor).item())
+
+
+def _store_val_in_tensor(val) -> torch.Tensor:
+    # hack to get dynamic shapes support: store in a (val, 0) shaped tensor
+    return torch.zeros(val, 0)
+
+
+def _load_val_from_tensor(t: torch.Tensor):
+    return t.shape[0]
+
+
+# serialization function must be defined at top level
+def _rebuild_njt(constructor_kwargs):
+    return NestedTensor(**constructor_kwargs)
+
+
+class NestedTensor(torch.Tensor):
+    _values: torch.Tensor  # type: ignore[assignment]
+    _offsets: torch.Tensor
+    _lengths: Optional[torch.Tensor]
+    # NOTE [ Nested ints for ragged sizes and strides ]
+    #
+    # Jagged layout tensors are tensors that represent a n-dim tensor with a
+    # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
+    # a jagged tensor with outer shape [B, x, D] is represented internally by a
+    # tensor with shape [sum(x), D] where we introduce what we call a nested int
+    # denoted as "x" here (but sometimes denoted with "*" to
+    # represent the ragged dimension, and sum(x) represents the dim of the inner
+    # tensor or equivalently the sum of all the sizes of the constituent
+    # tensors' varying lengths.
+    #
+    # We also use nested ints to represent the strides of this tensor.
+    # For example, a jagged tensor with shape [B, x, D] can be strided in two
+    # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
+    _size: tuple[int, ...]
+    _strides: tuple[int, ...]
+    # Indicates that the nth dimension is ragged
+    _ragged_idx: int
+    _metadata_cache: Dict[str, Any]
+
+    @staticmethod
+    def __new__(
+        cls,
+        values,
+        offsets,
+        *,
+        lengths=None,
+        **kwargs,
+    ):
+        ks = DispatchKeySet(DispatchKey.NestedTensor)
+        ks = ks.add(DispatchKey.AutogradNestedTensor)
+
+        # Only support jagged for now.
+        assert offsets is not None
+        assert offsets.ndim == 1
+        assert not isinstance(values, NestedTensor)
+        assert values.device == offsets.device
+
+        # Query cache for the symint associated with offsets or lengths
+        # (create a new one if needed).
+        ragged_source = offsets if lengths is None else lengths
+        ragged_size = get_tensor_symint(ragged_source, coeff=1)
+        _ragged_idx = kwargs.get("_ragged_idx", 1)
+        B = offsets.shape[0] - 1
+        if lengths is not None:
+            assert B == lengths.shape[0]
+
+        # subtract 1 to convert to values dim space
+        r = _ragged_idx - 1
+        _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
+        stride = values.stride()
+        _strides = (ragged_size * stride[r], *stride)
+
+        r = torch.Tensor._make_wrapper_subclass(
+            cls,
+            _size,
+            _strides,
+            0,
+            torch.contiguous_format,
+            values.dtype,
+            torch.jagged,
+            values.device,
+            False,
+            kwargs.get("requires_grad", False),
+            "sizes",
+            False,
+            True,  # dispatch_layout
+            ks,
+            # don't try to calculate storage based on non-zero size
+            storage_size=values.untyped_storage().size(),
+        )
+        r._ragged_idx = _ragged_idx
+        r._size = _size
+        r._strides = _strides
+
+        return r
+
+    def __init__(self, values, offsets, *, lengths=None, **kwargs) -> None:
+        super().__init__()
+
+        self._values = values
+        self._offsets = offsets
+        self._lengths = lengths
+
+        # holds properties that are computed lazily
+        self._metadata_cache = kwargs.get("_metadata_cache") or {}
+
+        # collapsed ragged dim must always be dynamic
+        torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
+        torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
+
+        # min / max sequence length should be dynamic if present
+        max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
+        if max_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
+        min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
+        if min_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
+
+    def values(self):
+        # dispatch to get proper view relationship
+        return torch._nested_get_values(self)  # type: ignore[attr-defined]
+
+    def offsets(self):
+        return self._offsets
+
+    def lengths(self):
+        return self._lengths
+
+    # Private accessor functions for min / max sequence length. They're
+    # purposefully not @properties because those don't work with PT2 (yet).
+    # These compute / cache if not present.
+    # TODO: Revisit this when @properties are better supported by PT2. I think the ideal
+    # state would be to have public @properties for min / max sequence length that compile
+    # (including setters).
+    def _get_max_seqlen(self):
+        max_seqlen_tensor = self._max_seqlen_tensor
+        if max_seqlen_tensor is None:
+            # compute & cache
+            max_val = _get_sdpa_extreme_seqlen(
+                torch.max,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            max_seqlen_tensor = _store_val_in_tensor(max_val)
+            self._metadata_cache["max_seqlen"] = max_seqlen_tensor
+        return _load_val_from_tensor(max_seqlen_tensor)
+
+    def _get_min_seqlen(self):
+        min_seqlen_tensor = self._min_seqlen_tensor
+        if min_seqlen_tensor is None:
+            # compute & cache
+            min_val = _get_sdpa_extreme_seqlen(
+                torch.min,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            min_seqlen_tensor = _store_val_in_tensor(min_val)
+            self._metadata_cache["min_seqlen"] = min_seqlen_tensor
+        return _load_val_from_tensor(min_seqlen_tensor)
+
+    # Private accessors used for treating min / max seqlen as inner tensors for
+    # flatten / unflatten. These must be properties to work with the traceable wrapper
+    # subclass logic. These do not compute / cache if not present.
+    @property
+    def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("max_seqlen", None)
+
+    @_max_seqlen_tensor.setter
+    def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["max_seqlen"] = val
+
+    @property
+    def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("min_seqlen", None)
+
+    @_min_seqlen_tensor.setter
+    def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["min_seqlen"] = val
+
+    # These are old private @property accessors that are kept around for internal BC
+    # reasons. TODO: Remove these!
+    @property
+    def _max_seqlen(self):
+        return self._get_max_seqlen()
+
+    @property
+    def _min_seqlen(self):
+        return self._get_min_seqlen()
+
+    # Convenience accessors that return a min / max seqlen if one is present and do NOT
+    # compute / cache them if they're not.
+    @property
+    def _maybe_max_seqlen(self) -> Optional[int]:
+        mt = self._max_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    @property
+    def _maybe_min_seqlen(self) -> Optional[int]:
+        mt = self._min_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    def _is_contiguous_or_false(self):
+        if self.lengths() is not None:
+            return False
+        from torch._prims_common import is_contiguous_for_memory_format_or_false
+
+        return is_contiguous_for_memory_format_or_false(
+            self._values, memory_format=torch.contiguous_format
+        )
+
+    def __repr__(self) -> str:  # type: ignore[override]
+        # We should implement this in torch/_tensor_str.py instead
+        grad_fn_str = (
+            f", requires_grad={self.requires_grad}" if self.requires_grad else ""
+        )
+
+        if self.grad_fn:
+            grad_fn_str = f", grad_fn={self.grad_fn}"
+
+        return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
+
+    # TODO: Remove this in favor of the default tensor subclass serialization logic.
+    # We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
+    def __reduce_ex__(self, proto):
+        state = torch._utils._get_obj_state(self)
+
+        # Cached PyCapsules for sizes / strides are not serializable.
+        # See Note [Tensor Subclass custom size/stride caching strategy]
+        self._clear_non_serializable_cached_data()
+        # SymNodes are not serializable
+        assert "_size" in state and "_strides" in state
+        state = dict(state)
+        del state["_size"]
+        del state["_strides"]
+
+        func = _rebuild_njt
+        constructor_kwargs = {
+            "values": self._values,
+            "offsets": self._offsets,
+            "lengths": self._lengths,
+            "_ragged_idx": self._ragged_idx,
+            "_metadata_cache": self._metadata_cache,
+            "requires_grad": self.requires_grad,
+        }
+        args = (constructor_kwargs,)
+        return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
+
+    def __tensor_flatten__(self):
+        ctx = {
+            "requires_grad": self.requires_grad,
+            "ragged_idx": self._ragged_idx,
+        }
+        inner_tensors = ["_values", "_offsets"]
+        if self._lengths is not None:
+            inner_tensors.append("_lengths")
+        if self._min_seqlen_tensor is not None:
+            inner_tensors.append("_min_seqlen_tensor")
+        if self._max_seqlen_tensor is not None:
+            inner_tensors.append("_max_seqlen_tensor")
+        return inner_tensors, ctx
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
+        from torch._subclasses.fake_tensor import FakeTensor
+
+        # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
+        assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
+        values = inner_tensors["_values"]
+        offsets = inner_tensors["_offsets"]
+        lengths = inner_tensors.get("_lengths", None)
+        min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
+        max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
+
+        metadata_cache = {}
+        if min_seqlen_tensor is not None:
+            metadata_cache["min_seqlen"] = min_seqlen_tensor
+        if max_seqlen_tensor is not None:
+            metadata_cache["max_seqlen"] = max_seqlen_tensor
+        ragged_idx = meta["ragged_idx"]
+
+        # Alternatively, we could make it the caller's responsibility to
+        # cache it. But this heuristic seems simple enough.
+        ragged_source = offsets if lengths is None else lengths
+        if isinstance(ragged_source, FakeTensor):
+            ragged_size = outer_size[ragged_idx]
+            ragged_source.nested_int_memo = ragged_size
+
+        return NestedTensor(
+            values,
+            offsets=offsets,
+            lengths=lengths,
+            requires_grad=meta["requires_grad"],
+            _ragged_idx=ragged_idx,
+            _metadata_cache=metadata_cache,
+        )
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):  # type: ignore[override]
+        # If you're wondering why there's a nested tensor with one of its
+        # size = -1, see note: [NJT outer_size in AOTDispatcher]
+        kwargs = {} if kwargs is None else kwargs
+
+        # Lazy import to avoid circular dependency
+        from .ops import lookup_jagged
+
+        fn = lookup_jagged(func, *args, **kwargs)
+        if fn is not None:
+            return fn(*args, **kwargs)
+
+        # Poor man's redispatch for composite ops. This becomes relevant under inference
+        # mode, where disabling autograd key dispatch prevents decomposition.
+        all_dks = (
+            # We want to handle both the cases where NestedTensor overrides the
+            # composite implicit autograd kernel, and the case where it doesn't.
+            # Prioritize calling into NestedTensor's kernel if it exists.
+            torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
+            torch._C.DispatchKey.CompositeImplicitAutograd,
+        )
+        for dk in all_dks:
+            if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
+                with torch.overrides.enable_reentrant_dispatch():
+                    return func._op_dk(dk, *args, **kwargs)
+
+        raise NotImplementedError(func)
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+
+        from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
+
+        from .ops import jagged_torch_function
+
+        # This should be removed after
+        # https://github.com/pytorch/pytorch/pull/125941/ lands
+        with maybe_enable_thunkify():
+            try:
+                return jagged_torch_function(func, *args, **kwargs)
+            except NotImplementedError:
+                pass
+            with torch._C.DisableTorchFunctionSubclass():
+                return func(*args, **kwargs)
+
+
+# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
+# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
+# internal BC period has passed.
+
+
+# Not actually a view!
+class ViewBufferFromNested(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x: NestedTensor):  # type: ignore[override]
+        ctx.save_for_backward(x.offsets())
+        ctx.metadata_cache = x._metadata_cache
+        ctx.ragged_idx = x._ragged_idx
+        return x._values
+
+    @staticmethod
+    def backward(ctx, gO: torch.Tensor):  # type: ignore[override]
+        (offsets,) = ctx.saved_tensors
+        return NestedTensor(
+            gO,
+            offsets=offsets,
+            _metadata_cache=ctx.metadata_cache,
+            _ragged_idx=ctx.ragged_idx,
+        )
+
+
+# Not actually a view!
+class ViewNestedFromBuffer(torch.autograd.Function):
+    @staticmethod
+    def forward(  # pyrefly: ignore  # bad-override
+        ctx,
+        values: torch.Tensor,
+        offsets: torch.Tensor,
+        metadata_cache: Optional[Dict[str, Any]] = None,
+    ):  # type: ignore[override]
+        # maintain BC with this usages of this where the seqlens are stuffed
+        # directly into the metadata cache as non-Tensors / ints
+        if metadata_cache is not None:
+            min_seqlen = metadata_cache.get("min_seqlen", None)
+            max_seqlen = metadata_cache.get("max_seqlen", None)
+            if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
+                metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
+            if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
+                metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
+        return NestedTensor(
+            values.detach(),
+            offsets=offsets,
+            _metadata_cache=metadata_cache,
+        )
+
+    @staticmethod
+    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
+        return gO._values, None, None
+
+
+def buffer_from_jagged(jagged):
+    return ViewBufferFromNested.apply(jagged)
+
+
+# Need to make it obvious that users should be passing in offsets
+def jagged_from_list(
+    tensors: List[torch.Tensor],
+    offsets: Optional[torch.Tensor],
+    dtype=None,
+    device=None,
+) -> tuple[NestedTensor, torch.Tensor]:
+    """Constructs a NestedTensor backed by jagged layout from a list of tensors"""
+
+    if len(tensors) == 0:
+        raise RuntimeError("Cannot construct a nested tensor from an empty tensor list")
+    if not len(set(t.dtype for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dtype"
+        )
+    if not len(set(t.device for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must be on the same device"
+        )
+    if not len(set(t.dim() for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dim"
+        )
+    component_dim = tensors[0].dim()
+    if component_dim == 0:
+        raise RuntimeError(
+            "Cannot construct a nested tensor from a list of zero-dim tensors"
+        )
+
+    # Check that the NT is representable by the jagged layout, which
+    # allows for a single ragged dimension after the batch dim.
+    # e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc.
+    sizes = [t.shape for t in tensors]
+    ragged_idx = None
+    for d in range(component_dim):
+        dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes)
+        if dim_is_ragged:
+            if ragged_idx is None:
+                # add 1 to convert to outer NJT dim space
+                ragged_idx = d + 1
+            else:
+                raise RuntimeError(
+                    "Cannot represent given tensor list as a nested tensor with the jagged layout. "
+                    "Note that the jagged layout only allows for a single ragged dimension. "
+                    "For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim."
+                )
+
+    # allow for a rectangular NJT and default the ragged dim next to the batch dim
+    if ragged_idx is None:
+        ragged_idx = 1
+
+    # Set properties appropriately.
+    values = torch.cat(tensors, dim=(ragged_idx - 1))
+    to_kwargs = {}
+    if device is not None:
+        to_kwargs["device"] = device
+    if dtype is not None:
+        to_kwargs["dtype"] = dtype
+    values = values.to(**to_kwargs)
+
+    # Calculate jagged offsets if not provided.
+    if offsets is None:
+        # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+        # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
+        # an extra leaf tensor during the forward, potentially resolving compatibility issues.
+        offsets = torch.cat(
+            [
+                torch.zeros(1, dtype=torch.int64, device=values.device),
+                torch.tensor(
+                    [s[ragged_idx - 1] for s in sizes], device=values.device
+                ).cumsum(dim=0),
+            ]
+        )
+
+    # compute this now since it's easy
+    min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors)
+    max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors)
+    ret_nt = nested_view_from_values_offsets(
+        values,
+        offsets,
+        min_seqlen=min_seqlen,
+        max_seqlen=max_seqlen,
+        ragged_idx=ragged_idx,
+    )
+    return (ret_nt, offsets)  # type: ignore[return-value]
+
+
+def jagged_from_tensor_and_lengths(
+    tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
+) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
+    """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
+    batch_size = tensor.shape[0]
+    if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
+        lengths.shape, (batch_size,)
+    ):
+        start_list = starts.expand(batch_size)
+        length_list = lengths.expand(batch_size)
+    else:
+        raise RuntimeError(
+            "When constructing a jagged nested tensor using narrow(), "
+            "your start and length must be Tensors that broadcast to input.shape[0]"
+        )
+
+    # Calculate jagged offsets
+    assert len(tensor.shape) >= 2, (
+        "tensor must at least be 2D for the nested narrow op to work"
+    )
+    max_seq_len = tensor.shape[1]
+    offset_lengths = max_seq_len * torch.arange(
+        0, batch_size, dtype=torch.int64, device=tensor.device
+    )
+    # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+    offsets = torch.cat(
+        [
+            start_list + offset_lengths,
+            (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
+        ]
+    )
+
+    # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
+    if len(tensor.shape) > 2:
+        values = tensor.view(-1, *tensor.shape[2:])
+    else:
+        values = tensor.view(-1)
+
+    # Check if offsets and lengths make it possibly contiguous and return a regular NT
+    is_contiguous = True
+    orig_dim = tensor.shape[1]
+    if torch.any(length_list[1:-1].ne(orig_dim)):
+        is_contiguous = False
+    if torch.any(offsets[1:-2].diff().ne(orig_dim)):
+        is_contiguous = False
+    if offsets[0] + length_list[0] != orig_dim:
+        is_contiguous = False
+
+    actual_max_seqlen = int(torch.max(lengths).item())
+    min_seqlen = int(torch.min(lengths).item())
+
+    if is_contiguous:
+        ret_nt = nested_view_from_values_offsets(
+            values[offsets[0] : offsets[-1]],
+            offsets - offsets[0],
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+    else:
+        ret_nt = nested_view_from_values_offsets_lengths(
+            values,
+            offsets,
+            length_list,
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+
+    return (ret_nt, offsets, None if is_contiguous else length_list)
+
+
+# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
+# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
+# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
+# This arg is otherwise unused.
+_dummy_instance: Optional[torch.Tensor] = None
+
+
+def _nt_view_dummy() -> torch.Tensor:
+    global _dummy_instance
+    if _dummy_instance is None:
+        _dummy_instance = NestedTensor(
+            values=torch.zeros(3, 3, device="meta"),
+            offsets=torch.zeros(3, device="meta", dtype=torch.int64),
+        ).detach()
+    return _dummy_instance
+
+
+def nested_view_from_values_offsets(
+    values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        None,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_view_from_values_offsets_lengths(
+    values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        lengths,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_from_padded(
+    padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_from_padded_tensor(
+        padded,
+        offsets,
+        _nt_view_dummy(),
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+        sum_S,
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..200ccd653f6c3b4e9eeca8c28468c362cae93e86
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/ops.py
@@ -0,0 +1,2748 @@
+# mypy: allow-untyped-defs
+import functools
+import math
+import operator
+from typing import *  # noqa: F403
+
+import torch
+import torch.nn.functional as F
+from torch.fx.operator_schemas import normalize_function
+from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
+
+from .nested_tensor import NestedTensor
+
+
+__all__: list[Any] = []
+
+JAGGED_OPS_TABLE: Dict[Any, Any] = {}
+
+
+def _get_padding_value(dtype, padding_type):
+    if dtype.is_floating_point:
+        return (
+            torch.finfo(dtype).max if padding_type == "max" else torch.finfo(dtype).min
+        )
+    elif dtype == torch.int64:
+        # Largest int64 value exactly representable in float64 (IEEE 754 double precision).
+        # Avoids overflow when padding_value is passed as double to _jagged_to_padded_dense_forward.
+        int64_safe_max = (1 << 53) - 1
+        int64_safe_min = -int64_safe_max
+        return int64_safe_max if padding_type == "max" else int64_safe_min
+    else:
+        return (
+            torch.iinfo(dtype).max if padding_type == "max" else torch.iinfo(dtype).min
+        )
+
+
+def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
+    from torch._prims_common import canonicalize_dims
+
+    if isinstance(dim, (tuple, list)):
+        output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
+        # ensure no duplicates, which can result from both batch and ragged mapping to 0
+        return type(output)(dict.fromkeys(output))
+
+    if canonicalize:
+        dim = canonicalize_dims(ndim, dim)
+
+    assert dim >= 0 and dim < ndim  # pyrefly: ignore [unsupported-operation]
+
+    # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
+    # For other dims, subtract 1 to convert to inner space.
+    return (
+        # pyrefly: ignore [unsupported-operation]
+        ragged_dim - 1 if dim == 0 else dim - 1
+    )
+
+
+def _wrap_jagged_dim(
+    ndim,
+    dim,
+    ragged_dim,
+    op_name,
+    convert_to_inner_dim=True,
+    allow_ragged_dim=False,
+    allow_batch_dim=False,
+):
+    from torch._prims_common import canonicalize_dims
+
+    wrapped = canonicalize_dims(ndim, dim)
+    if wrapped == ragged_dim and not allow_ragged_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
+    elif wrapped == 0 and not allow_batch_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
+    ret = (
+        _outer_to_inner_dim(ndim, wrapped, ragged_dim)
+        if convert_to_inner_dim
+        else wrapped
+    )
+    if allow_batch_dim:
+        # Need to disambiguate whether we're operating on the batch dim or not.
+        # Operating on dim=1 -> dim=0 after the inner dim conversion.
+        operating_on_batch = wrapped == 0
+        return (ret, operating_on_batch)
+    return ret
+
+
+def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
+    """
+    For NestedTensor operators,
+    wraps dimensions to non-negative values,
+    and returns metadata related to reduction dimension(s).
+    """
+    from torch._prims_common import canonicalize_dims
+
+    assert isinstance(dims, (tuple, list)), (
+        f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
+    )
+
+    wrapped_dims = [
+        canonicalize_dims(ndim, d) for d in dims
+    ]  # convert all indices to non-negative values
+
+    operate_on_batch = 0 in wrapped_dims
+    operate_on_ragged = ragged_idx in wrapped_dims
+    operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
+
+    # ensure no duplicates, which can result from both batch and ragged mapping to 0
+    outer_to_inner_dim = tuple(
+        dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
+    )
+
+    return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
+
+
+def check_schema(schema_str: str, func, *args, **kwargs) -> None:
+    named_arg_types = schema_str.split(", ")
+    num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
+    min_args = len(named_arg_types) - num_optional_args
+
+    # special case: ellipses allows for any number of unchecked args at the end
+    if named_arg_types[-1] == "...":
+        named_arg_types = named_arg_types[:-1]
+    else:
+        if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
+                f"arguments and at most {len(named_arg_types)} arguments, but got: "
+                f"{len(args)} arguments"
+            )
+
+    arg_type_check_fns = {
+        "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
+        "jt": lambda x: isinstance(x, NestedTensor)
+        and x._lengths is None
+        and x._ragged_idx == 1,  # ops with "jt" require contiguous JT only
+        "jt_all": lambda x: isinstance(
+            x, NestedTensor
+        ),  # ops with "jt_all" can accept all kinds of JT
+        "any": lambda x: True,
+    }
+    for i, named_arg_type in enumerate(named_arg_types):
+        name, arg_type = named_arg_type.split(": ")
+        is_optional = arg_type.endswith("?")
+        normalized_arg_type = arg_type[:-1] if is_optional else arg_type
+        if normalized_arg_type not in arg_type_check_fns:
+            raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
+
+        if i >= len(args):
+            if not is_optional:
+                raise ValueError(
+                    f"NestedTensor {func.__name__}({schema_str}) "
+                    f"missing required argument: {name}"
+                )
+            continue
+
+        _check_fn = arg_type_check_fns[normalized_arg_type]
+
+        def check_fn(x, is_optional=is_optional):
+            if is_optional:
+                return x is None or _check_fn(x)
+            else:
+                return _check_fn(x)
+
+        if not check_fn(args[i]):
+            type_to_desc = {
+                "t": "tensor",
+                "t?": "optional tensor",
+                "jt": "contiguous jagged layout NestedTensor",
+                "jt_all": "jagged layout NestedTensor",
+                "any": "",
+            }
+
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
+                f"{type_to_desc[arg_type]}"
+            )
+
+
+def check_ragged_dim_same(
+    func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
+) -> None:
+    # Calling into .shape here
+    if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
+        raise RuntimeError(
+            f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
+            "same exact offsets tensor."
+        )
+
+
+# returns True if the raggedness-relevant portions of the NT shape
+# match those of the specified size
+def raggedness_matches(nt, size):
+    end = nt._ragged_idx + 1
+    nt_ragged = nt._size[:end]
+    size_ragged = size[:end]
+    return len(nt_ragged) == len(size_ragged) and (
+        all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
+    )
+
+
+def squeeze_leading_ones(t):
+    # Note: [ Squeezing leading ones ]
+    #
+    # Squeeze leading ones from t.
+    #
+    # We want:
+    #   (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    #   (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)  (not yet supported)
+    #
+    # 1) Squeeze extra ones and grab values from NT
+    #   (1, 1, ?, ?) -> (?, ?)   and   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    # 2) Do dense broadcasting:
+    #   (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
+    # 3) Construct nested tensor
+    #   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    #
+    # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
+    # at step (4) and we would need to update this function to record how
+    # many ones we unsqueezed.
+    while t.dim() > 0 and t.shape[0] == 1:
+        t = t.squeeze(0)
+    return t
+
+
+def register_func(tables, aten_ops, schema_str):
+    if not isinstance(aten_ops, list):
+        aten_ops = [aten_ops]
+    if not isinstance(tables, list):
+        tables = [tables]
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+
+            def get_inner(aten_op):
+                def inner(*args, **kwargs):
+                    check_schema(schema_str, func, *args, **kwargs)
+                    return func(aten_op, *args, **kwargs)
+
+                return inner
+
+            for table in tables:
+                table[aten_op] = get_inner(aten_op)
+        return func
+
+    return wrapper
+
+
+register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
+
+
+def lookup_jagged(func, *args, **kwargs) -> Callable | None:
+    dispatch_func = JAGGED_OPS_TABLE.get(func, None)
+    if dispatch_func is not None:
+        return dispatch_func
+
+    # Handle pointwise fallbacks
+    if torch.Tag.pointwise in func.tags:
+        from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+        # No pointwise ops legitimately accept nested int inputs. Without this check,
+        # they will be incorrectly interpreted as tensors.
+        # See https://github.com/pytorch/pytorch/issues/138496
+        for arg in args:
+            if is_nested_int(arg):
+                raise RuntimeError(
+                    f"NestedTensor {func.__name__}: invalid argument {arg}"
+                )
+
+        # Assume there aren't additional tensors that aren't the "unary/binary" args
+        num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
+        if num_tensor_args == 1:
+            # Build up the check schema string. The first tensor arg is assumed to be
+            # an NJT and other args are sent through as-is.
+            schema_parts = []
+            for arg in func._schema.arguments:
+                if isinstance(arg.type, torch.TensorType):
+                    schema_parts.append(f"{arg.name}: jt_all")
+                    break
+                else:
+                    schema_parts.append(f"{arg.name}: any")
+            schema_parts.append("...")
+            check_schema_str = ", ".join(schema_parts)
+            check_schema(check_schema_str, func, *args, **kwargs)
+            return functools.partial(jagged_unary_pointwise, func)
+        elif num_tensor_args == 2:
+            check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
+            return functools.partial(jagged_binary_pointwise, func)
+
+    return None
+
+
+def extract_kwargs(arg):
+    kwargs = {
+        "offsets": arg.offsets(),
+        "lengths": arg.lengths(),
+        "_metadata_cache": arg._metadata_cache,
+        "_ragged_idx": arg._ragged_idx,
+    }
+    return kwargs
+
+
+def jagged_unary_pointwise(func, *args, **kwargs):
+    # assume if we get here that there is a single NJT input in the args
+    njt = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
+        **extract_kwargs(njt),
+    )
+
+
+def jagged_binary_pointwise(func, *args, **kwargs):
+    a, b = args[0], args[1]
+    assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
+
+    mismatch_error_msg = (
+        "cannot call binary pointwise function {} with inputs of shapes {} and {}"
+    )
+    # a is NT, b is NT
+    if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
+        # ex: (B, j0, D) + (B, j0, D)
+        # ex: (B, j0, D) + (B, j0, 1)
+        if raggedness_matches(a, b._size):
+            return NestedTensor(
+                func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
+            )
+        raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
+    # either a is NT or b is NT at this point
+    a_is_nt = isinstance(a, NestedTensor)
+    extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
+
+    # === Handle broadcasting across the batch / ragged dims ===
+
+    # Easy case: take advantage of pre-existing broadcasting logic
+    # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    nt, t = (a, b) if a_is_nt else (b, a)
+    # See Note: [ Squeezing leading ones ]
+    if t.dim() > nt.dim():
+        raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
+    t_squeezed = squeeze_leading_ones(t)
+    if nt.dim() >= t_squeezed.dim() + 2:
+        lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
+        return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
+
+    # Harder case: do manual broadcasting when NT dim == non-NT dim
+    # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
+    if a.dim() == b.dim():
+        # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
+        # be (B, j0, D_0, D_1) but not yet supported
+        if a.shape[0] != b.shape[0]:
+            raise RuntimeError(
+                mismatch_error_msg.format(func.__name__, a.shape, b.shape)
+            )
+
+        from .nested_tensor import nested_from_padded
+
+        # handle broadcasting via padded dense -> jagged conversion
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        # convert dense tensor -> jagged
+        t = t.expand(
+            [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)]
+        )
+        t_as_nt = nested_from_padded(
+            t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        # function call with two NJTs
+        lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt)
+        return func(lhs, rhs, *args[2:], **kwargs)
+
+    # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
+    # that ragged dim is wrt left-most batch dim
+    raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
+
+
+def jagged_torch_function(func, *args, **kwargs):
+    # SDPA has special kernels that handle nested tensors.
+    # Dispatch to the correct implementation here
+    if func is torch._C._nn.scaled_dot_product_attention:
+        return jagged_scaled_dot_product_attention(*args, **kwargs)
+
+    if func.__name__ == "apply_":
+        func(args[0]._values, *args[1:], **kwargs)
+        return args[0]
+
+    # Handle flatten() here because it's CompositeImplicit.
+    if func.__name__ == "flatten":
+
+        def _flatten_sig(input, start_dim=0, end_dim=-1) -> None:
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+
+        # NB: stay in outer dim space because we're going to redispatch on a NT input
+        start_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["start_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+        end_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["end_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+
+        if start_dim == end_dim:
+            return inp
+
+        product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
+        new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
+
+        return inp.reshape(*new_shape)
+
+    # Handle NestedTensor share_memory_.
+    if func.__name__ == "share_memory_":
+        nt = args[0]
+
+        if nt.is_cuda:
+            return nt
+
+        names, _ = nt.__tensor_flatten__()
+        with torch._C.DisableTorchFunctionSubclass():
+            for name in names:
+                component = getattr(nt, name, None)
+                if component is not None:
+                    component.share_memory_()
+        return nt
+
+    # Handle NestedTensor is_shared.
+    if func.__name__ == "is_shared":
+        nt = args[0]
+
+        if nt.is_cuda:
+            return False
+
+        names, _ = nt.__tensor_flatten__()
+        if not names:
+            return False
+        return all(
+            getattr(nt, name) is not None and getattr(nt, name).is_shared()
+            for name in names
+        )
+
+    # Handle nested-specific input validation for CompositeImplicit rms_norm
+    if func.__name__ == "rms_norm":
+
+        def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None:
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+        normalized_shape = new_kwargs.pop("normalized_shape")
+
+        # can't normalize over the ragged dim (yet)
+        max_normalizable = inp.dim() - inp._ragged_idx - 1
+        if len(normalized_shape) > max_normalizable:
+            raise ValueError(
+                "rms_norm(): Normalization over the ragged dim not supported for nested tensors"
+            )
+
+        with torch._C.DisableTorchFunctionSubclass():
+            return func(*args, **kwargs)
+
+    raise NotImplementedError(func)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.is_non_overlapping_and_dense.default,
+        torch.ops.aten.sym_size.default,
+        torch.ops.aten.dim.default,
+        torch.ops.aten.numel.default,
+        torch.ops.aten.sym_numel.default,
+        torch.ops.aten.sym_stride.default,
+        torch.ops.aten.sym_storage_offset.default,
+    ],
+    "self: jt_all",
+)
+def tensor_attr_supported_getter(func, *args, **kwargs):
+    if func is torch.ops.aten.is_non_overlapping_and_dense.default:
+        return False
+
+    if func is torch.ops.aten.sym_size.default:
+        return args[0]._size
+
+    if func is torch.ops.aten.dim.default:
+        return len(args[0]._size)
+
+    if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
+        if args[0]._lengths is not None:
+            return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
+        return args[0]._values.numel()
+
+    if func is torch.ops.aten.sym_stride.default:
+        return args[0]._strides
+
+    if func is torch.ops.aten.sym_storage_offset.default:
+        return args[0]._values.storage_offset()
+
+
+@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
+def prim_layout_default(func, *args, **kwargs):
+    return torch.jagged
+
+
+@register_jagged_func(
+    [torch.ops.aten.size.default],
+    "self: jt_all",
+)
+def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None:
+    if func is torch.ops.aten.size.default:
+        raise RuntimeError(
+            "NestedTensor does not support directly calling torch.ops.aten.size; "
+            "please use `nested_tensor.size()` instead."
+        )
+
+
+@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
+def is_contiguous_general(func, *args, **kwargs):
+    from torch._prims_common import is_contiguous_for_memory_format
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    # If created from narrow() check for lengths
+    if inp.lengths() is not None:
+        return False
+
+    new_kwargs["memory_format"] = new_kwargs.get(
+        "memory_format", torch.contiguous_format
+    )
+    if new_kwargs["memory_format"] == torch.preserve_format:
+        return True
+    return is_contiguous_for_memory_format(inp._values, **new_kwargs)
+
+
+register_jagged_func(
+    torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
+)(is_contiguous_general)
+
+
+@register_jagged_func(
+    torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
+)
+def sym_is_contiguous_general(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    # If created from narrow() check for lengths
+    if inp.lengths() is not None:
+        return False
+
+    new_kwargs["memory_format"] = new_kwargs.get(
+        "memory_format", torch.contiguous_format
+    )
+
+    if new_kwargs["memory_format"] == torch.preserve_format:
+        return True
+
+    return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
+)
+def clone_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_meta = extract_kwargs(inp)
+
+    if inp._lengths is not None:
+        if new_kwargs["memory_format"] == torch.contiguous_format:
+            # need to copy to remove "holes" non-contiguity / lengths metadata
+            # TODO: write a kernel for this
+            from .nested_tensor import jagged_from_list
+
+            # TODO: We probably want the output to have the same ragged structure / nested int.
+            assert inp._ragged_idx == 1, (
+                "NJT with ragged_idx != 1 not supported for contiguous clone"
+            )
+            contig, _ = jagged_from_list(inp.unbind(), offsets=None)
+            return contig
+
+    return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
+
+
+@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
+def linear_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.linear_backward.default,
+    "self: jt, grad_output: jt, weight: t, output_mask: any",
+)
+def linear_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+    weight = new_kwargs.pop("weight")
+    output_mask = new_kwargs.pop("output_mask")
+
+    ds, dw, db = None, None, None
+    check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
+    if output_mask[0]:
+        ds = NestedTensor(
+            torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
+        )
+    if output_mask[1]:
+        # NB: Fold dims of values for input and grad_output to treat them as 2D. This
+        # trick avoids materializing large intermediates and immediately reducing over
+        # them via sum(). This is equivalent to computing:
+        #     torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
+        # and then summing over the leading dimensions to get a 2D weight grad.
+        grad_2d = grad_output._values.reshape(-1, weight.size(0))
+        input_2d = inp._values.reshape(-1, weight.size(1))
+        dw = torch.matmul(grad_2d.t(), input_2d)
+    if output_mask[2]:
+        # Sum over all but the last dim to get a 1D bias grad. We cannot
+        # rely on the autograd engine to reduce for us, because returning a
+        # tensor aliasing the input would violate the aten signature annotation
+        reduce_dims = tuple(range(grad_output._values.ndim - 1))
+        if reduce_dims == ():
+            db = grad_output._values.clone()
+        else:
+            db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
+    return (ds, dw, db)
+
+
+@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
+def to_dtype(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
+def to_copy_default(func, *args, **kwargs):
+    from .nested_tensor import _tensor_symint_registry
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    # don't change layout
+    new_kwargs.pop("layout")
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import (
+        FunctionalTensor,
+        mb_unwrap_functional_tensor,
+    )
+
+    ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+    new_thing = new_offsets if new_lengths is None else new_lengths
+    if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+        # Temporary hack until we have the union find
+        tgt = mb_unwrap_functional_tensor(new_thing)
+        src = mb_unwrap_functional_tensor(ragged_source)
+        tgt.nested_int_memo = src.nested_int_memo
+    else:
+        _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+    inp_kwargs = extract_kwargs(inp)
+    inp_kwargs["offsets"] = new_offsets
+    inp_kwargs["lengths"] = new_lengths
+
+    output = NestedTensor(new_values, **inp_kwargs)
+    return output
+
+
+@register_jagged_func(
+    torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
+)
+def copy_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    src = new_kwargs.pop("src")
+    if inp._size != src._size:
+        # try to recursively copy_ on unbound components to get around nested int mismatch
+        # TODO: eventually do a direct copy when this is possible
+        inp_comps = inp.unbind()
+        inp_comp_shapes = [c.shape for c in inp_comps]
+        src_comps = src.unbind()
+        src_comp_shapes = [c.shape for c in src_comps]
+        if inp_comp_shapes != src_comp_shapes:
+            raise RuntimeError(
+                "copy_(): expected compatible input and src shapes, but got: "
+                f"{inp.shape} and {src.shape}"
+            )
+        for inp_comp, src_comp in zip(inp_comps, src_comps):
+            inp_comp.copy_(src_comp)
+
+    # AOTD allows mutations of inputs only, (not views of the inputs).
+    # NJT.values() returns _values.detach() to workaround some issues.
+    # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
+    # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable.
+    inp._values.copy_(src._values)
+    return inp
+
+
+register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
+    jagged_unary_pointwise
+)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.empty_like.default,
+        torch.ops.aten.ones_like.default,
+        torch.ops.aten.zeros_like.default,
+        torch.ops.aten.rand_like.default,
+        torch.ops.aten.randn_like.default,
+    ],
+    "self: jt_all",
+)
+def like_factory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # Default layout is technically torch.strided but only jagged is supported here.
+    # Rather than force users to specify the layout, assume jagged.
+    # This should be set to strided for redispatching on values.
+    new_kwargs["layout"] = torch.strided
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+    output_kwargs = extract_kwargs(inp)
+    if "offsets" in output_kwargs:
+        output_kwargs["offsets"] = new_offsets
+    if "lengths" in output_kwargs:
+        output_kwargs["lengths"] = new_lengths
+
+    if inp.device != new_values.device:
+        # Update the nested int registry to indicate that the ragged structure is the same
+        # between the two offsets / lengths on different devices.
+        from torch._subclasses.fake_tensor import FakeTensor
+        from torch._subclasses.functional_tensor import (
+            FunctionalTensor,
+            mb_unwrap_functional_tensor,
+        )
+
+        from .nested_tensor import _tensor_symint_registry
+
+        ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+        new_thing = new_offsets if new_lengths is None else new_lengths
+        if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+            # Temporary hack until we have the union find
+            tgt = mb_unwrap_functional_tensor(new_thing)
+            src = mb_unwrap_functional_tensor(ragged_source)
+            tgt.nested_int_memo = src.nested_int_memo
+        else:
+            _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+
+    return NestedTensor(new_values, **output_kwargs)
+
+
+register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
+    like_factory_default
+)
+
+register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
+    like_factory_default
+)
+
+register_jagged_func(
+    torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
+)(like_factory_default)
+
+
+@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
+def zero__default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    func(inp._values)
+    return inp
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
+)
+def _softmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    if isinstance(new_kwargs["dim"], tuple):
+        raise RuntimeError(
+            "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
+        )
+
+    inp = new_kwargs.pop("input")
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        _reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        (new_kwargs["dim"],),
+        "softmax",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_batch:
+        raise RuntimeError(
+            "softmax(): not supported when reducing across the batch dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._ragged_idx > 1:
+        raise RuntimeError(
+            "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            "softmax(): not supported where lengths is not None "
+            + "if reducing across the ragged dimension for NestedTensor"
+        )
+
+    new_kwargs["dim"] = new_kwargs["dim"][
+        0
+    ]  # torch.softmax takes in the reduction dimension as an integer
+
+    if reduce_on_ragged:
+        padded_softmax_values = torch.nn.functional.softmax(
+            torch.ops.aten._jagged_to_padded_dense_forward(
+                inp._values.reshape(
+                    inp._values.shape[0], -1
+                ),  # values are required to be 2D tensors for j2pd
+                [inp._offsets],
+                max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+                padding_value=float("-inf"),  # e^-inf = 0
+            ),
+            dim=inp._ragged_idx,
+        )
+
+        softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_softmax_values,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).reshape(
+            -1, *inp._values.shape[1:]
+        )  # expand softmax_values back to original shape (inp._values.shape)
+
+        return NestedTensor(softmax_values, **extract_kwargs(inp))
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any"
+)
+def _log_softmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    if isinstance(new_kwargs["dim"], tuple):
+        raise RuntimeError(
+            "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
+        )
+
+    inp = new_kwargs.pop("input")
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        _reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx
+    )
+
+    if reduce_on_batch:
+        raise RuntimeError(
+            "log_softmax(): not supported when reducing across the batch dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged:
+        raise RuntimeError(
+            "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor"
+        )
+
+    # torch.log_softmax takes in the reduction dimension as an integer
+    new_kwargs["dim"] = new_kwargs["dim"][0]
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax_backward_data.default,
+    "grad_output: jt, output: jt, dim: any, input_dtype: any",
+)
+def _softmax_backward(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_output")
+    output = new_kwargs.pop("output")
+    return NestedTensor(
+        func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
+)
+def native_dropout_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    out1, out2 = func(inp._values, **new_kwargs)
+    return (
+        NestedTensor(out1, **extract_kwargs(inp)),
+        NestedTensor(out2, **extract_kwargs(inp)),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout_backward.default,
+    "grad_output: jt, mask: jt, scale: any",
+)
+def native_dropout_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_output = new_kwargs.pop("grad_output")
+    mask = new_kwargs.pop("mask")
+    return NestedTensor(
+        func(grad_output._values, mask._values, **new_kwargs),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.prod.dim_int,
+    "self: jt_all, dim: any, keepdim: any?, dtype: any?",
+)
+def prod_dim_int(func, *args, **kwargs):
+    return _apply_reduction(func, "prod", 1, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?")
+def prod_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?"
+)
+def split_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
+    )
+
+    return tuple(
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?"
+)
+def split_with_sizes_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
+    )
+
+    return [
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    ]
+
+
+@register_jagged_func(
+    torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
+)
+def narrow(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
+    values = func(
+        inp._values,
+        dim=dim,
+        start=new_kwargs["start"],
+        length=new_kwargs["length"],
+    )
+    return NestedTensor(values, **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
+def chunk_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
+    )
+
+    if operating_on_batch:
+        chunks = new_kwargs["chunks"]
+
+        # get _offsets of the chunks
+        lengths = inp._offsets.diff()
+        chunked_lengths = lengths.chunk(chunks)
+        chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
+        chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets]  # type: ignore[arg-type]
+        nested_kwargs = [
+            {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
+            for per_offsets in chunked_offsets
+        ]
+
+        # get _values of the chunks
+        split_sizes = [x.sum().item() for x in chunked_lengths]
+        chunk_values = inp._values.split(split_sizes)
+
+        # Note that the actual number of chunks returned is not necessarily the same as
+        # the input number; it can be counter-intuitive, but it matches dense behavior.
+        return [
+            NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
+            for i in range(len(chunk_values))
+        ]
+    else:
+        return [
+            NestedTensor(values=x, **extract_kwargs(inp))
+            for x in func(inp._values, **new_kwargs)
+        ]
+
+
+@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
+def unbind_int(func, *args, **kwargs):
+    # Note that this specializes on the length of the offsets
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dim = new_kwargs["dim"]
+    if dim != 0:
+        raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
+
+    inp = new_kwargs.pop("input")
+    values = inp.values()
+    offsets = inp.offsets()
+    lengths = inp.lengths()
+    ragged_idx = inp._ragged_idx
+
+    def _torch_check(_lengths: list[int], _offsets: list[int] | None = None) -> None:
+        # This torch._check are needed for torch.compile
+        # symbolic shapes processing.
+        # offsets and lengths are symbolic variables during compilation,
+        # we guarantee the correct offsets/lengths correspondence:
+        # sum of lengths <= total ragged_dim_size
+        # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf))
+        # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness
+        # offsets[i] <= ragged_dim_size
+
+        lengths_sum = 0
+        ragged_dim_size = values.shape[ragged_idx - 1]
+        for i in range(len(_lengths)):
+            torch._check(_lengths[i] >= 0)
+            torch._check(_lengths[i] <= ragged_dim_size)
+
+            lengths_sum += _lengths[i]
+            if _offsets is not None:
+                torch._check(
+                    _offsets[i] + _lengths[i] <= ragged_dim_size,
+                    lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension",
+                )
+        torch._check(lengths_sum <= ragged_dim_size)
+
+        if _offsets is not None:
+            for i in range(len(_offsets)):
+                torch._check(_offsets[i] >= 0)
+                torch._check(_offsets[i] <= ragged_dim_size)
+
+    if lengths is None:
+        lengths_scalars = offsets.diff().tolist()
+        _torch_check(lengths_scalars)
+
+        return torch.split(values, lengths_scalars, dim=(ragged_idx - 1))
+
+    if ragged_idx <= 0:
+        raise RuntimeError(
+            "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
+        )
+
+    lengths_scalars = lengths.tolist()
+    offsets_scalars = offsets.tolist()
+
+    _torch_check(lengths_scalars, offsets_scalars)
+
+    return [
+        torch.narrow(
+            values,
+            dim=(ragged_idx - 1),
+            start=offsets_scalars[i],
+            length=lengths_scalars[i],
+        )
+        for i in range(lengths.shape[0])
+    ]
+
+
+@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
+def squeeze_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
+    )
+    return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any")
+def unsqueeze_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True
+    )
+
+    # ragged_idx changes if a dimension is added before it
+    output_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] <= inp._ragged_idx - 1:
+        output_kwargs["_ragged_idx"] += 1
+
+    return NestedTensor(func(values, **new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
+def cat_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+
+    # Convert any non-nested to nested
+    nested = [t for t in tensors if t.is_nested]
+    assert len(nested) > 0
+    first = nested[0]
+    tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(first.shape), dim, first._ragged_idx, "cat"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any")
+def matmul_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    def _unbind_impl(a, b):
+        return [
+            func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind())
+        ]
+
+    def _padded_impl(a, b):
+        if a.is_nested:
+            nt = a
+        else:
+            nt = b
+
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *nt.shape[: nt._ragged_idx],
+            padded_max_S,
+            *nt.shape[nt._ragged_idx + 1 :],
+        )
+        padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape)
+        if a.is_nested:
+            padded_t = func(padded_nt, b)
+        else:
+            padded_t = func(a, padded_nt)
+        return nested_from_padded(
+            padded_t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+    # TODO: Back these with proper kernels (e.g. grouped GEMM)
+    # NJT x dense
+    if inp.is_nested and not other.is_nested:
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if (
+            inp.dim() >= 3
+            and inp.dim() == other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        # etc.
+        elif (
+            other.dim() == 2
+            and inp.dim() > other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            return NestedTensor(
+                func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
+            )
+    # Dense x NJT
+    elif not inp.is_nested and other.is_nested:
+        # (B, D, E) x (B, E, j1) => (B, E, j1)
+        if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2:
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (D, E) x (B, E, j1) => (B, D, j1)
+        # (D, E) x (B, E, j1, F) => (B, D, j1, F)
+        # etc.
+        elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2:
+            return NestedTensor(
+                func(inp, other._values, **new_kwargs), **extract_kwargs(other)
+            )
+
+    # NJT x NJT
+    elif inp.is_nested and other.is_nested:
+        # Support ragged batch dim:
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc.
+        if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
+            return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
+        # Support reducing over ragged with dense output:
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+        elif (
+            inp.dim() == 3
+            and other.dim() == 3
+            and inp._ragged_idx == 2
+            and other._ragged_idx == 1
+            and inp.size(inp._ragged_idx) == other.size(other._ragged_idx)
+        ):
+            # do unbind for this; can't use padded conversion due to j1 in last dim
+            return torch.stack(_unbind_impl(inp, other))
+
+    raise RuntimeError(
+        f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
+    )
+
+
+@register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any")
+def bmm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("mat2")
+
+    if inp.dim() != 3:
+        raise ValueError("bmm(): input must be 3D")
+    if other.dim() != 3:
+        raise ValueError("bmm(): mat2 must be 3D")
+
+    return matmul_default(torch.ops.aten.matmul.default, inp, other)
+
+
+@register_jagged_func(
+    torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?"
+)
+def expand_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs["size"]
+
+    assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
+    if not raggedness_matches(inp, size):
+        raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
+
+    expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())]
+    return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
+def expand_as_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    return NestedTensor(func(inp, other._values), **extract_kwargs(other))
+
+
+@register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any")
+def broadcast_to(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if len(size) <= inp.dim():
+        return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size])
+
+    raise ValueError(
+        "broadcast_to(): broadcasting to a higher-dim shape is currently not supported "
+        "for nested tensors with the jagged layout"
+    )
+
+
+@register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any")
+def broadcast_tensors(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+    if len(tensors) == 0:
+        raise ValueError("broadcast_tensors(): expected at least one tensor input")
+    if len(tensors) == 1:
+        return tensors[0]
+
+    outs = []
+    broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors))
+    # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible.
+    njt = next(t for t in tensors if isinstance(t, NestedTensor))
+    for t in tensors:
+        if t.is_nested:
+            outs.append(t.broadcast_to(broadcast_shape))
+        elif t.dim() < len(broadcast_shape):
+            outs.append(
+                NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt))
+            )
+        else:
+            raise ValueError(
+                "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal "
+                "or higher dim is not currently supported"
+            )
+
+    return tuple(outs)
+
+
+@register_jagged_func(
+    torch.ops.aten.where.self, "condition: jt_all, self: any, other: any"
+)
+def where_self(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    condition = new_kwargs.pop("condition")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    # if the tensors aren't compatible, broadcast_tensors() will let us know
+    condition, inp, other = torch.broadcast_tensors(condition, inp, other)
+
+    return NestedTensor(
+        func(condition._values, inp._values, other._values, **new_kwargs),
+        **extract_kwargs(condition),
+    )
+
+
+@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
+def _pin_memory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
+def is_pinned_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
+)
+def is_same_size_default(func, *args, **kwargs):
+    return args[0]._size == args[1]._size
+
+
+def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # some ops use dim=None to indicate a full reduction; some use an empty dim list
+    full_reduction = new_kwargs["dim"] is None or (
+        isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0
+    )
+    if full_reduction:
+        out = func(inp._values, **new_kwargs)
+        if new_kwargs.get("keepdim", False):
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
+            else:
+                out = out.unsqueeze(inp._ragged_idx)
+        return out
+
+    # some ops support lists of dims; some don't
+    dim_to_convert = new_kwargs["dim"]
+    is_dimlist = isinstance(new_kwargs["dim"], (tuple, list))
+    if not is_dimlist:
+        dim_to_convert = [dim_to_convert]
+
+    (
+        converted_dim,
+        reduce_on_batch,
+        reduce_on_ragged,
+        reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        dim_to_convert,
+        f"{func_name}",
+        inp._ragged_idx,
+    )
+
+    if not is_dimlist:
+        # convert back from list
+        converted_dim = converted_dim[0]
+    new_kwargs["dim"] = converted_dim
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            f"{func_name}(): reducing across the ragged dimension is not supported "
+            "for non-contiguous nested tensors with holes"
+        )
+
+    from torch.utils._pytree import tree_map
+
+    # raggedness reduced away --> return dense tensor
+    if reduce_on_ragged:
+        # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
+        if reduce_on_batch:
+            # no need to read offsets --> apply sum directly on values
+            out = func(inp._values, **new_kwargs)
+            if new_kwargs.get("keepdim", False):
+                # some ops return multiple things; unsqueeze all of them
+                out = tree_map(lambda o: o.unsqueeze(0), out)
+            return out
+        else:
+            # invalid reduction cases: (ragged, non-batch), etc.
+            if reduce_on_non_batch:
+                raise RuntimeError(
+                    f"{func_name}(): reducing along a ragged and non-batch dimension "
+                    "is not supported for nested tensors"
+                )
+
+            # reduction cases: (ragged)
+            # convert to padded dense and reduce
+            new_kwargs.pop("dim")
+            dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
+            return func(
+                inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
+            )
+    # raggedness preserved --> return nested tensor
+    else:
+        # invalid reduction cases: (batch), (batch, non-batch), etc.
+        if reduce_on_batch:
+            raise RuntimeError(
+                f"{func_name}(): reducing along the batch dimension but not "
+                "the ragged dimension is not supported for nested tensors"
+            )
+
+        # reduction cases: (non-batch), (non-batch, non-batch), etc.
+        # apply sum directly on values
+        out = func(inp._values, **new_kwargs)
+        out_kwargs = extract_kwargs(inp)
+        if not new_kwargs.get("keepdim", False):
+            # dims are reduced away -> ragged_idx of output needs to be reevaluated
+            dimlist = (
+                new_kwargs["dim"]
+                if isinstance(new_kwargs["dim"], (tuple, list))
+                else [new_kwargs["dim"]]
+            )
+            for d in dimlist:
+                # adjust for all dims reduced before the ragged dim
+                if d < inp._ragged_idx - 1:
+                    out_kwargs["_ragged_idx"] -= 1
+
+        # some ops return multiple things; wrap each of them as an NJT
+        return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
+
+
+@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
+def sum_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.sum.dim_IntList,
+    "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
+)
+def sum_dim_IntList(func, *args, **kwargs):
+    return _apply_reduction(func, "sum", 0, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
+)
+def transpose_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    from torch._prims_common import canonicalize_dims
+
+    inp = new_kwargs.pop("input")
+    dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
+
+    # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
+    # instead of 1, although the internal Flash and mem-effn implementations will
+    # use the inputs with raggedness in dim 1.
+    if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
+        if dim0 == 0 or dim1 == 0:
+            raise ValueError(
+                "Transpose is not supported on the batch dimension for jagged NT"
+            )
+        if dim0 == inp._ragged_idx:
+            to_dim = dim1
+        else:
+            to_dim = dim0
+        inp_kwargs = extract_kwargs(inp)
+        inp_kwargs["_ragged_idx"] = to_dim
+        return NestedTensor(
+            inp.values().transpose(
+                _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
+                _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
+            ),
+            **inp_kwargs,
+        )
+
+    new_kwargs["dim0"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
+    )
+    new_kwargs["dim1"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
+def permute_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    dims = new_kwargs.pop("dims")
+    inp_kwargs = extract_kwargs(inp)
+    inp_dim = len(inp._size)
+
+    # The first two checks are the same as the checks in the normal permute implementation
+    if inp_dim != len(dims):
+        raise ValueError(
+            f"permute(): number of dimensions in the tensor input ({inp_dim}) "
+            + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
+        )
+
+    from torch._prims_common import canonicalize_dims
+
+    canonicalized_dims = canonicalize_dims(inp_dim, dims)
+
+    if len(canonicalized_dims) != len(set(canonicalized_dims)):
+        raise ValueError("permute(): duplicate dims are not allowed.")
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "permute(): not supported on jagged layout nested tensor with holes"
+        )
+    if canonicalized_dims[0] != 0:
+        raise ValueError(
+            "Permute is not supported on the batch dimension for jagged NT"
+        )
+    inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
+    inner_dims = [
+        _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
+        for dim in canonicalized_dims[1:]
+    ]
+    new_kwargs["dims"] = inner_dims
+    return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
+
+
+@register_jagged_func(
+    [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
+    "self: jt_all, size: any",
+)
+def view_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
+        raise RuntimeError(
+            f"view(): does not support ragged_idx != 1 except when inp._size == size. "
+            f"inp._size is ({inp._size}) and size is ({size})."
+        )
+
+    # Ensure specified size still includes batch and ragged dims
+    if len(size) < 3 or not raggedness_matches(inp, size):
+        raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
+
+    # outer size: the size of the NT, e.g. [3, j0, 10]
+    # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
+    # this function gets inner_size[inner_idx] for a given inner_idx.
+    #
+    # example: for outer size [a, b, c, j0, d, e, f]
+    #                         assume that j0 is ragged, other are concrete integers
+    #                         and ragged_idx=3
+    # inner size will be      [b, c, inp._values.size(ragged_idx), d, e, f]
+    # therefore:
+    #    inner_size[0] = outer_size[1]
+    #    inner_size[1] = outer_size[2]
+    #    inner_size[0] = inp._values.size(ragged_idx - 1)
+    #    inner_size[3] = outer_size[4]
+    #    inner_size[4] = outer_size[5]
+    def get_inner_size(inner_idx):
+        nonlocal inp, size
+        if inner_idx == inp._ragged_idx - 1:
+            return inp._values.size(inner_idx)
+        else:
+            return size[inner_idx + 1]
+
+    inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
+
+    # Preserve inference-mode-ness of input.
+    # TODO: Do this for all other views!
+    with torch.inference_mode(inp.is_inference()):
+        return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm.default,
+    "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
+)
+def native_layer_norm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp.dim() <= 2:
+        raise RuntimeError(
+            "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
+        )
+
+    normalized_shape = new_kwargs["normalized_shape"]
+    ragged_size = inp.shape[inp._ragged_idx]
+
+    num_dims_not_normalized = inp.dim() - len(normalized_shape)
+
+    if (
+        num_dims_not_normalized == 0
+    ):  # error if trying to normalize over the batch dimension
+        raise RuntimeError(
+            "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
+        )
+
+    if ragged_size in normalized_shape and inp._lengths is not None:
+        raise RuntimeError(
+            "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
+        )
+
+    if (
+        ragged_size in normalized_shape
+    ):  # special handling for normalizing over the ragged dimension
+        padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
+            inp._values.flatten(
+                start_dim=inp._ragged_idx
+            ),  # _jagged_to_padded_dense_forward requires values to be a 2D tensor
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        )
+
+        padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
+            torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        ).expand(
+            padded_input.shape
+        )  # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
+
+        ragged_lengths = (
+            inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
+        )  # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
+
+        mean = (
+            torch.sum(
+                padded_input,
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        padded_normalized = (
+            (padded_input - mean) * padded_mask
+        )  # mask elements outside of the ragged dimension size for correct variance calculation
+
+        variance = (
+            torch.sum(
+                torch.square(padded_normalized),
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        std = torch.sqrt(variance + new_kwargs["eps"])
+        padded_layer_norm = padded_normalized / std
+
+        jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_layer_norm,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).unflatten(
+            -1, inp.shape[inp._ragged_idx + 1 :]
+        )  # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
+
+        return (
+            NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
+            mean,
+            std,
+        )
+
+    output, mean, std = func(inp._values, **new_kwargs)
+    return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm_backward.default,
+    "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
+)
+def native_layer_norm_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_out")
+    inp = new_kwargs.pop("input")
+    d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
+    if d_input is None:
+        return (None, d_gamma, d_beta)
+
+    return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
+
+
+@register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any")
+def select_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
+    )
+
+    # handle batch dim slicing via unbind() for now
+    # TODO: make this more efficient
+    if operating_on_batch:
+        return inp.unbind()[new_kwargs["index"]]
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes"
+        )
+
+    # if selecting before the ragged dim, adjust output ragged_idx
+    out_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] < inp._ragged_idx - 1:
+        out_kwargs["_ragged_idx"] -= 1
+
+    return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.slice.Tensor,
+    "self: jt, dim: any?, start: any?, end: any?, step: any?",
+)
+def slice_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.index_put.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+@register_jagged_func(
+    torch.ops.aten.index_put_.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+def index_put_(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp: NestedTensor = new_kwargs.pop("input")
+
+    # For index_put_ to work, we add together the indices of the ragged dimension
+    # and the batch dimension, adding the offsets of each ragged dimension to its
+    # indices
+
+    indices = new_kwargs.pop("indices")
+
+    assert len(indices) <= inp.dim()
+
+    if len(indices) < inp._ragged_idx + 1:
+        if not inp.is_contiguous():
+            raise RuntimeError(
+                "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
+            )
+        # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = inp._maybe_min_seqlen
+        max_seqlen = inp._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = inp._values.shape[inp._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *inp.shape[: inp._ragged_idx],
+            padded_max_S,
+            *inp.shape[inp._ragged_idx + 1 :],
+        )
+        padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
+        new_njt = nested_from_padded(
+            func(padded_inp, indices, **new_kwargs),
+            offsets=inp._offsets,
+            ragged_idx=inp._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        if func is torch.ops.aten.index_put_.default:
+            inp._values.copy_(new_njt.values())
+            return inp
+        return new_njt
+
+    # We can run on the underlying values directly
+
+    # Validate indices
+    if inp.lengths() is None:
+        lengths = inp.offsets().diff()
+    else:
+        lengths = inp.lengths()
+    torch._assert_async(
+        # pyrefly: ignore [no-matching-overload]
+        torch.all(indices[inp._ragged_idx] < lengths),
+        "Some indices in the ragged dimension are out of bounds!",
+    )
+
+    # Recompute indices for _values
+    ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
+    func_indices = (
+        # before ragged dim
+        indices[1 : inp._ragged_idx]
+        # ragged dim (combined with batch)
+        + [ragged_indices]
+        # after ragged dim
+        + indices[inp._ragged_idx + 1 :]
+    )
+
+    if func is torch.ops.aten.index_put_.default:
+        inp._values = func(inp._values, func_indices, **new_kwargs)
+        return inp
+
+    return NestedTensor(
+        func(inp._values, func_indices, **new_kwargs),
+        **extract_kwargs(inp),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.convolution.default,
+    "input: jt, weight: t, bias: t?, stride: any, padding: any, "
+    "dilation: any, transposed: any, output_padding: any, groups: any",
+)
+def convolution_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
+)
+def mean_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs["input"]
+    (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims(
+        inp.dim(),
+        new_kwargs["dim"],
+        "mean",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_ragged and not reduce_on_batch:
+        assert not reduce_on_non_batch
+        # calculate an intermediate sum and leave the dim in for normalization purposes
+        keepdim = new_kwargs["keepdim"]
+        new_kwargs["keepdim"] = True
+        intermediate_sum = _apply_reduction(
+            torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs
+        )
+
+        # normalize by sequence lengths
+        lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff()
+        for _ in range(intermediate_sum.dim() - 1):
+            lengths = lengths.unsqueeze(-1)
+        out = intermediate_sum / lengths
+        if not keepdim:
+            out = out.squeeze(inp._ragged_idx)
+        return out
+
+    # at this point, we're just redispatching on the values buffer
+    # since we expect it to be unused, specify a weird intermediate value to
+    # hopefully make errors obvious
+    intermediate_value = 0.42
+    return _apply_reduction(func, "mean", intermediate_value, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?")
+def mean_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?")
+def any_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "any", False, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?")
+def any_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return any_dims(torch.ops.aten.any.dims, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?")
+def all_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "all", True, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?")
+def all_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return all_dims(torch.ops.aten.all.dims, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.all.default,
+        torch.ops.aten.any.default,
+        torch.ops.aten.max.default,
+        torch.ops.aten.min.default,
+    ],
+    "self: jt_all",
+)
+def all_any_max_min_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    [torch.ops.aten._is_all_true.default, torch.ops.aten._is_any_true.default],
+    "self: jt_all",
+)
+def _is_true_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return func(inp._values)
+
+
+@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
+def min_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "min", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?")
+def max_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "max", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "amin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "amax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_max = _get_padding_value(dtype, "max")
+    return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype = new_kwargs["input"].dtype
+    dtype_min = _get_padding_value(dtype, "min")
+    return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.value_selecting_reduction_backward.default,
+    "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
+)
+def value_selecting_reduction_backward_default(func, *args, **kwargs):
+    from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    new_kwargs["grad"] = grad._values
+    indices = new_kwargs.pop("indices")
+    new_kwargs["indices"] = indices._values
+    # should always succeed; sizes should contain a nested int
+    ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
+    # convert dim -> values-space dim
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(new_kwargs["sizes"]),
+        new_kwargs["dim"],
+        ragged_idx,
+        "value_selecting_reduction_backward",
+    )
+    # convert saved NJT sizes -> values-space sizes
+    sizes = new_kwargs.pop("sizes")
+    sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
+    sizes = sizes[1:]
+    new_kwargs["sizes"] = sizes
+
+    output_kwargs = extract_kwargs(indices)
+    output_kwargs["_ragged_idx"] = ragged_idx
+
+    return NestedTensor(func(**new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
+def stack_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    tensors = new_kwargs.pop("tensors")
+    for t in tensors:
+        if not isinstance(t, NestedTensor):
+            raise RuntimeError("stack(): expected all nested tensors inputs")
+
+        if t.dim() != tensors[0].dim():
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same dim"
+            )
+
+        if not raggedness_matches(t, tensors[0].shape):
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same nested structure"
+            )
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding.default,
+    "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
+)
+def embedding_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    indices = new_kwargs.pop("indices")
+    weight = new_kwargs.pop("weight")
+
+    return NestedTensor(
+        func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding_dense_backward.default,
+    "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
+)
+def embedding_dense_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    indices = new_kwargs.pop("indices")
+    grad_output = new_kwargs.pop("grad_output")
+    return func(grad_output._values, indices._values, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.values.default,
+        torch.ops.aten._nested_get_values.default,
+    ],
+    "self: jt_all",
+)
+def values_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # TODO: Handle inference mode properly.
+    # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
+    return inp._values.detach()
+
+
+@register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
+def all_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values)
+
+
+@register_jagged_func(
+    torch.ops.aten.to_padded_tensor.default,
+    "self: jt_all, padding: any, output_size: any?",
+)
+def to_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp._lengths is not None:
+        raise RuntimeError(
+            "to_padded_tensor(): not supported for nested tensors with holes"
+        )
+
+    # TODO: Handle the rest of output_size
+    output_size = new_kwargs["output_size"]
+    if output_size is not None:
+        max_seq_len = output_size[inp._ragged_idx]
+    else:
+        max_seq_len = (
+            inp._max_seqlen
+            if inp._max_seqlen_tensor is not None
+            else inp._values.size(0)
+        )
+
+    # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics if needed
+    values = inp.values()
+    if inp._ragged_idx > 1:
+        values = values.transpose(inp._ragged_idx - 1, 0)
+    values_shape = values.shape
+    if values.dim() > 2:
+        values = values.flatten(start_dim=1)
+    elif values.dim() == 1:
+        values = values.unsqueeze(-1)
+
+    # NB: The CUDA kernel for jagged -> padded dense conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = values.dtype is torch.bool
+    if is_bool and values.is_cuda:
+        values = values.to(torch.half)
+    padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
+        values,
+        [inp._offsets],
+        [max_seq_len],
+        new_kwargs["padding"],
+    )
+    if is_bool and padded_out.is_cuda:
+        padded_out = padded_out.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(values_shape) > 2:
+        padded_out = padded_out.unflatten(-1, values_shape[1:])
+    elif len(values_shape) == 1:
+        padded_out = padded_out.squeeze(-1)
+    if inp._ragged_idx > 1:
+        padded_out = padded_out.transpose(inp._ragged_idx, 1)
+
+    return padded_out
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_from_padded_tensor.default,
+    "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?",
+)
+def _nested_from_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]
+    ragged_idx = new_kwargs.get("ragged_idx", 1)
+
+    # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics
+    if ragged_idx > 1:
+        padded = padded.transpose(ragged_idx, 1)
+    padded_ragged_dim1_shape = padded.shape
+    if padded.dim() > 3:
+        padded = padded.flatten(start_dim=2)
+    elif padded.dim() < 3:
+        padded = padded.unsqueeze(-1)
+
+    # NB: The CUDA kernel for padded dense -> jagged conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = padded.dtype is torch.bool
+    if is_bool and padded.is_cuda:
+        padded = padded.to(torch.half)
+    values = torch.ops.aten._padded_dense_to_jagged_forward(
+        padded, [offsets], new_kwargs["sum_S"]
+    )
+    if is_bool and values.is_cuda:
+        values = values.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(padded_ragged_dim1_shape) > 3:
+        values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
+    elif len(padded_ragged_dim1_shape) < 3:
+        values = values.squeeze(-1)
+    if ragged_idx > 1:
+        values = values.transpose(ragged_idx - 1, 0)
+
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_view_from_jagged.default,
+    "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
+)
+def _nested_view_from_jagged_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    values, offsets, lengths = (
+        new_kwargs["input"],
+        new_kwargs["offsets"],
+        new_kwargs["lengths"],
+    )
+    ragged_idx = new_kwargs["ragged_idx"]
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        lengths=lengths,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
+def _nested_get_offsets(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._offsets
+
+
+@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
+def _nested_get_lengths(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._lengths
+
+
+@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
+def _nested_get_ragged_idx(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._ragged_idx
+
+
+@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
+def _nested_get_min_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("min_seqlen", None)
+
+
+@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
+def _nested_get_max_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("max_seqlen", None)
+
+
+# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
+@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
+def masked_select_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    mask = new_kwargs.pop("mask")
+
+    if inp.ndim > 2:
+        raise RuntimeError("masked_select only support 2-D selections currently")
+    elif inp.shape != mask.shape:
+        raise RuntimeError(
+            f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
+        )
+    res_values = inp._values.masked_select(mask.values())
+    mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0))  # type: ignore[arg-type]
+
+    args = extract_kwargs(inp)
+    args["offsets"] = mask_cumsum[inp._offsets]
+    return NestedTensor(
+        values=res_values,
+        **args,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_select_backward.default,
+    "grad_output: t, self: jt_all, dim: any, index: any",
+)
+def _nested_select_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+
+    grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
+    grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
+
+    return grad_input
+
+
+@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
+def record_stream_default(func, *args, **kwargs) -> None:
+    inp = args[0]
+    stream = args[1]
+    # ensure all components live until stream computation completes
+    func(inp._values, stream)
+    func(inp._offsets, stream)
+    if inp._lengths is not None:
+        func(inp._lengths, stream)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.new_empty.default,
+        torch.ops.aten.new_zeros.default,
+        torch.ops.aten.new_ones.default,
+    ],
+    "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
+)
+def new_empty_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if len(new_kwargs["size"]) == 0:
+        return func(inp._values, **new_kwargs)
+
+    raise RuntimeError("new_empty() not supported for NJT with shape != ()")
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.elu_backward.default,
+        torch.ops.aten.hardshrink_backward.default,
+        torch.ops.aten.hardsigmoid_backward.default,
+        torch.ops.aten.hardtanh_backward.default,
+        torch.ops.aten.softplus_backward.default,
+        torch.ops.aten.softshrink_backward.default,
+    ],
+    "self: jt_all, ...",
+)
+def activation_backward(func, *args, **kwargs):
+    # first NJT arg is expected to be grad_output
+    grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(
+            *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
+            **kwargs,
+        ),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any")
+def fill_Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
+def fill__Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    func(inp._values, **new_kwargs)
+    return inp
+
+
+@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
+def frexp_Tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    output_kwargs = extract_kwargs(inp)
+
+    mantissa, exponent = func(inp._values)
+    return NestedTensor(mantissa, **output_kwargs), NestedTensor(
+        exponent, **output_kwargs
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.matmul_backward.default,
+    "grad: any, self: any, other: any, mask: any",
+)
+def matmul_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+    grad_input_mask = new_kwargs.pop("mask")
+
+    if grad is None:
+        return (None, None)
+
+    grad_self = None
+    if grad_input_mask[0]:
+        grad_self = torch.matmul(grad, other.transpose(-1, -2))
+
+    grad_other = None
+    if grad_input_mask[1]:
+        grad_other = torch.matmul(inp.transpose(-1, -2), grad)
+
+    return (grad_self, grad_other)
+
+
+# Make the dummy available on the C++ side.
+@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
+def _nested_get_jagged_dummy(func, *args, **kwargs):
+    from torch.nested._internal.nested_tensor import _nt_view_dummy
+
+    return _nt_view_dummy()
+
+
+with torch.library._scoped_library("aten", "IMPL") as aten:
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
new file mode 100644
index 0000000000000000000000000000000000000000..328702ede37462cf880503e575b6d722b7ba4a40
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
@@ -0,0 +1,933 @@
+# mypy: allow-untyped-defs
+import logging
+
+import torch
+import torch.nn
+import torch.nn.functional as F
+from torch.backends.cuda import (
+    can_use_cudnn_attention,
+    can_use_efficient_attention,
+    can_use_flash_attention,
+    cudnn_sdp_enabled,
+    flash_sdp_enabled,
+    math_sdp_enabled,
+    mem_efficient_sdp_enabled,
+    SDPAParams,
+)
+from torch.nn.attention import SDPBackend
+
+from .nested_tensor import NestedTensor
+
+
+log = logging.getLogger(__name__)
+
+
+def _validate_sdpa_input(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+) -> None:
+    if (
+        not isinstance(query, NestedTensor)
+        or not isinstance(key, NestedTensor)
+        or not isinstance(value, NestedTensor)
+    ):
+        raise ValueError(
+            f"Expected query, key, and value to be nested tensors, "
+            f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
+            f"and value.is_nested: {value.is_nested} instead."
+        )
+    if query.dtype != key.dtype or query.dtype != value.dtype:
+        raise ValueError(
+            f"Expected query, key, and value to have the same dtype, "
+            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
+            f"and value.dtype: {value.dtype} instead."
+        )
+    if query.device != key.device or query.device != value.device:
+        raise ValueError(
+            f"Expected query, key, and value to have the same device type, "
+            f"but got query.device: {query.device}, key.device: {key.device}, "
+            f"and value.device: {value.device} instead."
+        )
+    if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
+        raise ValueError(
+            f"Expected query, key, and value to all be  at least 3 dimensional, but got query.dim: "
+            f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
+        )
+    if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
+        raise ValueError(
+            f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
+            f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
+        )
+    if attn_mask is not None:
+        # TODO: Figure out whether masks are actually supported for this layout or not
+        raise ValueError("Masks are not yet supported!")
+        if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
+            raise ValueError(
+                f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
+                f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
+            )
+
+
+def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
+    # This is expected to be called after check_tensor_shapes ensuring that the
+    # size() calls won't error since the inputs are all 4 dimensional
+    q_batch_size = params.query.size(0)
+    k_batch_size = params.key.size(0)
+    v_batch_size = params.value.size(0)
+
+    # num_heads logic for nested input is checked in
+    # check_for_seq_len_0_nested_tensor as there is handling there to make sure
+    # num_heads is not ragged
+    return q_batch_size == k_batch_size and q_batch_size == v_batch_size
+
+
+def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 256
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 256. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 128
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 128. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+    param: torch.Tensor, param_name: str, debug=False
+) -> bool:
+    assert isinstance(param, NestedTensor), "param should be a jagged NT"
+
+    if param._ragged_idx == 1:
+        # num_head_dims is ragged
+        if debug:
+            log.warning(
+                "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
+                param_name,
+            )
+        return False
+
+    # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
+    if param._get_min_seqlen() == 0:
+        if debug:
+            log.warning(
+                "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
+                param_name,
+            )
+        return False
+
+    return True
+
+
+def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
+    max_size = max(q_size, k_size, v_size)
+    if (
+        (q_size != max_size and q_size != 1)
+        or (k_size != max_size and k_size != 1)
+        or (v_size != max_size and v_size != 1)
+    ):
+        if debug:
+            log.warning(
+                "Both fused kernels require query, key and value to have broadcastable %s, "
+                "got Query %s %d, Key %s %d, Value %s %d instead.",
+                param_name,
+                param_name,
+                q_size,
+                param_name,
+                k_size,
+                param_name,
+                v_size,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
+    # When this function is called we are assured that the nt is dim==4
+    q_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.query, "query", debug
+        )
+        if params.query.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not q_is_safe:
+        return False
+
+    k_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.key, "key", debug
+        )
+        if params.key.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not k_is_safe:
+        return False
+
+    v_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.value, "value", debug
+        )
+        if params.value.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not v_is_safe:
+        return False
+
+    # We now know none of the inputs have ragged num_heads, so we can safely
+    # access .size(1)
+    q_num_heads = params.query.size(1)
+    k_num_heads = params.key.size(1)
+    v_num_heads = params.value.size(1)
+    same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
+
+    if not same_num_heads:
+        if (
+            params.query.requires_grad
+            or params.key.requires_grad
+            or params.value.requires_grad
+        ):
+            if debug:
+                log.warning(
+                    "Both fused kernels do not support training with broadcasted NT inputs."
+                )
+            return False
+        return _try_broadcast_param_size(
+            q_num_heads, k_num_heads, v_num_heads, "num heads", debug
+        )
+    return True
+
+
+def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_head_dim_size_flash_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    if (
+        not params.query.transpose(1, 2).is_contiguous()
+        or not params.key.transpose(1, 2).is_contiguous()
+        or not params.value.transpose(1, 2).is_contiguous()
+    ):
+        if debug:
+            log.warning(
+                "If inputs are nested tensors they must be contiguous after transposing."
+            )
+        return False
+    if params.is_causal:
+        if debug:
+            log.warning(
+                "Nested tensors for query / key are not supported when is_causal=True."
+            )
+        return False
+    return True
+
+
+def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
+    if (
+        not flash_sdp_enabled()
+        and not mem_efficient_sdp_enabled()
+        and not math_sdp_enabled()
+        and not cudnn_sdp_enabled()
+    ):
+        return SDPBackend.ERROR
+
+    ordering = (
+        SDPBackend.FLASH_ATTENTION,
+        SDPBackend.EFFICIENT_ATTENTION,
+        SDPBackend.MATH,
+        SDPBackend.CUDNN_ATTENTION,
+    )
+
+    params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
+
+    for backend in ordering:
+        if backend == SDPBackend.CUDNN_ATTENTION:
+            if can_use_cudnn_attention(params):
+                return SDPBackend.CUDNN_ATTENTION
+        if backend == SDPBackend.FLASH_ATTENTION:
+            if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
+                return SDPBackend.FLASH_ATTENTION
+        if backend == SDPBackend.EFFICIENT_ATTENTION:
+            if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
+                params
+            ):
+                return SDPBackend.EFFICIENT_ATTENTION
+        if backend == SDPBackend.MATH:
+            if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
+                return SDPBackend.MATH
+
+    log.warning("Memory efficient kernel not used because:")
+    can_use_efficient_attention(params, debug=True)
+    _can_use_efficient_sdpa_jagged(params, debug=True)
+    log.warning("Flash attention kernel not used because:")
+    can_use_flash_attention(params, debug=True)
+    _can_use_flash_sdpa_jagged(params, debug=True)
+    log.warning("Math attention kernel not used because:")
+    _can_use_math_sdpa_jagged(params, debug=True)
+    log.warning("cuDNN attention kernel not used because:")
+    can_use_cudnn_attention(params, debug=True)
+    return SDPBackend.ERROR
+
+
+def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
+    # This function is used to calculate two pieces of metadata that are needed
+    # for use with flash-attention and efficient_attention kernels. They are the
+    # cumulative sequence_length over a batch of sequences and the maximum
+    # sequence length.
+
+    # It returns a tuple of cumulative sequence lengths and the maximum sequence
+    # length, and the last element in the cumulative_sequence_lengths
+    if not isinstance(qkv, NestedTensor):
+        raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
+
+    if qkv.lengths() is None:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
+        max_seqlen = qkv._get_max_seqlen()
+        n_elem = qkv.values().shape[0]
+    else:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = (
+            qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
+        )
+        max_seqlen = qkv._get_max_seqlen()
+        # TODO: Explore performance impact when compiling
+        n_elem = int(cumulative_seqlen[-1].item())
+    return cumulative_seqlen, max_seqlen, n_elem
+
+
+def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool:
+    # This function checks if a nested tensor is valid for
+    # use with the flash-attention and efficient_attention kernels without
+    # needing to call contiguous on the nested tensor input.
+    # It checks that the storage offsets' adjacent_differences are a constant
+    # multiple of the previous tensor in the nested tensor and that the strides
+    # are monitonically decreasing. This check is done after calling transpose on
+    # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
+
+    # Returns a boolean indicating if contiguous needs to be called for input
+    assert isinstance(tensor, NestedTensor)
+    offsets = tensor.offsets()
+    strides = tensor._strides
+
+    n_tensors = offsets.size(0) - 1
+    if n_tensors <= 1:
+        return True
+
+    # Check initially that the tensor strides are in strictly descending order
+    prev_stride = strides[1]
+    for stride in strides[2:]:
+        if prev_stride <= stride:
+            # This would mean that the last stride is greater than the seq_len
+            # stride
+            return False
+        prev_stride = stride
+
+    # Congrats you made it!
+    return True
+
+
+def _view_as_dense(
+    tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
+) -> torch.Tensor:
+    if tensor.is_nested:
+        return tensor.values()
+    return tensor.view(Nnz, num_heads, head_dim)
+
+
+# TODO: Next iteration should add test cases and check it works
+# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
+#     # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+#     # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     q_batch_size = query.size(0)
+#     k_batch_size = key.size(0)
+#     v_batch_size = value.size(0)
+
+#     output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
+
+#     q_num_heads = query.size(1)
+#     k_num_heads = key.size(1)
+#     v_num_heads = value.size(1)
+
+#     output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
+
+#     head_dim_qk = query.size(3)
+#     head_dim_v = value.size(3)
+
+#     q_t = query.transpose(1, 2)
+#     k_t = key.transpose(1, 2)
+#     v_t = value.transpose(1, 2)
+
+#     # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
+#     # output_batch_size/num_heads then they are 1
+#     q_batch_size_needs_broadcast = q_batch_size != output_batch_size
+#     k_batch_size_needs_broadcast = k_batch_size != output_batch_size
+#     v_batch_size_needs_broadcast = v_batch_size != output_batch_size
+
+#     # If {*}_batch_size_needs_broadcast, then
+#     # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
+#     #     this is because needs_broadcast indicates that the batch_size is 1
+#     #     and hence there is only 1 value for seq_len
+#     # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
+#     # ..., output_batch_size * {*}_t.size(1)]
+#     # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
+
+#     if q_batch_size_needs_broadcast or not q_t.is_nested:
+#         max_seqlen_batch_q = q_t.size(1)
+#         cumulative_sequence_length_q = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_q,
+#             max_seqlen_batch_q,
+#             device=q_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_q = output_batch_size * max_seqlen_batch_q
+#     else:
+#         (
+#             cumulative_sequence_length_q,
+#             max_seqlen_batch_q,
+#             Nnz_q,
+#         ) = _cumulative_and_max_seq_len_nnz(q_t)
+
+#     if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
+#         assert k_t.size(1) == v_t.size(1)
+#         max_seqlen_batch_kv = k_t.size(1)
+#         cumulative_sequence_length_kv = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_kv,
+#             max_seqlen_batch_kv,
+#             device=k_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_kv = output_batch_size * max_seqlen_batch_kv
+#     else:
+#         cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
+#             _cumulative_and_max_seq_len_nnz(v_t)
+#             if k_batch_size_needs_broadcast
+#             else _cumulative_and_max_seq_len_nnz(k_t)
+#         )
+
+#     q_num_heads_needs_broadcast = q_num_heads != output_num_heads
+#     k_num_heads_needs_broadcast = k_num_heads != output_num_heads
+#     v_num_heads_needs_broadcast = v_num_heads != output_num_heads
+
+#     if not q_t.is_nested:
+#         query_buffer_reshaped = q_t.expand(
+#             output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
+#         )
+#         query_buffer_reshaped = query_buffer_reshaped.reshape(
+#             Nnz_q, output_num_heads, head_dim_qk
+#         )
+#     else:
+#         if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+#             q_t = q_t.contiguous()
+#         # If we are broadcasting then Nnz_q will be the output_batch_size since
+#         # seq_len is 1
+#         effective_batch_size_q = (
+#             output_batch_size if q_batch_size_needs_broadcast else Nnz_q
+#         )
+#         query_buffer_reshaped = _view_as_dense(
+#             q_t, effective_batch_size_q, output_num_heads, head_dim_qk
+#         )
+
+#     # If the physical layout of the NestedTensor's storage
+#     # is not: batch, {seq_len}, num_heads, head_dim then we need
+#     # to call contiguous
+#     if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+#         k_t = k_t.contiguous()
+#     if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+#         v_t = v_t.contiguous()
+
+#     effective_batch_size_k = (
+#         output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     key_buffer_reshaped = _view_as_dense(
+#         k_t, effective_batch_size_k, output_num_heads, head_dim_qk
+#     )
+
+#     effective_batch_size_v = (
+#         output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     value_buffer_reshaped = _view_as_dense(
+#         v_t, effective_batch_size_v, output_num_heads, head_dim_v
+#     )
+
+#     if not q_batch_size_needs_broadcast:
+#         output_shape = q_t._size
+#         if head_dim_v != head_dim_qk:
+#             output_shape[-1] = head_dim_v
+#         if q_num_heads_needs_broadcast:
+#             output_shape[1] = output_num_heads
+#     else:
+#         output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
+#         output_shape[0] = q_t.size(1)
+#         output_shape[1] = output_num_heads
+#         output_shape[2] = head_dim_v
+
+#     return (
+#         query_buffer_reshaped,
+#         key_buffer_reshaped,
+#         value_buffer_reshaped,
+#         cumulative_sequence_length_q,
+#         cumulative_sequence_length_kv,
+#         max_seqlen_batch_q,
+#         max_seqlen_batch_kv,
+#         output_shape,
+#     )
+
+
+def _sdpa_nested_preprocessing(query, key, value):
+    # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+    # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    q_batch_size = query.size(0)
+    k_batch_size = key.size(0)
+    v_batch_size = value.size(0)
+
+    q_num_heads = query.size(1)
+    k_num_heads = key.size(1)
+    v_num_heads = value.size(1)
+
+    if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
+        q_num_heads == k_num_heads and k_num_heads == v_num_heads
+    ):
+        raise RuntimeError(
+            "This path is currently not implemented for jagged layout NT."
+        )
+        # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
+
+    num_heads = query.size(1)
+    head_dim_qk = query.size(3)
+    head_dim_v = value.size(3)
+    q_t = query.transpose(1, 2)
+    k_t = key.transpose(1, 2)
+    v_t = value.transpose(1, 2)
+
+    (
+        cumulative_sequence_length_q,
+        max_seqlen_batch_q,
+        Nnz_q,
+    ) = _cumulative_and_max_seq_len_nnz(q_t)
+    (
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_kv,
+        Nnz_kv,
+    ) = _cumulative_and_max_seq_len_nnz(k_t)
+
+    # [TODO] K and V have to have the same Nnz, should probably torch_check
+    # assume in order to not iterate over v
+
+    # If the physical layout of the NestedTensor's storage
+    # is not: batch, {seq_len}, num_heads, head_dim then we need
+    # to call contiguous
+    if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+        q_t = q_t.contiguous()
+    if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+        k_t = k_t.contiguous()
+    if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+        v_t = v_t.contiguous()
+
+    query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
+    key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
+    value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
+
+    output_nt_info = {
+        "offsets": q_t.offsets(),
+        "lengths": q_t.lengths(),
+        "max_seqlen": q_t._get_max_seqlen(),
+        "min_seqlen": q_t._get_min_seqlen(),
+    }
+
+    return (
+        query_buffer_reshaped,
+        key_buffer_reshaped,
+        value_buffer_reshaped,
+        cumulative_sequence_length_q,
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_q,
+        max_seqlen_batch_kv,
+        output_nt_info,
+    )
+
+
+def _pad_last_dim(
+    tensor: torch.Tensor, alignment_size: int, slice: bool
+) -> torch.Tensor:
+    # FlashAttentionV2 requires that head dimension be a multiple of 8
+    # This was previously done within the kernel, however
+    # This causes the kernel to maybe alias query, key, value
+    # So instead we pad the head_dimensions to be a multiple of 8
+    # in the composite region
+    last_dim_size = tensor.size(-1)
+    if last_dim_size % alignment_size == 0:
+        return tensor
+    pad_count = alignment_size - (last_dim_size % alignment_size)
+    tensor = torch.nn.functional.pad(tensor, [0, pad_count])
+    if slice:
+        return tensor[..., 0:last_dim_size]
+    return tensor
+
+
+# TODO: coalesce with torch/nn/utils/attention.py
+def _calculate_scale(query, scale):
+    # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
+    softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
+    return softmax_scale
+
+
+def _post_process_flash_output(out: torch.Tensor, og_size):
+    if not out.is_nested and out.size(-1) != og_size:
+        out = out[..., 0:og_size]
+    return out
+
+
+def _is_computing_meta_flops(x):
+    # Note: there's a use case of using meta tensors & the dispatch-based flop counter.
+    # We can use this function to check for this scenario in order to handle it specially.
+    if not torch.jit.is_scripting() and x.device.type == "meta":
+        torch_dispatch_mode_stack = (
+            torch.utils._python_dispatch._get_current_dispatch_mode_stack()
+        )
+        return any(
+            type(x) is torch.utils.flop_counter._FlopCounterMode
+            for x in torch_dispatch_mode_stack
+        )
+    return False
+
+
+def _autocast(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
+    """
+    [Autocasting SDPA for NJT]
+
+    Normal autocasting doesn't work for NJT+SDPA right now:
+    * NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
+      before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
+      efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
+      op won't work because we never see that aten op.
+    * If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
+      the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
+      won't work correctly: the kernel selection logic will run before autocasting, and choose
+      a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
+      attention computation will happen in a different dtype.
+
+    An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
+    and rely on autocasting to do the actual conversions for flash attention / efficient attention.
+    However, by manually doing the actual autocast before the backend selection, we ensure that the
+    autocast handling for backend selection doesn't diverge from the autocast handling for the
+    actual dtype conversions.
+    """
+    device_type = query.device.type
+    # meta device is not supported by autocast, so break early for it
+    if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
+        return query, key, value, attn_mask
+
+    def cvt(x):
+        if x is None:
+            return x
+        target_dtype = torch.get_autocast_dtype(device_type)
+        if (
+            (not x.dtype.is_floating_point)
+            or x.dtype == target_dtype
+            or x.dtype == torch.float64
+        ):
+            return x
+        return x.to(target_dtype)
+
+    return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
+
+
+def jagged_scaled_dot_product_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: torch.Tensor | None = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+):
+    query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
+    _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
+    # for mypy, ugh
+    assert (
+        isinstance(query, NestedTensor)
+        and isinstance(key, NestedTensor)
+        and isinstance(value, NestedTensor)
+    )
+    from torch.nested._internal.nested_tensor import (
+        nested_view_from_values_offsets_lengths,
+    )
+
+    # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
+    # second batch dim instead). For this case, we can just send the dense buffers through
+    # vanilla SDPA.
+    if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
+        output = F.scaled_dot_product_attention(
+            query.values(),
+            key.values(),
+            value.values(),
+            attn_mask=(
+                attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
+            ),
+            dropout_p=dropout_p,
+            is_causal=is_causal,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            output,
+            query.offsets(),
+            query.lengths(),
+            min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+            max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+        )
+
+    compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
+
+    backend_choice = _select_sdp_backend(
+        query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
+    )
+
+    if _is_computing_meta_flops(query):
+        # Backend choice will probably not be correct if we have a meta device,
+        # because backend choice is device-aware. In this case, we mostly just
+        # want to avoid using math backend (which does a .item() call).
+        # Arbitrarily choose flash attention.
+        backend_choice = SDPBackend.FLASH_ATTENTION
+
+    if backend_choice == SDPBackend.FLASH_ATTENTION:
+        og_size = query.size(-1)
+        query_padded = _pad_last_dim(query, 8, False)
+        key_padded = _pad_last_dim(key, 8, False)
+        value_padded = _pad_last_dim(value, 8, False)
+        # We need to calculate the scale based off the OG head dim size
+        og_scale = _calculate_scale(query, scale)
+        (
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
+        (
+            attention,
+            _logsumexp,
+            _philox_seed,
+            _philox_offset,
+            _debug_attn_mask,
+        ) = torch.ops.aten._flash_attention_forward(
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            is_causal,
+            False,
+            scale=og_scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        attention = nested_view_from_values_offsets_lengths(
+            attention,  # output from flash_attn is [total_q, num_heads, head_size_og]
+            **output_nt_info,
+        ).transpose(1, 2)
+        return _post_process_flash_output(attention, og_size)
+    elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            log_sumexp,
+            seed,
+            offset,
+            max_seqlen_q,
+            max_seqlen_batch_kv,
+        ) = torch.ops.aten._efficient_attention_forward(
+            query_reshaped.unsqueeze(0),
+            key_reshaped.unsqueeze(0),
+            value_reshaped.unsqueeze(0),
+            None,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            int(is_causal),
+            compute_logsumexp,
+            scale=scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        return nested_view_from_values_offsets_lengths(
+            attention.squeeze(0),
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.CUDNN_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            logsumexp,
+            cum_seqlen_q,
+            cum_seqlen_kv,
+            max_seqlen_q,
+            max_seqlen_kv,
+            seed,
+            offset,
+            _,
+        ) = torch.ops.aten._cudnn_attention_forward(
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            attn_mask,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            compute_logsumexp,
+            dropout_p,
+            is_causal,
+            False,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            attention,
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.MATH:
+        # save the offsets and shape of the inputs, so we can reshape the final output
+        # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
+        # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
+        offsets = query.offsets()
+        q_lengths = query.lengths()
+        min_seqlen = query._maybe_min_seqlen
+        max_seqlen = query._maybe_max_seqlen
+        d1 = query._size[1]
+        d2 = value._size[-1]
+
+        # convert jagged layout Nested Tensor to strided layout Nested Tensor
+        # which support the math implementation of SDPA
+        def get_strided_layout_nested_tensor(jagged_layout_nt):
+            lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
+            transpose = torch.transpose(jagged_layout_nt, 1, 2)
+            tensor_list = transpose.values().split(list(lengths), dim=0)
+            strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
+            strided_nt = strided_nt.transpose(1, 2).contiguous()
+            return strided_nt
+
+        query = get_strided_layout_nested_tensor(query)
+        key = get_strided_layout_nested_tensor(key)
+        value = get_strided_layout_nested_tensor(value)
+
+        attn_out = torch._scaled_dot_product_attention_math(
+            query, key, value, attn_mask, dropout_p, is_causal, scale=scale
+        )[0]
+
+        # convert strided layout Nested Tensor back to jagged layout Nested Tensor
+        attn_out = attn_out.transpose(1, 2).contiguous().values()
+        attn_out = attn_out.view(-1, d1, d2)
+        attn_out = nested_view_from_values_offsets_lengths(
+            attn_out,
+            offsets,
+            lengths=q_lengths,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        ).transpose(1, 2)
+
+        return attn_out
+    else:
+        raise RuntimeError(
+            "No viable backend for scaled_dot_product_attention was found."
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4801518acd873810e57c6c6794d0efdc0f91f06
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2098b8d985f4d6b56ac1e398703d26595aeb8ba3
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/numa/__pycache__/binding.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..0ce7803dbf78897298d81c2679f2cdb3c872bc15
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/ATen/ATenConfig.cmake
@@ -0,0 +1,9 @@
+# Find the TH includes and library
+#
+# ATEN_INCLUDE_DIR -- where to find the includes
+# ATEN_LIBRARIES -- list of libraries to link against
+# ATEN_FOUND -- set to 1 if found
+
+set(ATEN_FOUND 1)
+set(ATEN_INCLUDE_DIR "/pytorch/torch/include")
+set(ATEN_LIBRARIES "")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..2457dff032a8b824d173fe1cb2d4e787a7b9839c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake
@@ -0,0 +1,140 @@
+# - Config file for the Caffe2 package
+# It defines the following variable(s)
+#   CAFFE2_INCLUDE_DIRS     - include directories for FooBar
+# as well as Caffe2 targets for other cmake libraries to use.
+
+# library version information
+
+# Utils functions.
+include("${CMAKE_CURRENT_LIST_DIR}/public/utils.cmake")
+
+# Depending on whether Caffe2 uses gflags during compile time or
+# not, invoke gflags.
+if(OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/gflags.cmake")
+  if(NOT TARGET gflags)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses gflags but the gflags library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not "
+        "have gflags, you will need to install gflags and set the library "
+        "path accordingly.")
+  endif()
+endif()
+
+# Depending on whether Caffe2 uses glog during compile time or
+# not, invoke glog.
+if(OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/glog.cmake")
+  if(NOT TARGET glog::glog)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses glog but the glog library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH and/or GFLAGS_ROOT_DIR? If you do not "
+        "have glog, you will need to install glog and set the library "
+        "path accordingly.")
+  endif()
+endif()
+
+# Protobuf
+if(ON)
+  if(NOT TARGET protobuf::libprotobuf)
+    # Define protobuf::libprotobuf as a dummy target to resolve references to
+    # protobuf::libprotobuf in Caffe2Targets.cmake.
+    add_library(dummy INTERFACE)
+    add_library(protobuf::libprotobuf ALIAS dummy)
+  endif()
+else()
+  include("${CMAKE_CURRENT_LIST_DIR}/public/protobuf.cmake")
+  if(NOT TARGET protobuf::libprotobuf)
+    message(FATAL_ERROR
+        "Your installed Caffe2 version uses protobuf but the protobuf library "
+        "cannot be found. Did you accidentally remove it, or have you set "
+        "the right CMAKE_PREFIX_PATH? If you do not have protobuf, you will "
+        "need to install protobuf and set the library path accordingly.")
+  endif()
+  message(STATUS "Caffe2: Protobuf version " ${Protobuf_VERSION})
+  # If during build time we know the protobuf version, we will also do a sanity
+  # check to ensure that the protobuf library that Caffe2 found is consistent
+  # with the compiled version.
+  if(FALSE)
+    if(NOT (${Protobuf_VERSION} VERSION_EQUAL Protobuf_VERSION_NOTFOUND))
+      message(FATAL_ERROR
+          "Your installed Caffe2 is built with protobuf "
+          "Protobuf_VERSION_NOTFOUND"
+          ", while your current cmake setting discovers protobuf version "
+          ${Protobuf_VERSION}
+          ". Please specify a protobuf version that is the same as the built "
+          "version.")
+    endif()
+  endif()
+endif()
+
+if (OFF)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake")
+endif()
+
+if(ON)
+  # The file public/cuda.cmake exclusively uses CAFFE2_USE_*.
+  # If Caffe2 was compiled with the libraries below, they must
+  # be found again when including the Caffe2 target.
+  set(CAFFE2_USE_CUDA ON)
+
+  # Add current directory to module path so we pick up FindCUDAToolkit.cmake
+  set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}")
+  list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}")
+  include("${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake")
+  set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}")
+
+  if(ON AND NOT CAFFE2_USE_CUDA)
+    message(FATAL_ERROR
+      "Your installed Caffe2 version uses CUDA but I cannot find the CUDA "
+      "libraries. Please set the proper CUDA prefixes and / or install "
+      "CUDA.")
+  endif()
+endif()
+
+if(OFF)
+  # Add current directory to module path so we pick up FindSYCLToolkit.cmake
+  set(old_CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}")
+  list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}")
+  include("${CMAKE_CURRENT_LIST_DIR}/public/xpu.cmake")
+  set(CMAKE_MODULE_PATH "${old_CMAKE_MODULE_PATH}")
+
+  if(OFF AND NOT PYTORCH_FOUND_XPU)
+    message(FATAL_ERROR
+      "Your installed Caffe2 version uses XPU but I cannot find the XPU runtime"
+      "libraries. Please set the proper oneAPI paths and / or install "
+      "oneAPI.")
+  endif()
+endif()
+
+if(ON)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/mkl.cmake")
+endif()
+
+if(ON)
+  include("${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake")
+endif()
+
+# import targets
+include ("${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets.cmake")
+
+# Interface libraries, that allows one to build proper link flags.
+# We will also define a helper variable, Caffe2_MAIN_LIBS, that resolves to
+# the main caffe2 libraries in cases of cuda presence / absence.
+set(Caffe2_MAIN_LIBS torch_library)
+
+# include directory.
+#
+# Newer versions of CMake set the INTERFACE_INCLUDE_DIRECTORIES property
+# of the imported targets. It is hence not necessary to add this path
+# manually to the include search path for targets which link to gflags.
+# The following lines are here for backward compatibility, in case one
+# would like to use the old-style include path.
+get_filename_component(
+    CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+# Note: the current list dir is _INSTALL_PREFIX/share/cmake/Gloo.
+get_filename_component(
+    _INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
+set(CAFFE2_INCLUDE_DIRS "${_INSTALL_PREFIX}/include")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..721afaa1b956f721ecd584a69ae59de56f5e5064
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets-release.cmake
@@ -0,0 +1,71 @@
+#----------------------------------------------------------------
+# Generated CMake target import file for configuration "Release".
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Import target "c10_cuda" for configuration "Release"
+set_property(TARGET c10_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(c10_cuda PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libc10_cuda.so"
+  IMPORTED_SONAME_RELEASE "libc10_cuda.so"
+  )
+
+list(APPEND _cmake_import_check_targets c10_cuda )
+list(APPEND _cmake_import_check_files_for_c10_cuda "${_IMPORT_PREFIX}/lib/libc10_cuda.so" )
+
+# Import target "c10" for configuration "Release"
+set_property(TARGET c10 APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(c10 PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libc10.so"
+  IMPORTED_SONAME_RELEASE "libc10.so"
+  )
+
+list(APPEND _cmake_import_check_targets c10 )
+list(APPEND _cmake_import_check_files_for_c10 "${_IMPORT_PREFIX}/lib/libc10.so" )
+
+# Import target "torch_nvshmem" for configuration "Release"
+set_property(TARGET torch_nvshmem APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_nvshmem PROPERTIES
+  IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "torch_cpu"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_nvshmem.so"
+  IMPORTED_SONAME_RELEASE "libtorch_nvshmem.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_nvshmem )
+list(APPEND _cmake_import_check_files_for_torch_nvshmem "${_IMPORT_PREFIX}/lib/libtorch_nvshmem.so" )
+
+# Import target "torch_cpu" for configuration "Release"
+set_property(TARGET torch_cpu APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_cpu PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_cpu.so"
+  IMPORTED_SONAME_RELEASE "libtorch_cpu.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_cpu )
+list(APPEND _cmake_import_check_files_for_torch_cpu "${_IMPORT_PREFIX}/lib/libtorch_cpu.so" )
+
+# Import target "torch_cuda" for configuration "Release"
+set_property(TARGET torch_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch_cuda PROPERTIES
+  IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "torch_nvshmem"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch_cuda.so"
+  IMPORTED_SONAME_RELEASE "libtorch_cuda.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch_cuda )
+list(APPEND _cmake_import_check_files_for_torch_cuda "${_IMPORT_PREFIX}/lib/libtorch_cuda.so" )
+
+# Import target "torch" for configuration "Release"
+set_property(TARGET torch APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(torch PROPERTIES
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libtorch.so"
+  IMPORTED_SONAME_RELEASE "libtorch.so"
+  )
+
+list(APPEND _cmake_import_check_targets torch )
+list(APPEND _cmake_import_check_files_for_torch "${_IMPORT_PREFIX}/lib/libtorch.so" )
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..086cc1e2547c8f2ba2536d918a6676f65f38f56a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Caffe2Targets.cmake
@@ -0,0 +1,200 @@
+# Generated by CMake
+
+if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
+   message(FATAL_ERROR "CMake >= 3.0.0 required")
+endif()
+if(CMAKE_VERSION VERSION_LESS "3.0.0")
+   message(FATAL_ERROR "CMake >= 3.0.0 required")
+endif()
+cmake_policy(PUSH)
+cmake_policy(VERSION 3.0.0...4.0)
+#----------------------------------------------------------------
+# Generated CMake target import file.
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Protect against multiple inclusion, which would fail when already imported targets are added once more.
+set(_cmake_targets_defined "")
+set(_cmake_targets_not_defined "")
+set(_cmake_expected_targets "")
+foreach(_cmake_expected_target IN ITEMS headeronly c10_cuda c10 torch_nvshmem torch_cpu torch_cpu_library torch_cuda torch_cuda_library torch torch_library)
+  list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
+  if(TARGET "${_cmake_expected_target}")
+    list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
+  else()
+    list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
+  endif()
+endforeach()
+unset(_cmake_expected_target)
+if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
+  unset(_cmake_targets_defined)
+  unset(_cmake_targets_not_defined)
+  unset(_cmake_expected_targets)
+  unset(CMAKE_IMPORT_FILE_VERSION)
+  cmake_policy(POP)
+  return()
+endif()
+if(NOT _cmake_targets_defined STREQUAL "")
+  string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
+  string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
+  message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
+endif()
+unset(_cmake_targets_defined)
+unset(_cmake_targets_not_defined)
+unset(_cmake_expected_targets)
+
+
+# Compute the installation prefix relative to this file.
+get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+if(_IMPORT_PREFIX STREQUAL "/")
+  set(_IMPORT_PREFIX "")
+endif()
+
+# Create imported target headeronly
+add_library(headeronly INTERFACE IMPORTED)
+
+# Create imported target c10_cuda
+add_library(c10_cuda SHARED IMPORTED)
+
+set_target_properties(c10_cuda PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "c10;torch::cudart"
+)
+
+# Create imported target c10
+add_library(c10 SHARED IMPORTED)
+
+set_target_properties(c10 PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "headeronly"
+)
+
+# Create imported target torch_nvshmem
+add_library(torch_nvshmem SHARED IMPORTED)
+
+set_target_properties(torch_nvshmem PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_NVSHMEM"
+)
+
+# Create imported target torch_cpu
+add_library(torch_cpu SHARED IMPORTED)
+
+set_target_properties(torch_cpu PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_DISTRIBUTED;USE_C10D_GLOO;USE_RPC;USE_TENSORPIPE"
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "protobuf::libprotobuf;c10;caffe2::mkl"
+)
+
+# Create imported target torch_cpu_library
+add_library(torch_cpu_library INTERFACE IMPORTED)
+
+set_target_properties(torch_cpu_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Create imported target torch_cuda
+add_library(torch_cuda SHARED IMPORTED)
+
+set_target_properties(torch_cuda PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "USE_NVSHMEM;USE_C10D_NCCL"
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "torch::cudart;c10_cuda;torch_cpu_library"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "include"
+)
+
+# Create imported target torch_cuda_library
+add_library(torch_cuda_library INTERFACE IMPORTED)
+
+set_target_properties(torch_cuda_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Create imported target torch
+add_library(torch SHARED IMPORTED)
+
+set_target_properties(torch PROPERTIES
+  INTERFACE_LINK_LIBRARIES "torch_cpu_library;torch_cuda_library"
+)
+
+# Create imported target torch_library
+add_library(torch_library INTERFACE IMPORTED)
+
+set_target_properties(torch_library PROPERTIES
+  INTERFACE_COMPILE_DEFINITIONS "\$"
+  INTERFACE_COMPILE_OPTIONS "\$"
+  INTERFACE_INCLUDE_DIRECTORIES "\$"
+  INTERFACE_LINK_LIBRARIES "-Wl,--no-as-needed,\"\$\" -Wl,--as-needed;\$"
+  INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "\$"
+)
+
+# Load information for each installed configuration.
+file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/Caffe2Targets-*.cmake")
+foreach(_cmake_config_file IN LISTS _cmake_config_files)
+  include("${_cmake_config_file}")
+endforeach()
+unset(_cmake_config_file)
+unset(_cmake_config_files)
+
+# Cleanup temporary variables.
+set(_IMPORT_PREFIX)
+
+# Loop over all imported files and verify that they actually exist
+foreach(_cmake_target IN LISTS _cmake_import_check_targets)
+  if(CMAKE_VERSION VERSION_LESS "3.28"
+      OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
+      OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
+    foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
+      if(NOT EXISTS "${_cmake_file}")
+        message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
+   \"${_cmake_file}\"
+but this file does not exist.  Possible reasons include:
+* The file was deleted, renamed, or moved to another location.
+* An install or uninstall procedure did not complete successfully.
+* The installation package was faulty and contained
+   \"${CMAKE_CURRENT_LIST_FILE}\"
+but not all the files it references.
+")
+      endif()
+    endforeach()
+  endif()
+  unset(_cmake_file)
+  unset("_cmake_import_check_files_for_${_cmake_target}")
+endforeach()
+unset(_cmake_target)
+unset(_cmake_import_check_targets)
+
+# Make sure the targets which have been exported in some other
+# export set exist.
+unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+foreach(_target "protobuf::libprotobuf" )
+  if(NOT TARGET "${_target}" )
+    set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets "${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets} ${_target}")
+  endif()
+endforeach()
+
+if(DEFINED ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+  if(CMAKE_FIND_PACKAGE_NAME)
+    set( ${CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE)
+    set( ${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}")
+  else()
+    message(FATAL_ERROR "The following imported targets are referenced, but are missing: ${${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets}")
+  endif()
+endif()
+unset(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE_targets)
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
+cmake_policy(POP)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..ec9ae530aa6b2bdceb87f966e706fb5c2a36349a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDAToolkit.cmake
@@ -0,0 +1,1081 @@
+
+# This module is back-ported from CMake 3.17 and above to work with CMake 3.10
+
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+#[=======================================================================[.rst:
+FindCUDAToolkit
+---------------
+
+.. versionadded:: 3.17
+
+This script locates the NVIDIA CUDA toolkit and the associated libraries, but
+does not require the ``CUDA`` language be enabled for a given project. This
+module does not search for the NVIDIA CUDA Samples.
+
+.. versionadded:: 3.19
+  QNX support.
+
+Search Behavior
+^^^^^^^^^^^^^^^
+
+The CUDA Toolkit search behavior uses the following order:
+
+1. If the ``CUDA`` language has been enabled we will use the directory
+   containing the compiler as the first search location for ``nvcc``.
+
+2. If the ``CUDAToolkit_ROOT`` cmake configuration variable (e.g.,
+   ``-DCUDAToolkit_ROOT=/some/path``) *or* environment variable is defined, it
+   will be searched.  If both an environment variable **and** a
+   configuration variable are specified, the *configuration* variable takes
+   precedence.
+
+   The directory specified here must be such that the executable ``nvcc`` or
+   the appropriate ``version.txt`` file can be found underneath the specified
+   directory.
+
+3. If the CUDA_PATH environment variable is defined, it will be searched
+   for ``nvcc``.
+
+4. The user's path is searched for ``nvcc`` using :command:`find_program`.  If
+   this is found, no subsequent search attempts are performed.  Users are
+   responsible for ensuring that the first ``nvcc`` to show up in the path is
+   the desired path in the event that multiple CUDA Toolkits are installed.
+
+5. On Unix systems, if the symbolic link ``/usr/local/cuda`` exists, this is
+   used.  No subsequent search attempts are performed.  No default symbolic link
+   location exists for the Windows platform.
+
+6. The platform specific default install locations are searched.  If exactly one
+   candidate is found, this is used.  The default CUDA Toolkit install locations
+   searched are:
+
+   +-------------+-------------------------------------------------------------+
+   | Platform    | Search Pattern                                              |
+   +=============+=============================================================+
+   | macOS       | ``/Developer/NVIDIA/CUDA-X.Y``                              |
+   +-------------+-------------------------------------------------------------+
+   | Other Unix  | ``/usr/local/cuda-X.Y``                                     |
+   +-------------+-------------------------------------------------------------+
+   | Windows     | ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y`` |
+   +-------------+-------------------------------------------------------------+
+
+   Where ``X.Y`` would be a specific version of the CUDA Toolkit, such as
+   ``/usr/local/cuda-9.0`` or
+   ``C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0``
+
+   .. note::
+
+       When multiple CUDA Toolkits are installed in the default location of a
+       system(e.g., both ``/usr/local/cuda-9.0`` and ``/usr/local/cuda-10.0``
+       exist but the ``/usr/local/cuda`` symbolic link does **not** exist), this
+       package is marked as **not** found.
+
+       There are too many factors involved in making an automatic decision in
+       the presence of multiple CUDA Toolkits being installed.  In this
+       situation, users are encouraged to either (1) set ``CUDAToolkit_ROOT`` or
+       (2) ensure that the correct ``nvcc`` executable shows up in ``$PATH`` for
+       :command:`find_program` to find.
+
+Arguments
+^^^^^^^^^
+
+``[]``
+    The ``[]`` argument requests a version with which the package found
+    should be compatible. See :ref:`find_package version format `
+    for more details.
+
+Options
+^^^^^^^
+
+``REQUIRED``
+    If specified, configuration will error if a suitable CUDA Toolkit is not
+    found.
+
+``QUIET``
+    If specified, the search for a suitable CUDA Toolkit will not produce any
+    messages.
+
+``EXACT``
+    If specified, the CUDA Toolkit is considered found only if the exact
+    ``VERSION`` specified is recovered.
+
+Imported targets
+^^^^^^^^^^^^^^^^
+
+An :ref:`imported target ` named ``CUDA::toolkit`` is provided.
+
+This module defines :prop_tgt:`IMPORTED` targets for each
+of the following libraries that are part of the CUDAToolkit:
+
+- :ref:`CUDA Runtime Library`
+- :ref:`CUDA Driver Library`
+- :ref:`cuBLAS`
+- :ref:`cuFFT`
+- :ref:`cuRAND`
+- :ref:`cuSOLVER`
+- :ref:`cuSPARSE`
+- :ref:`cuPTI`
+- :ref:`NPP`
+- :ref:`nvBLAS`
+- :ref:`nvGRAPH`
+- :ref:`nvJPEG`
+- :ref:`nvidia-ML`
+- :ref:`nvRTC`
+- :ref:`nvToolsExt`
+- :ref:`OpenCL`
+- :ref:`cuLIBOS`
+
+.. _`cuda_toolkit_rt_lib`:
+
+CUDA Runtime Library
+""""""""""""""""""""
+
+The CUDA Runtime library (cudart) are what most applications will typically
+need to link against to make any calls such as `cudaMalloc`, and `cudaFree`.
+
+Targets Created:
+
+- ``CUDA::cudart``
+- ``CUDA::cudart_static``
+
+.. _`cuda_toolkit_driver_lib`:
+
+CUDA Driver Library
+""""""""""""""""""""
+
+The CUDA Driver library (cuda) are used by applications that use calls
+such as `cuMemAlloc`, and `cuMemFree`.
+
+Targets Created:
+
+- ``CUDA::cuda_driver``
+
+.. _`cuda_toolkit_cuBLAS`:
+
+cuBLAS
+""""""
+
+The `cuBLAS `_ library.
+
+Targets Created:
+
+- ``CUDA::cublas``
+- ``CUDA::cublas_static``
+- ``CUDA::cublasLt`` starting in CUDA 10.1
+- ``CUDA::cublasLt_static`` starting in CUDA 10.1
+
+.. _`cuda_toolkit_cuFFT`:
+
+cuFFT
+"""""
+
+The `cuFFT `_ library.
+
+Targets Created:
+
+- ``CUDA::cufft``
+- ``CUDA::cufftw``
+- ``CUDA::cufft_static``
+- ``CUDA::cufft_static_nocallback`` starting in CUDA 9.2, requires CMake 3.23+
+- ``CUDA::cufftw_static``
+
+cuRAND
+""""""
+
+The `cuRAND `_ library.
+
+Targets Created:
+
+- ``CUDA::curand``
+- ``CUDA::curand_static``
+
+.. _`cuda_toolkit_cuSOLVER`:
+
+cuSOLVER
+""""""""
+
+The `cuSOLVER `_ library.
+
+Targets Created:
+
+- ``CUDA::cusolver``
+- ``CUDA::cusolver_static``
+
+.. _`cuda_toolkit_cuSPARSE`:
+
+cuSPARSE
+""""""""
+
+The `cuSPARSE `_ library.
+
+Targets Created:
+
+- ``CUDA::cusparse``
+- ``CUDA::cusparse_static``
+
+.. _`cuda_toolkit_cupti`:
+
+cupti
+"""""
+
+The `NVIDIA CUDA Profiling Tools Interface `_.
+
+Targets Created:
+
+- ``CUDA::cupti``
+- ``CUDA::cupti_static``
+
+.. _`cuda_toolkit_NPP`:
+
+NPP
+"""
+
+The `NPP `_ libraries.
+
+Targets Created:
+
+- `nppc`:
+
+  - ``CUDA::nppc``
+  - ``CUDA::nppc_static``
+
+- `nppial`: Arithmetic and logical operation functions in `nppi_arithmetic_and_logical_operations.h`
+
+  - ``CUDA::nppial``
+  - ``CUDA::nppial_static``
+
+- `nppicc`: Color conversion and sampling functions in `nppi_color_conversion.h`
+
+  - ``CUDA::nppicc``
+  - ``CUDA::nppicc_static``
+
+- `nppicom`: JPEG compression and decompression functions in `nppi_compression_functions.h`
+  Removed starting in CUDA 11.0, use :ref:`nvJPEG` instead.
+
+  - ``CUDA::nppicom``
+  - ``CUDA::nppicom_static``
+
+- `nppidei`: Data exchange and initialization functions in `nppi_data_exchange_and_initialization.h`
+
+  - ``CUDA::nppidei``
+  - ``CUDA::nppidei_static``
+
+- `nppif`: Filtering and computer vision functions in `nppi_filter_functions.h`
+
+  - ``CUDA::nppif``
+  - ``CUDA::nppif_static``
+
+- `nppig`: Geometry transformation functions found in `nppi_geometry_transforms.h`
+
+  - ``CUDA::nppig``
+  - ``CUDA::nppig_static``
+
+- `nppim`: Morphological operation functions found in `nppi_morphological_operations.h`
+
+  - ``CUDA::nppim``
+  - ``CUDA::nppim_static``
+
+- `nppist`: Statistics and linear transform in `nppi_statistics_functions.h` and `nppi_linear_transforms.h`
+
+  - ``CUDA::nppist``
+  - ``CUDA::nppist_static``
+
+- `nppisu`: Memory support functions in `nppi_support_functions.h`
+
+  - ``CUDA::nppisu``
+  - ``CUDA::nppisu_static``
+
+- `nppitc`: Threshold and compare operation functions in `nppi_threshold_and_compare_operations.h`
+
+  - ``CUDA::nppitc``
+  - ``CUDA::nppitc_static``
+
+- `npps`:
+
+  - ``CUDA::npps``
+  - ``CUDA::npps_static``
+
+.. _`cuda_toolkit_nvBLAS`:
+
+nvBLAS
+""""""
+
+The `nvBLAS `_ libraries.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvblas``
+
+.. _`cuda_toolkit_nvGRAPH`:
+
+nvGRAPH
+"""""""
+
+The `nvGRAPH `_ library.
+Removed starting in CUDA 11.0
+
+Targets Created:
+
+- ``CUDA::nvgraph``
+- ``CUDA::nvgraph_static``
+
+
+.. _`cuda_toolkit_nvJPEG`:
+
+nvJPEG
+""""""
+
+The `nvJPEG `_ library.
+Introduced in CUDA 10.
+
+Targets Created:
+
+- ``CUDA::nvjpeg``
+- ``CUDA::nvjpeg_static``
+
+.. _`cuda_toolkit_nvRTC`:
+
+nvRTC
+"""""
+
+The `nvRTC `_ (Runtime Compilation) library.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvrtc``
+
+.. _`cuda_toolkit_nvml`:
+
+nvidia-ML
+"""""""""
+
+The `NVIDIA Management Library `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvml``
+
+.. _`cuda_toolkit_nvToolsExt`:
+
+nvToolsExt
+""""""""""
+
+The `NVIDIA Tools Extension `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::nvToolsExt``
+
+.. _`cuda_toolkit_opencl`:
+
+OpenCL
+""""""
+
+The `NVIDIA OpenCL Library `_.
+This is a shared library only.
+
+Targets Created:
+
+- ``CUDA::OpenCL``
+
+.. _`cuda_toolkit_cuLIBOS`:
+
+cuLIBOS
+"""""""
+
+The cuLIBOS library is a backend thread abstraction layer library which is
+static only.  The ``CUDA::cublas_static``, ``CUDA::cusparse_static``,
+``CUDA::cufft_static``, ``CUDA::curand_static``, and (when implemented) NPP
+libraries all automatically have this dependency linked.
+
+Target Created:
+
+- ``CUDA::culibos``
+
+**Note**: direct usage of this target by consumers should not be necessary.
+
+.. _`cuda_toolkit_cuRAND`:
+
+
+
+Result variables
+^^^^^^^^^^^^^^^^
+
+``CUDAToolkit_FOUND``
+    A boolean specifying whether or not the CUDA Toolkit was found.
+
+``CUDAToolkit_VERSION``
+    The exact version of the CUDA Toolkit found (as reported by
+    ``nvcc --version`` or ``version.txt``).
+
+``CUDAToolkit_VERSION_MAJOR``
+    The major version of the CUDA Toolkit.
+
+``CUDAToolkit_VERSION_MINOR``
+    The minor version of the CUDA Toolkit.
+
+``CUDAToolkit_VERSION_PATCH``
+    The patch version of the CUDA Toolkit.
+
+``CUDAToolkit_BIN_DIR``
+    The path to the CUDA Toolkit library directory that contains the CUDA
+    executable ``nvcc``.
+
+``CUDAToolkit_INCLUDE_DIRS``
+    The path to the CUDA Toolkit ``include`` folder containing the header files
+    required to compile a project linking against CUDA.
+
+``CUDAToolkit_LIBRARY_DIR``
+    The path to the CUDA Toolkit library directory that contains the CUDA
+    Runtime library ``cudart``.
+
+``CUDAToolkit_LIBRARY_ROOT``
+    .. versionadded:: 3.18
+
+    The path to the CUDA Toolkit directory containing the nvvm directory and
+    version.txt.
+
+``CUDAToolkit_TARGET_DIR``
+    The path to the CUDA Toolkit directory including the target architecture
+    when cross-compiling. When not cross-compiling this will be equivalent to
+    the parent directory of ``CUDAToolkit_BIN_DIR``.
+
+``CUDAToolkit_NVCC_EXECUTABLE``
+    The path to the NVIDIA CUDA compiler ``nvcc``.  Note that this path may
+    **not** be the same as
+    :variable:`CMAKE_CUDA_COMPILER _COMPILER>`.  ``nvcc`` must be
+    found to determine the CUDA Toolkit version as well as determining other
+    features of the Toolkit.  This variable is set for the convenience of
+    modules that depend on this one.
+
+
+#]=======================================================================]
+
+# NOTE: much of this was simply extracted from FindCUDA.cmake.
+
+#   James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#   Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#   Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#   Copyright (c) 2007-2009
+#   Scientific Computing and Imaging Institute, University of Utah
+#
+#   This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#   for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+###############################################################################
+
+# The toolkit is located during compiler detection for CUDA and stored in CMakeCUDACompiler.cmake as
+# CMAKE_CUDA_COMPILER_TOOLKIT_ROOT and CMAKE_CUDA_COMPILER_LIBRARY_ROOT.
+# We compute the rest based on those here to avoid re-searching and to avoid finding a possibly
+# different installation.
+if(CMAKE_CUDA_COMPILER_TOOLKIT_ROOT)
+  set(CUDAToolkit_ROOT_DIR "${CMAKE_CUDA_COMPILER_TOOLKIT_ROOT}")
+  set(CUDAToolkit_LIBRARY_ROOT "${CMAKE_CUDA_COMPILER_LIBRARY_ROOT}")
+  set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}")
+
+  if(CUDAToolkit_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+    set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+    set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+    set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+  endif()
+else()
+  function(_CUDAToolkit_find_root_dir )
+    cmake_parse_arguments(arg "" "" "SEARCH_PATHS;FIND_FLAGS" ${ARGN})
+
+    if(NOT CUDAToolkit_BIN_DIR)
+      if(NOT CUDAToolkit_SENTINEL_FILE)
+        find_program(CUDAToolkit_NVCC_EXECUTABLE
+          NAMES nvcc nvcc.exe
+          PATHS ${arg_SEARCH_PATHS}
+          ${arg_FIND_FLAGS}
+        )
+      endif()
+
+      if(NOT CUDAToolkit_NVCC_EXECUTABLE)
+        find_file(CUDAToolkit_SENTINEL_FILE
+          NAMES version.txt
+          PATHS ${arg_SEARCH_PATHS}
+          NO_DEFAULT_PATH
+        )
+      endif()
+
+      if(EXISTS "${CUDAToolkit_NVCC_EXECUTABLE}")
+        # If NVCC exists  then invoke it to find the toolkit location.
+        # This allows us to support wrapper scripts (e.g. ccache or colornvcc), CUDA Toolkit,
+        # NVIDIA HPC SDK, and distro's splayed layouts
+        execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "-v" "__cmake_determine_cuda"
+          OUTPUT_VARIABLE _CUDA_NVCC_OUT ERROR_VARIABLE _CUDA_NVCC_OUT)
+        if(_CUDA_NVCC_OUT MATCHES "\\#\\$ TOP=([^\r\n]*)")
+          get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_MATCH_1}/bin" ABSOLUTE)
+        else()
+          get_filename_component(CUDAToolkit_BIN_DIR "${CUDAToolkit_NVCC_EXECUTABLE}" DIRECTORY)
+        endif()
+        unset(_CUDA_NVCC_OUT)
+
+        mark_as_advanced(CUDAToolkit_BIN_DIR)
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE)
+      endif()
+
+      if(CUDAToolkit_SENTINEL_FILE)
+        get_filename_component(CUDAToolkit_BIN_DIR ${CUDAToolkit_SENTINEL_FILE} DIRECTORY ABSOLUTE)
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}/bin")
+
+        set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "" FORCE)
+        mark_as_advanced(CUDAToolkit_BIN_DIR)
+      endif()
+    endif()
+
+    if(CUDAToolkit_BIN_DIR)
+      get_filename_component(CUDAToolkit_ROOT_DIR ${CUDAToolkit_BIN_DIR} DIRECTORY ABSOLUTE)
+      set(CUDAToolkit_ROOT_DIR "${CUDAToolkit_ROOT_DIR}" PARENT_SCOPE)
+    endif()
+
+  endfunction()
+
+  # For NVCC we can easily deduce the SDK binary directory from the compiler path.
+  if(CMAKE_CUDA_COMPILER_LOADED AND NOT CUDAToolkit_BIN_DIR AND CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA")
+    get_filename_component(CUDAToolkit_BIN_DIR "${CMAKE_CUDA_COMPILER}" DIRECTORY)
+    set(CUDAToolkit_BIN_DIR "${CUDAToolkit_BIN_DIR}" CACHE PATH "")
+    # Try language provided path first.
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_BIN_DIR}" FIND_FLAGS NO_DEFAULT_PATH)
+    mark_as_advanced(CUDAToolkit_BIN_DIR)
+  endif()
+
+  # Try user provided path
+  if(NOT CUDAToolkit_ROOT_DIR AND CUDAToolkit_ROOT)
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${CUDAToolkit_ROOT}" FIND_FLAGS PATH_SUFFIXES bin NO_DEFAULT_PATH)
+  endif()
+  if(NOT CUDAToolkit_ROOT_DIR)
+    _CUDAToolkit_find_root_dir(FIND_FLAGS PATHS ENV CUDA_PATH PATH_SUFFIXES bin)
+  endif()
+
+  # If the user specified CUDAToolkit_ROOT but the toolkit could not be found, this is an error.
+  if(NOT CUDAToolkit_ROOT_DIR AND (DEFINED CUDAToolkit_ROOT OR DEFINED ENV{CUDAToolkit_ROOT}))
+    # Declare error messages now, print later depending on find_package args.
+    set(fail_base "Could not find nvcc executable in path specified by")
+    set(cuda_root_fail "${fail_base} CUDAToolkit_ROOT=${CUDAToolkit_ROOT}")
+    set(env_cuda_root_fail "${fail_base} environment variable CUDAToolkit_ROOT=$ENV{CUDAToolkit_ROOT}")
+
+    if(CUDAToolkit_FIND_REQUIRED)
+      if(DEFINED CUDAToolkit_ROOT)
+        message(FATAL_ERROR ${cuda_root_fail})
+      elseif(DEFINED ENV{CUDAToolkit_ROOT})
+        message(FATAL_ERROR ${env_cuda_root_fail})
+      endif()
+    else()
+      if(NOT CUDAToolkit_FIND_QUIETLY)
+        if(DEFINED CUDAToolkit_ROOT)
+          message(STATUS ${cuda_root_fail})
+        elseif(DEFINED ENV{CUDAToolkit_ROOT})
+          message(STATUS ${env_cuda_root_fail})
+        endif()
+      endif()
+      set(CUDAToolkit_FOUND FALSE)
+      unset(fail_base)
+      unset(cuda_root_fail)
+      unset(env_cuda_root_fail)
+      return()
+    endif()
+  endif()
+
+  # CUDAToolkit_ROOT cmake / env variable not specified, try platform defaults.
+  #
+  # - Linux: /usr/local/cuda-X.Y
+  # - macOS: /Developer/NVIDIA/CUDA-X.Y
+  # - Windows: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y
+  #
+  # We will also search the default symlink location /usr/local/cuda first since
+  # if CUDAToolkit_ROOT is not specified, it is assumed that the symlinked
+  # directory is the desired location.
+  if(NOT CUDAToolkit_ROOT_DIR)
+    if(UNIX)
+      if(NOT APPLE)
+        set(platform_base "/usr/local/cuda-")
+      else()
+        set(platform_base "/Developer/NVIDIA/CUDA-")
+      endif()
+    else()
+      set(platform_base "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v")
+    endif()
+
+    # Build out a descending list of possible cuda installations, e.g.
+    file(GLOB possible_paths "${platform_base}*")
+    # Iterate the glob results and create a descending list.
+    set(versions)
+    foreach(p ${possible_paths})
+      # Extract version number from end of string
+      string(REGEX MATCH "[0-9][0-9]?\\.[0-9]$" p_version ${p})
+      if(IS_DIRECTORY ${p} AND p_version)
+        list(APPEND versions ${p_version})
+      endif()
+    endforeach()
+
+    # Sort numerically in descending order, so we try the newest versions first.
+    if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
+      list(SORT versions COMPARE NATURAL ORDER DESCENDING)
+    elseif(versions)
+      # Alphabetical sort here is not ideal but better than nothing
+      list(SORT versions)
+      list(REVERSE versions)
+    endif()
+
+    # With a descending list of versions, populate possible paths to search.
+    set(search_paths)
+    foreach(v ${versions})
+      list(APPEND search_paths "${platform_base}${v}")
+    endforeach()
+
+    # Force the global default /usr/local/cuda to the front on Unix.
+    if(UNIX)
+      list(INSERT search_paths 0 "/usr/local/cuda")
+    endif()
+
+    # Now search for the toolkit again using the platform default search paths.
+    _CUDAToolkit_find_root_dir(SEARCH_PATHS "${search_paths}" FIND_FLAGS PATH_SUFFIXES bin)
+
+    # We are done with these variables now, cleanup for caller.
+    unset(platform_base)
+    unset(possible_paths)
+    unset(versions)
+    unset(search_paths)
+
+    if(NOT CUDAToolkit_ROOT_DIR)
+      if(CUDAToolkit_FIND_REQUIRED)
+        message(FATAL_ERROR "Could not find nvcc, please set CUDAToolkit_ROOT.")
+      elseif(NOT CUDAToolkit_FIND_QUIETLY)
+        message(STATUS "Could not find nvcc, please set CUDAToolkit_ROOT.")
+      endif()
+
+      set(CUDAToolkit_FOUND FALSE)
+      return()
+    endif()
+  endif()
+endif()
+
+if(NOT CUDAToolkit_BIN_DIR)
+  set(CUDAToolkit_BIN_DIR "${CUDAToolkit_ROOT_DIR}/bin")
+endif()
+
+if(NOT CUDAToolkit_NVCC_EXECUTABLE)
+  set(CUDAToolkit_NVCC_EXECUTABLE "${CUDAToolkit_BIN_DIR}/nvcc${CMAKE_EXECUTABLE_SUFFIX}")
+endif()
+
+if(CMAKE_CUDA_COMPILER_TOOLKIT_VERSION)
+  set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_TOOLKIT_VERSION}")
+else()
+  function(_CUDAToolkit_find_version_file result_variable)
+    # We first check for a non-scattered installation to prefer it over a scattered installation.
+    if(CUDAToolkit_ROOT AND EXISTS "${CUDAToolkit_ROOT}/version.txt")
+      set(${result_variable} "${CUDAToolkit_ROOT}/version.txt" PARENT_SCOPE)
+    elseif(CUDAToolkit_ROOT_DIR AND EXISTS "${CUDAToolkit_ROOT_DIR}/version.txt")
+      set(${result_variable} "${CUDAToolkit_ROOT_DIR}/version.txt" PARENT_SCOPE)
+    elseif(CMAKE_SYSROOT_LINK AND EXISTS "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt")
+      set(${result_variable} "${CMAKE_SYSROOT_LINK}/usr/lib/cuda/version.txt" PARENT_SCOPE)
+    elseif(EXISTS "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt")
+      set(${result_variable} "${CMAKE_SYSROOT}/usr/lib/cuda/version.txt" PARENT_SCOPE)
+    endif()
+  endfunction()
+
+  _CUDAToolkit_find_version_file( _CUDAToolkit_version_file )
+  if(_CUDAToolkit_version_file)
+    # CUDAToolkit_LIBRARY_ROOT contains the device library and version file.
+    get_filename_component(CUDAToolkit_LIBRARY_ROOT "${_CUDAToolkit_version_file}" DIRECTORY ABSOLUTE)
+  endif()
+  unset(_CUDAToolkit_version_file)
+
+  if(CUDAToolkit_NVCC_EXECUTABLE AND
+     CMAKE_CUDA_COMPILER_VERSION AND
+     CUDAToolkit_NVCC_EXECUTABLE STREQUAL CMAKE_CUDA_COMPILER)
+    # Need to set these based off the already computed CMAKE_CUDA_COMPILER_VERSION value
+    # This if statement will always match, but is used to provide variables for MATCH 1,2,3...
+    if(CMAKE_CUDA_COMPILER_VERSION MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+      set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+      set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+      set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+      set(CUDAToolkit_VERSION "${CMAKE_CUDA_COMPILER_VERSION}")
+    endif()
+  elseif(CUDAToolkit_NVCC_EXECUTABLE)
+    # Compute the version by invoking nvcc
+    execute_process(COMMAND ${CUDAToolkit_NVCC_EXECUTABLE} "--version" OUTPUT_VARIABLE NVCC_OUT)
+    if(NVCC_OUT MATCHES [=[ V([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+      set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+      set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+      set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+      set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}")
+    endif()
+    unset(NVCC_OUT)
+  else()
+    _CUDAToolkit_find_version_file(version_file)
+    if(version_file)
+      file(READ "${version_file}" VERSION_INFO)
+      if(VERSION_INFO MATCHES [=[CUDA Version ([0-9]+)\.([0-9]+)\.([0-9]+)]=])
+        set(CUDAToolkit_VERSION_MAJOR "${CMAKE_MATCH_1}")
+        set(CUDAToolkit_VERSION_MINOR "${CMAKE_MATCH_2}")
+        set(CUDAToolkit_VERSION_PATCH "${CMAKE_MATCH_3}")
+        set(CUDAToolkit_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}")
+      endif()
+    endif()
+  endif()
+endif()
+
+# Find target directory when crosscompiling.
+if(CMAKE_CROSSCOMPILING)
+  if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a")
+    # Support for NVPACK
+    set(CUDAToolkit_TARGET_NAME "armv7-linux-androideabi")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm")
+    set(CUDAToolkit_TARGET_NAME "armv7-linux-gnueabihf")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
+    if(ANDROID_ARCH_NAME STREQUAL "arm64")
+      set(CUDAToolkit_TARGET_NAME "aarch64-linux-androideabi")
+    elseif(CMAKE_SYSTEM_NAME STREQUAL "QNX")
+      set(CUDAToolkit_TARGET_NAME "aarch64-qnx")
+    else()
+      set(CUDAToolkit_TARGET_NAME "aarch64-linux")
+    endif(ANDROID_ARCH_NAME STREQUAL "arm64")
+  elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
+    set(CUDAToolkit_TARGET_NAME "x86_64-linux")
+  endif()
+
+  if(EXISTS "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}")
+    set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}/targets/${CUDAToolkit_TARGET_NAME}")
+    # add known CUDA target root path to the set of directories we search for programs, libraries and headers
+    list(PREPEND CMAKE_FIND_ROOT_PATH "${CUDAToolkit_TARGET_DIR}")
+
+    # Mark that we need to pop the root search path changes after we have
+    # found all cuda libraries so that searches for our cross-compilation
+    # libraries work when another cuda sdk is in CMAKE_PREFIX_PATH or
+    # PATh
+    set(_CUDAToolkit_Pop_ROOT_PATH True)
+  endif()
+endif()
+
+# If not already set we can simply use the toolkit root or it's a scattered installation.
+if(NOT CUDAToolkit_TARGET_DIR)
+  # Not cross compiling
+  set(CUDAToolkit_TARGET_DIR "${CUDAToolkit_ROOT_DIR}")
+  # Now that we have the real ROOT_DIR, find components inside it.
+  list(APPEND CMAKE_PREFIX_PATH ${CUDAToolkit_ROOT_DIR})
+
+  # Mark that we need to pop the prefix path changes after we have
+  # found the cudart library.
+  set(_CUDAToolkit_Pop_Prefix True)
+endif()
+
+# CUDAToolkit_TARGET_DIR always points to the directory containing the include directory.
+# On a scattered installation /usr, on a non-scattered something like /usr/local/cuda or /usr/local/cuda-10.2/targets/aarch64-linux.
+if(EXISTS "${CUDAToolkit_TARGET_DIR}/include/cuda_runtime.h")
+  set(CUDAToolkit_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/include")
+elseif(NOT CUDAToolkit_FIND_QUIETLY)
+  message(STATUS "Unable to find cuda_runtime.h in \"${CUDAToolkit_TARGET_DIR}/include\" for CUDAToolkit_INCLUDE_DIR.")
+endif()
+
+# The NVHPC layout moves math library headers and libraries to a sibling directory.
+# Create a separate variable so this directory can be selectively added to math targets.
+if(NOT EXISTS "${CUDAToolkit_INCLUDE_DIR}/cublas_v2.h")
+  set(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_TARGET_DIR}/../../math_libs/include")
+  get_filename_component(CUDAToolkit_MATH_INCLUDE_DIR "${CUDAToolkit_MATH_INCLUDE_DIR}" ABSOLUTE)
+  if(NOT EXISTS "${CUDAToolkit_MATH_INCLUDE_DIR}/cublas_v2.h")
+    if(NOT CUDAToolkit_FIND_QUIETLY)
+      message(STATUS "Unable to find cublas_v2.h in either \"${CUDAToolkit_INCLUDE_DIR}\" or \"${CUDAToolkit_MATH_INCLUDE_DIR}\"")
+    endif()
+    unset(CUDAToolkit_MATH_INCLUDE_DIR)
+  endif()
+endif()
+
+# Find the CUDA Runtime Library libcudart
+find_library(CUDA_CUDART
+  NAMES cudart
+  PATH_SUFFIXES lib64 lib/x64
+)
+find_library(CUDA_CUDART
+  NAMES cudart
+  PATH_SUFFIXES lib64/stubs lib/x64/stubs
+)
+
+if(NOT CUDA_CUDART AND NOT CUDAToolkit_FIND_QUIETLY)
+  message(STATUS "Unable to find cudart library.")
+endif()
+
+if(_CUDAToolkit_Pop_Prefix)
+  list(REMOVE_AT CMAKE_PREFIX_PATH -1)
+  unset(_CUDAToolkit_Pop_Prefix)
+endif()
+
+#-----------------------------------------------------------------------------
+# Perform version comparison and validate all required variables are set.
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(CUDAToolkit
+  REQUIRED_VARS
+    CUDAToolkit_INCLUDE_DIR
+    CUDAToolkit_VERSION
+    CUDA_CUDART
+    CUDAToolkit_BIN_DIR
+  VERSION_VAR
+    CUDAToolkit_VERSION
+)
+
+mark_as_advanced(CUDA_CUDART
+                 CUDAToolkit_INCLUDE_DIR
+                 CUDAToolkit_NVCC_EXECUTABLE
+                 CUDAToolkit_SENTINEL_FILE
+                 )
+
+#-----------------------------------------------------------------------------
+# Construct result variables
+if(CUDAToolkit_FOUND)
+  set(CUDAToolkit_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIR})
+  get_filename_component(CUDAToolkit_LIBRARY_DIR ${CUDA_CUDART} DIRECTORY ABSOLUTE)
+endif()
+
+#-----------------------------------------------------------------------------
+# Construct import targets
+if(CUDAToolkit_FOUND)
+
+  function(_CUDAToolkit_find_and_add_import_lib lib_name)
+    cmake_parse_arguments(arg "" "" "ALT;DEPS;EXTRA_HINTS;EXTRA_PATH_SUFFIXES;EXTRA_INCLUDE_DIRS" ${ARGN})
+
+    set(search_names ${lib_name} ${arg_ALT})
+
+    find_library(CUDA_${lib_name}_LIBRARY
+      NAMES ${search_names}
+      HINTS ${CUDAToolkit_LIBRARY_DIR}
+            ENV CUDA_PATH
+            ${arg_EXTRA_HINTS}
+      PATH_SUFFIXES nvidia/current lib64 lib/x64 lib
+                    ${arg_EXTRA_PATH_SUFFIXES}
+    )
+    # Don't try any stub directories until we have exhausted all other
+    # search locations.
+    find_library(CUDA_${lib_name}_LIBRARY
+      NAMES ${search_names}
+      HINTS ${CUDAToolkit_LIBRARY_DIR}
+            ENV CUDA_PATH
+            ${arg_EXTRA_HINTS}
+      PATH_SUFFIXES lib64/stubs lib/x64/stubs lib/stubs stubs
+                    # Support NVHPC splayed math library layout
+                    ../../math_libs/${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}/lib64
+                    ../../math_libs/lib64
+    )
+
+    mark_as_advanced(CUDA_${lib_name}_LIBRARY)
+
+    if(NOT TARGET CUDA::${lib_name} AND CUDA_${lib_name}_LIBRARY)
+      add_library(CUDA::${lib_name} UNKNOWN IMPORTED)
+      set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+          INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+      set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+          INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+      if(DEFINED CUDAToolkit_MATH_INCLUDE_DIR)
+        string(FIND ${CUDA_${lib_name}_LIBRARY} "math_libs" math_libs)
+        if(NOT ${math_libs} EQUAL -1)
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}")
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_MATH_INCLUDE_DIRS}")
+        endif()
+      endif()
+      set_property(TARGET CUDA::${lib_name} PROPERTY IMPORTED_LOCATION "${CUDA_${lib_name}_LIBRARY}")
+      foreach(dep ${arg_DEPS})
+        if(TARGET CUDA::${dep})
+          set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+              INTERFACE_LINK_LIBRARIES CUDA::${dep})
+        endif()
+      endforeach()
+      if(arg_EXTRA_INCLUDE_DIRS)
+        set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+            INTERFACE_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}")
+        set_property(TARGET CUDA::${lib_name} APPEND PROPERTY
+            INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${arg_EXTRA_INCLUDE_DIRS}")
+      endif()
+    endif()
+  endfunction()
+
+  if(NOT TARGET CUDA::toolkit)
+    add_library(CUDA::toolkit IMPORTED INTERFACE)
+    set_property(TARGET CUDA::toolkit APPEND PROPERTY
+        INTERFACE_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+    set_property(TARGET CUDA::toolkit APPEND PROPERTY
+        INTERFACE_SYSTEM_INCLUDE_DIRECTORIES "${CUDAToolkit_INCLUDE_DIRS}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(cuda_driver ALT cuda)
+
+  _CUDAToolkit_find_and_add_import_lib(cudart)
+  _CUDAToolkit_find_and_add_import_lib(cudart_static)
+
+  # setup dependencies that are required for cudart_static when building
+  # on linux. These are generally only required when using the CUDA toolkit
+  # when CUDA language is disabled
+  if(NOT TARGET CUDA::cudart_static_deps
+     AND TARGET CUDA::cudart_static)
+
+    add_library(CUDA::cudart_static_deps IMPORTED INTERFACE)
+    set_property(TARGET CUDA::cudart_static APPEND PROPERTY
+        INTERFACE_LINK_LIBRARIES CUDA::cudart_static_deps)
+
+    if(UNIX AND (CMAKE_C_COMPILER OR CMAKE_CXX_COMPILER))
+      find_package(Threads REQUIRED)
+      set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY
+          INTERFACE_LINK_LIBRARIES Threads::Threads ${CMAKE_DL_LIBS})
+    endif()
+
+    if(UNIX AND NOT APPLE AND NOT (CMAKE_SYSTEM_NAME STREQUAL "QNX"))
+      # On Linux, you must link against librt when using the static cuda runtime.
+      find_library(CUDAToolkit_rt_LIBRARY rt)
+      mark_as_advanced(CUDAToolkit_rt_LIBRARY)
+      if(NOT CUDAToolkit_rt_LIBRARY)
+        message(WARNING "Could not find librt library, needed by CUDA::cudart_static")
+      else()
+        set_property(TARGET CUDA::cudart_static_deps APPEND PROPERTY
+            INTERFACE_LINK_LIBRARIES ${CUDAToolkit_rt_LIBRARY})
+      endif()
+    endif()
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(culibos) # it's a static library
+  foreach(cuda_lib cublasLt cufft curand cusparse nppc nvjpeg)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib})
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS culibos)
+  endforeach()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.0.0)
+    # cublas depends on cublasLt
+    # https://docs.nvidia.com/cuda/archive/11.0/cublas/index.html#static-library
+    _CUDAToolkit_find_and_add_import_lib(cublas DEPS cublasLt)
+    _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS cublasLt_static)
+  else()
+    _CUDAToolkit_find_and_add_import_lib(cublas)
+    _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
+  endif()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
+    _CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
+    _CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
+
+    _CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
+    _CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
+  endif()
+
+  # cuFFTW depends on cuFFT
+  _CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
+  _CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 9.2)
+    _CUDAToolkit_find_and_add_import_lib(cufft_static_nocallback DEPS culibos)
+  endif()
+
+  # cuSOLVER depends on cuBLAS, and cuSPARSE
+  _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublas cusparse)
+  _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cublas_static cusparse_static culibos)
+
+
+  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 10.1.2)
+    # cusolver depends on liblapack_static.a starting with CUDA 10.1 update 2,
+    # https://docs.nvidia.com/cuda/archive/11.5.0/cusolver/index.html#static-link-lapack
+    _CUDAToolkit_find_and_add_import_lib(cusolver_lapack_static ALT lapack_static) # implementation detail static lib
+    _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_lapack_static)
+  endif()
+
+  if(CUDAToolkit_VERSION VERSION_GREATER 11.2.1)
+    # cusolver depends on libcusolver_metis and cublasLt
+    # https://docs.nvidia.com/cuda/archive/11.2.2/cusolver/index.html#link-dependency
+    _CUDAToolkit_find_and_add_import_lib(cusolver DEPS cublasLt)
+
+    _CUDAToolkit_find_and_add_import_lib(cusolver_metis_static ALT metis_static) # implementation detail static lib
+    _CUDAToolkit_find_and_add_import_lib(cusolver_static DEPS cusolver_metis_static cublasLt_static)
+  endif()
+
+  # nvGRAPH depends on cuRAND, and cuSOLVER.
+  _CUDAToolkit_find_and_add_import_lib(nvgraph DEPS curand cusolver)
+  _CUDAToolkit_find_and_add_import_lib(nvgraph_static DEPS curand_static cusolver_static)
+
+  # Process the majority of the NPP libraries.
+  foreach(cuda_lib nppial nppicc nppidei nppif nppig nppim nppist nppitc npps nppicom nppisu)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib} DEPS nppc)
+    _CUDAToolkit_find_and_add_import_lib(${cuda_lib}_static DEPS nppc_static)
+  endforeach()
+
+  find_path(CUDAToolkit_CUPTI_INCLUDE_DIR cupti.h PATHS
+      "${CUDAToolkit_ROOT_DIR}/extras/CUPTI/include"
+      "${CUDAToolkit_INCLUDE_DIR}/../extras/CUPTI/include"
+      "${CUDAToolkit_INCLUDE_DIR}"
+      NO_DEFAULT_PATH)
+  mark_as_advanced(CUDAToolkit_CUPTI_INCLUDE_DIR)
+
+  if(CUDAToolkit_CUPTI_INCLUDE_DIR)
+    _CUDAToolkit_find_and_add_import_lib(cupti
+                                        EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/
+                                                            ../extras/CUPTI/lib/
+                                        EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}")
+    _CUDAToolkit_find_and_add_import_lib(cupti_static
+                                        EXTRA_PATH_SUFFIXES ../extras/CUPTI/lib64/
+                                                            ../extras/CUPTI/lib/
+                                        EXTRA_INCLUDE_DIRS "${CUDAToolkit_CUPTI_INCLUDE_DIR}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(nvrtc DEPS cuda_driver)
+
+  _CUDAToolkit_find_and_add_import_lib(nvml ALT nvidia-ml nvml)
+
+  # nvtools can be installed outside the CUDA toolkit directory,
+  # so search the NVTOOLSEXT_PATH windows only environment variable
+  set(nvToolsExt_EXTRA_PATH)
+  if(WIN32)
+     set(nvToolsExt_EXTRA_PATH "C:\\Program Files\\NVIDIA Corporation\\NvToolsExt")
+  endif()
+
+  find_path(CUDAToolkit_nvToolsExt_INCLUDE_DIR nvToolsExt.h
+      PATHS "${CUDAToolkit_INCLUDE_DIR}"
+            "${CUDAToolkit_ROOT_DIR}"
+            ENV NVTOOLSEXT_PATH
+            "${nvToolsExt_EXTRA_PATH}"
+      PATH_SUFFIXES include
+      NO_DEFAULT_PATH)
+  mark_as_advanced(CUDAToolkit_nvToolsExt_INCLUDE_DIR)
+
+  if(CUDAToolkit_nvToolsExt_INCLUDE_DIR)
+    _CUDAToolkit_find_and_add_import_lib(nvToolsExt
+        ALT nvToolsExt64 nvToolsExt64_1
+        EXTRA_HINTS ENV NVTOOLSEXT_PATH
+                    "${nvToolsExt_EXTRA_PATH}"
+        EXTRA_INCLUDE_DIRS "${CUDAToolkit_nvToolsExt_INCLUDE_DIR}")
+  endif()
+
+  _CUDAToolkit_find_and_add_import_lib(OpenCL)
+endif()
+
+unset(CUDAToolkit_ROOT_DIR)
+
+if(_CUDAToolkit_Pop_ROOT_PATH)
+  list(REMOVE_AT CMAKE_FIND_ROOT_PATH 0)
+  unset(_CUDAToolkit_Pop_ROOT_PATH)
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b614e1c492b99f7b3adf456b0b88bdf5cd26fd0b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUDSS.cmake
@@ -0,0 +1,67 @@
+# Find the CUDSS library
+#
+# The following variables are optionally searched for defaults
+#  CUDSS_ROOT: Base directory where CUDSS is found
+#  CUDSS_INCLUDE_DIR: Directory where CUDSS header is searched for
+#  CUDSS_LIBRARY: Directory where CUDSS library is searched for
+#
+# The following are set after configuration is done:
+#  CUDSS_FOUND
+#  CUDSS_INCLUDE_PATH
+#  CUDSS_LIBRARY_PATH
+
+include(FindPackageHandleStandardArgs)
+
+set(CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} CACHE PATH "Folder containing NVIDIA CUDSS")
+if (DEFINED $ENV{CUDSS_ROOT_DIR})
+  message(WARNING "CUDSS_ROOT_DIR is deprecated. Please set CUDSS_ROOT instead.")
+endif()
+list(APPEND CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUDSS_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUDSS_ROOT})
+
+set(CUDSS_INCLUDE_DIR $ENV{CUDSS_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA CUDSS header files")
+
+find_path(CUDSS_INCLUDE_PATH cudss.h
+  HINTS ${CUDSS_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+set(CUDSS_LIBRARY $ENV{CUDSS_LIBRARY} CACHE PATH "Path to the CUDSS library file (e.g., libcudss.so)")
+
+set(CUDSS_LIBRARY_NAME "libcudss.so")
+if(MSVC)
+  set(CUDSS_LIBRARY_NAME "cudss.lib")
+endif()
+
+find_library(CUDSS_LIBRARY_PATH ${CUDSS_LIBRARY_NAME}
+  PATHS ${CUDSS_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUDSS DEFAULT_MSG CUDSS_LIBRARY_PATH CUDSS_INCLUDE_PATH)
+
+if(CUDSS_FOUND)
+  # Get CUDSS version
+  file(READ ${CUDSS_INCLUDE_PATH}/cudss.h CUDSS_HEADER_CONTENTS)
+  string(REGEX MATCH "define CUDSS_VER_MAJOR * +([0-9]+)"
+               CUDSS_VERSION_MAJOR "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_MAJOR * +([0-9]+)" "\\1"
+               CUDSS_VERSION_MAJOR "${CUDSS_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUDSS_VER_MINOR * +([0-9]+)"
+               CUDSS_VERSION_MINOR "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_MINOR * +([0-9]+)" "\\1"
+               CUDSS_VERSION_MINOR "${CUDSS_VERSION_MINOR}")
+  string(REGEX MATCH "define CUDSS_VER_PATCH * +([0-9]+)"
+               CUDSS_VERSION_PATCH "${CUDSS_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDSS_VER_PATCH * +([0-9]+)" "\\1"
+               CUDSS_VERSION_PATCH "${CUDSS_VERSION_PATCH}")
+  # Assemble CUDSS version. Use minor version since current major version is 0.
+  if(NOT CUDSS_VERSION_MINOR)
+    set(CUDSS_VERSION "?")
+  else()
+    set(CUDSS_VERSION
+        "${CUDSS_VERSION_MAJOR}.${CUDSS_VERSION_MINOR}.${CUDSS_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUDSS_ROOT CUDSS_INCLUDE_DIR CUDSS_LIBRARY CUDSS_VERSION)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..6c15bde147469ddc84980dca0c756e8f26e1ddb1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindCUSPARSELT.cmake
@@ -0,0 +1,67 @@
+# Find the CUSPARSELT library
+#
+# The following variables are optionally searched for defaults
+#  CUSPARSELT_ROOT: Base directory where CUSPARSELT is found
+#  CUSPARSELT_INCLUDE_DIR: Directory where CUSPARSELT header is searched for
+#  CUSPARSELT_LIBRARY: Directory where CUSPARSELT library is searched for
+#
+# The following are set after configuration is done:
+#  CUSPARSELT_FOUND
+#  CUSPARSELT_INCLUDE_PATH
+#  CUSPARSELT_LIBRARY_PATH
+
+include(FindPackageHandleStandardArgs)
+
+set(CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt")
+if (DEFINED $ENV{CUSPARSELT_ROOT_DIR})
+  message(WARNING "CUSPARSELT_ROOT_DIR is deprecated. Please set CUSPARSELT_ROOT instead.")
+endif()
+list(APPEND CUSPARSELT_ROOT $ENV{CUSPARSELT_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUSPARSELT_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUSPARSELT_ROOT})
+
+set(CUSPARSELT_INCLUDE_DIR $ENV{CUSPARSELT_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuSPARSELt header files")
+
+find_path(CUSPARSELT_INCLUDE_PATH cusparseLt.h
+  HINTS ${CUSPARSELT_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+set(CUSPARSELT_LIBRARY $ENV{CUSPARSELT_LIBRARY} CACHE PATH "Path to the cusparselt library file (e.g., libcusparseLt.so)")
+
+set(CUSPARSELT_LIBRARY_NAME "libcusparseLt.so")
+if(MSVC)
+  set(CUSPARSELT_LIBRARY_NAME "cusparseLt.lib")
+endif()
+
+find_library(CUSPARSELT_LIBRARY_PATH ${CUSPARSELT_LIBRARY_NAME}
+  PATHS ${CUSPARSELT_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUSPARSELT DEFAULT_MSG CUSPARSELT_LIBRARY_PATH CUSPARSELT_INCLUDE_PATH)
+
+if(CUSPARSELT_FOUND)
+  # Get cuSPARSELt version
+  file(READ ${CUSPARSELT_INCLUDE_PATH}/cusparseLt.h CUSPARSELT_HEADER_CONTENTS)
+  string(REGEX MATCH "define CUSPARSELT_VER_MAJOR * +([0-9]+)"
+               CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_MAJOR * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_MAJOR "${CUSPARSELT_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUSPARSELT_VER_MINOR * +([0-9]+)"
+               CUSPARSELT_VERSION_MINOR "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_MINOR * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_MINOR "${CUSPARSELT_VERSION_MINOR}")
+  string(REGEX MATCH "define CUSPARSELT_VER_PATCH * +([0-9]+)"
+               CUSPARSELT_VERSION_PATCH "${CUSPARSELT_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUSPARSELT_VER_PATCH * +([0-9]+)" "\\1"
+               CUSPARSELT_VERSION_PATCH "${CUSPARSELT_VERSION_PATCH}")
+  # Assemble cuSPARSELt version. Use minor version since current major version is 0.
+  if(NOT CUSPARSELT_VERSION_MINOR)
+    set(CUSPARSELT_VERSION "?")
+  else()
+    set(CUSPARSELT_VERSION
+        "${CUSPARSELT_VERSION_MAJOR}.${CUSPARSELT_VERSION_MINOR}.${CUSPARSELT_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUSPARSELT_ROOT CUSPARSELT_INCLUDE_DIR CUSPARSELT_LIBRARY CUSPARSELT_VERSION)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..337afa1bfe4178d1af041c6504c1124b8c31d482
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/FindSYCLToolkit.cmake
@@ -0,0 +1,141 @@
+# This will define the following variables:
+# SYCL_FOUND               : True if the system has the SYCL library.
+# SYCL_INCLUDE_DIR         : Include directories needed to use SYCL.
+# SYCL_LIBRARY_DIR         : The path to the SYCL library.
+# SYCL_LIBRARY             : SYCL library fullname.
+# SYCL_COMPILER_VERSION    : SYCL compiler version.
+
+include(FindPackageHandleStandardArgs)
+
+set(SYCL_ROOT "")
+if(DEFINED ENV{SYCL_ROOT})
+  set(SYCL_ROOT $ENV{SYCL_ROOT})
+elseif(DEFINED ENV{CMPLR_ROOT})
+  set(SYCL_ROOT $ENV{CMPLR_ROOT})
+else()
+  # Use the default path to ensure proper linking with torch::xpurt when the user is working with libtorch.
+  if(CMAKE_SYSTEM_NAME MATCHES "Linux")
+    set(SYCL_ROOT "/opt/intel/oneapi/compiler/latest")
+  elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
+    set(SYCL_ROOT "C:/Program Files (x86)/Intel/oneAPI/compiler/latest")
+  endif()
+  if(NOT EXISTS ${SYCL_ROOT})
+    set(SYCL_ROOT "")
+  endif()
+endif()
+
+string(COMPARE EQUAL "${SYCL_ROOT}" "" nosyclfound)
+if(nosyclfound)
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "SYCL library not set!!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+# Find SYCL compiler executable.
+find_program(
+  SYCL_COMPILER
+  NAMES icx
+  PATHS "${SYCL_ROOT}"
+  PATH_SUFFIXES bin bin64
+  NO_DEFAULT_PATH
+  )
+
+function(parse_sycl_compiler_version version_number)
+  # Execute the SYCL compiler with the --version flag to match the version string.
+  execute_process(COMMAND ${SYCL_COMPILER} --version OUTPUT_VARIABLE SYCL_VERSION_STRING)
+  string(REGEX REPLACE "Intel\\(R\\) (.*) Compiler ([0-9]+\\.[0-9]+\\.[0-9]+) (.*)" "\\2"
+               SYCL_VERSION_STRING_MATCH ${SYCL_VERSION_STRING})
+  string(REPLACE "." ";" SYCL_VERSION_LIST ${SYCL_VERSION_STRING_MATCH})
+  # Split the version number list into major, minor, and patch components.
+  list(GET SYCL_VERSION_LIST 0 VERSION_MAJOR)
+  list(GET SYCL_VERSION_LIST 1 VERSION_MINOR)
+  list(GET SYCL_VERSION_LIST 2 VERSION_PATCH)
+  # Calculate the version number in the format XXXXYYZZ, using the formula (major * 10000 + minor * 100 + patch).
+  math(EXPR VERSION_NUMBER_MATCH "${VERSION_MAJOR} * 10000 + ${VERSION_MINOR} * 100 + ${VERSION_PATCH}")
+  set(${version_number} "${VERSION_NUMBER_MATCH}" PARENT_SCOPE)
+endfunction()
+
+if(SYCL_COMPILER)
+  parse_sycl_compiler_version(SYCL_COMPILER_VERSION)
+endif()
+
+if(NOT SYCL_COMPILER_VERSION)
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "Cannot parse sycl compiler version to get SYCL_COMPILER_VERSION!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+# Find include path from binary.
+find_file(
+  SYCL_INCLUDE_DIR
+  NAMES include
+  HINTS ${SYCL_ROOT}
+  NO_DEFAULT_PATH
+  )
+
+# Find include/sycl path from include path.
+find_file(
+  SYCL_INCLUDE_SYCL_DIR
+  NAMES sycl
+  HINTS ${SYCL_ROOT}/include/
+  NO_DEFAULT_PATH
+  )
+
+# Due to the unrecognized compilation option `-fsycl` in other compiler.
+list(APPEND SYCL_INCLUDE_DIR ${SYCL_INCLUDE_SYCL_DIR})
+
+# Find library directory from binary.
+find_file(
+  SYCL_LIBRARY_DIR
+  NAMES lib lib64
+  HINTS ${SYCL_ROOT}
+  NO_DEFAULT_PATH
+  )
+
+# Define the old version of SYCL toolkit that is compatible with the current version of PyTorch.
+set(PYTORCH_2_5_SYCL_TOOLKIT_VERSION 20249999)
+
+# By default, we use libsycl.so on Linux and sycl.lib on Windows as the SYCL library name.
+if (SYCL_COMPILER_VERSION VERSION_LESS_EQUAL PYTORCH_2_5_SYCL_TOOLKIT_VERSION)
+  # Don't use if(WIN32) here since this requires cmake>=3.25 and file is installed
+  # and used by other projects.
+  # See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html
+  if(CMAKE_SYSTEM_NAME MATCHES "Windows")
+    # On Windows, the SYCL library is named sycl7.lib until PYTORCH_2_5_SYCL_TOOLKIT_VERSION.
+    # sycl.lib is supported in the later version.
+    set(sycl_lib_suffix "7")
+  endif()
+endif()
+
+# Find SYCL library fullname.
+find_library(
+  SYCL_LIBRARY
+  NAMES "sycl${sycl_lib_suffix}"
+  HINTS ${SYCL_LIBRARY_DIR}
+  NO_DEFAULT_PATH
+)
+
+# Find OpenCL library fullname, which is a dependency of oneDNN.
+find_library(
+  OCL_LIBRARY
+  NAMES OpenCL
+  HINTS ${SYCL_LIBRARY_DIR}
+  NO_DEFAULT_PATH
+)
+
+if((NOT SYCL_LIBRARY) OR (NOT OCL_LIBRARY))
+  set(SYCL_FOUND False)
+  set(SYCL_REASON_FAILURE "SYCL library is incomplete!!")
+  set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
+  return()
+endif()
+
+find_package_handle_standard_args(
+  SYCL
+  FOUND_VAR SYCL_FOUND
+  REQUIRED_VARS SYCL_INCLUDE_DIR SYCL_LIBRARY_DIR SYCL_LIBRARY
+  REASON_FAILURE_MESSAGE "${SYCL_REASON_FAILURE}"
+  VERSION_VAR SYCL_COMPILER_VERSION
+  )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..55c4e83012d820995f59b717ecb676452f9ccbec
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake
@@ -0,0 +1,10 @@
+# This is a wrapper of the upstream `./upstream/FindCUDA.cmake` that
+# automatically includes `./upstream/CMakeInitializeConfigs.cmake` before
+# `./upstream/FindCUDA.cmake`. The `CMakeInitializeConfigs.cmake`, which is
+# absent in old CMake versions, creates some necessary variables for the later
+# to run.
+# See ./README.md for details.
+
+set(UPSTREAM_FIND_CUDA_DIR "${CMAKE_CURRENT_LIST_DIR}/upstream/")
+
+include("${UPSTREAM_FIND_CUDA_DIR}/FindCUDA.cmake")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..82134328c803dc87a89564638540a6cbcfa2d906
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake
@@ -0,0 +1,78 @@
+# Find the CUDNN libraries
+#
+# The following variables are optionally searched for defaults
+#  CUDNN_ROOT: Base directory where CUDNN is found
+#  CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for
+#  CUDNN_LIBRARY: Directory where CUDNN library is searched for
+#  CUDNN_STATIC: Are we looking for a static library? (default: no)
+#
+# The following are set after configuration is done:
+#  CUDNN_FOUND
+#  CUDNN_INCLUDE_PATH
+#  CUDNN_LIBRARY_PATH
+#
+
+include(FindPackageHandleStandardArgs)
+
+set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN")
+if (DEFINED $ENV{CUDNN_ROOT_DIR})
+  message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.")
+endif()
+list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+
+# Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT})
+
+set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files")
+
+find_path(CUDNN_INCLUDE_PATH cudnn.h
+  HINTS ${CUDNN_INCLUDE_DIR}
+  PATH_SUFFIXES cuda/include cuda include)
+
+option(CUDNN_STATIC "Look for static CUDNN" OFF)
+if (CUDNN_STATIC)
+  set(CUDNN_LIBNAME "libcudnn_static.a")
+else()
+  set(CUDNN_LIBNAME "cudnn")
+endif()
+
+set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)")
+if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC)
+  message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.")
+endif()
+
+find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME}
+  PATHS ${CUDNN_LIBRARY}
+  PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
+
+find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH)
+
+if(CUDNN_FOUND)
+  # Get cuDNN version
+  if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h)
+    file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS)
+  else()
+    file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS)
+  endif()
+  string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)"
+               CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1"
+               CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}")
+  string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)"
+               CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1"
+               CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}")
+  string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)"
+               CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}")
+  string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1"
+               CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}")
+  # Assemble cuDNN version
+  if(NOT CUDNN_VERSION_MAJOR)
+    set(CUDNN_VERSION "?")
+  else()
+    set(CUDNN_VERSION
+        "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}")
+  endif()
+endif()
+
+mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..5517e8f0624b1e5538b761e1f4891227007d0045
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake
@@ -0,0 +1,40 @@
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+# Present in upstream, but not supported on versions of cmake we need to support
+# include_guard(GLOBAL)
+
+# Initializes `<_PREFIX>_` variables from the corresponding
+# `<_PREFIX>__INIT`, for the configurations currently used.
+function(cmake_initialize_per_config_variable _PREFIX _DOCSTRING)
+  string(STRIP "${${_PREFIX}_INIT}" _INIT)
+  set("${_PREFIX}" "${_INIT}"
+    CACHE STRING "${_DOCSTRING} during all build types.")
+  mark_as_advanced("${_PREFIX}")
+
+  if (NOT CMAKE_NOT_USING_CONFIG_FLAGS)
+    set(_CONFIGS Debug Release MinSizeRel RelWithDebInfo)
+
+    get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
+    if (_GENERATOR_IS_MULTI_CONFIG)
+      list(APPEND _CONFIGS ${CMAKE_CONFIGURATION_TYPES})
+    else()
+      if (NOT CMAKE_NO_BUILD_TYPE)
+        set(CMAKE_BUILD_TYPE "${CMAKE_BUILD_TYPE_INIT}" CACHE STRING
+          "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel ...")
+      endif()
+      list(APPEND _CONFIGS ${CMAKE_BUILD_TYPE})
+    endif()
+
+    list(REMOVE_DUPLICATES _CONFIGS)
+    foreach(_BUILD_TYPE IN LISTS _CONFIGS)
+      if (NOT "${_BUILD_TYPE}" STREQUAL "")
+        string(TOUPPER "${_BUILD_TYPE}" _BUILD_TYPE)
+        string(STRIP "${${_PREFIX}_${_BUILD_TYPE}_INIT}" _INIT)
+        set("${_PREFIX}_${_BUILD_TYPE}" "${_INIT}"
+          CACHE STRING "${_DOCSTRING} during ${_BUILD_TYPE} builds.")
+        mark_as_advanced("${_PREFIX}_${_BUILD_TYPE}")
+      endif()
+    endforeach()
+  endif()
+endfunction()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..411a246656b3bdaba6abc238fd35caf959c9cca0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake
@@ -0,0 +1,1981 @@
+#.rst:
+# FindCUDA
+# --------
+#
+# .. note::
+#
+#   The FindCUDA module has been superseded by first-class support
+#   for the CUDA language in CMake.  It is no longer necessary to
+#   use this module or call ``find_package(CUDA)``.  This module
+#   now exists only for compatibility with projects that have not
+#   been ported.
+#
+#   Instead, list ``CUDA`` among the languages named in the top-level
+#   call to the :command:`project` command, or call the
+#   :command:`enable_language` command with ``CUDA``.
+#   Then one can add CUDA (``.cu``) sources to programs directly
+#   in calls to :command:`add_library` and :command:`add_executable`.
+#
+# Tools for building CUDA C files: libraries and build dependencies.
+#
+# This script locates the NVIDIA CUDA C tools.  It should work on Linux,
+# Windows, and macOS and should be reasonably up to date with CUDA C
+# releases.
+#
+# This script makes use of the standard :command:`find_package` arguments of
+# ````, ``REQUIRED`` and ``QUIET``.  ``CUDA_FOUND`` will report if an
+# acceptable version of CUDA was found.
+#
+# The script will prompt the user to specify ``CUDA_TOOLKIT_ROOT_DIR`` if
+# the prefix cannot be determined by the location of nvcc in the system
+# path and ``REQUIRED`` is specified to :command:`find_package`.  To use
+# a different installed version of the toolkit set the environment variable
+# ``CUDA_BIN_PATH`` before running cmake (e.g.
+# ``CUDA_BIN_PATH=/usr/local/cuda1.0`` instead of the default
+# ``/usr/local/cuda``) or set ``CUDA_TOOLKIT_ROOT_DIR`` after configuring.  If
+# you change the value of ``CUDA_TOOLKIT_ROOT_DIR``, various components that
+# depend on the path will be relocated.
+#
+# It might be necessary to set ``CUDA_TOOLKIT_ROOT_DIR`` manually on certain
+# platforms, or to use a CUDA runtime not installed in the default
+# location.  In newer versions of the toolkit the CUDA library is
+# included with the graphics driver -- be sure that the driver version
+# matches what is needed by the CUDA runtime version.
+#
+# The following variables affect the behavior of the macros in the
+# script (in alphebetical order).  Note that any of these flags can be
+# changed multiple times in the same directory before calling
+# ``CUDA_ADD_EXECUTABLE``, ``CUDA_ADD_LIBRARY``, ``CUDA_COMPILE``,
+# ``CUDA_COMPILE_PTX``, ``CUDA_COMPILE_FATBIN``, ``CUDA_COMPILE_CUBIN``
+# or ``CUDA_WRAP_SRCS``::
+#
+#   CUDA_64_BIT_DEVICE_CODE (Default matches host bit size)
+#   -- Set to ON to compile for 64 bit device code, OFF for 32 bit device code.
+#      Note that making this different from the host code when generating object
+#      or C files from CUDA code just won't work, because size_t gets defined by
+#      nvcc in the generated source.  If you compile to PTX and then load the
+#      file yourself, you can mix bit sizes between device and host.
+#
+#   CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE (Default ON)
+#   -- Set to ON if you want the custom build rule to be attached to the source
+#      file in Visual Studio.  Turn OFF if you add the same cuda file to multiple
+#      targets.
+#
+#      This allows the user to build the target from the CUDA file; however, bad
+#      things can happen if the CUDA source file is added to multiple targets.
+#      When performing parallel builds it is possible for the custom build
+#      command to be run more than once and in parallel causing cryptic build
+#      errors.  VS runs the rules for every source file in the target, and a
+#      source can have only one rule no matter how many projects it is added to.
+#      When the rule is run from multiple targets race conditions can occur on
+#      the generated file.  Eventually everything will get built, but if the user
+#      is unaware of this behavior, there may be confusion.  It would be nice if
+#      this script could detect the reuse of source files across multiple targets
+#      and turn the option off for the user, but no good solution could be found.
+#
+#   CUDA_BUILD_CUBIN (Default OFF)
+#   -- Set to ON to enable and extra compilation pass with the -cubin option in
+#      Device mode. The output is parsed and register, shared memory usage is
+#      printed during build.
+#
+#   CUDA_BUILD_EMULATION (Default OFF for device mode)
+#   -- Set to ON for Emulation mode. -D_DEVICEEMU is defined for CUDA C files
+#      when CUDA_BUILD_EMULATION is TRUE.
+#
+#   CUDA_LINK_LIBRARIES_KEYWORD (Default "")
+#    -- The  keyword to use for internal
+#       target_link_libraries calls. The default is to use no keyword which
+#       uses the old "plain" form of target_link_libraries. Note that is matters
+#       because whatever is used inside the FindCUDA module must also be used
+#       outside - the two forms of target_link_libraries cannot be mixed.
+#
+#   CUDA_GENERATED_OUTPUT_DIR (Default CMAKE_CURRENT_BINARY_DIR)
+#   -- Set to the path you wish to have the generated files placed.  If it is
+#      blank output files will be placed in CMAKE_CURRENT_BINARY_DIR.
+#      Intermediate files will always be placed in
+#      CMAKE_CURRENT_BINARY_DIR/CMakeFiles.
+#
+#   CUDA_HOST_COMPILATION_CPP (Default ON)
+#   -- Set to OFF for C compilation of host code.
+#
+#   CUDA_HOST_COMPILER (Default CMAKE_C_COMPILER)
+#   -- Set the host compiler to be used by nvcc.  Ignored if -ccbin or
+#      --compiler-bindir is already present in the CUDA_NVCC_FLAGS or
+#      CUDA_NVCC_FLAGS_ variables.  For Visual Studio targets,
+#      the host compiler is constructed with one or more visual studio macros
+#      such as $(VCInstallDir), that expands out to the path when
+#      the command is run from within VS.
+#      If the CUDAHOSTCXX environment variable is set it will
+#      be used as the default.
+#
+#   CUDA_NVCC_FLAGS
+#   CUDA_NVCC_FLAGS_
+#   -- Additional NVCC command line arguments.  NOTE: multiple arguments must be
+#      semi-colon delimited (e.g. --compiler-options;-Wall)
+#
+#   CUDA_PROPAGATE_HOST_FLAGS (Default ON)
+#   -- Set to ON to propagate CMAKE_{C,CXX}_FLAGS and their configuration
+#      dependent counterparts (e.g. CMAKE_C_FLAGS_DEBUG) automatically to the
+#      host compiler through nvcc's -Xcompiler flag.  This helps make the
+#      generated host code match the rest of the system better.  Sometimes
+#      certain flags give nvcc problems, and this will help you turn the flag
+#      propagation off.  This does not affect the flags supplied directly to nvcc
+#      via CUDA_NVCC_FLAGS or through the OPTION flags specified through
+#      CUDA_ADD_LIBRARY, CUDA_ADD_EXECUTABLE, or CUDA_WRAP_SRCS.  Flags used for
+#      shared library compilation are not affected by this flag.
+#
+#   CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST (Default "")
+#   -- A list containing the host flags that should not be propagated when
+#      CUDA_PROPAGATE_HOST_FLAGS is ON.
+#
+#   CUDA_SEPARABLE_COMPILATION (Default OFF)
+#   -- If set this will enable separable compilation for all CUDA runtime object
+#      files.  If used outside of CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY
+#      (e.g. calling CUDA_WRAP_SRCS directly),
+#      CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME and
+#      CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS should be called.
+#
+#   CUDA_SOURCE_PROPERTY_FORMAT
+#   -- If this source file property is set, it can override the format specified
+#      to CUDA_WRAP_SRCS (OBJ, PTX, CUBIN, or FATBIN).  If an input source file
+#      is not a .cu file, setting this file will cause it to be treated as a .cu
+#      file. See documentation for set_source_files_properties on how to set
+#      this property.
+#
+#   CUDA_USE_STATIC_CUDA_RUNTIME (Default ON)
+#   -- When enabled the static version of the CUDA runtime library will be used
+#      in CUDA_LIBRARIES.  If the version of CUDA configured doesn't support
+#      this option, then it will be silently disabled.
+#
+#   CUDA_VERBOSE_BUILD (Default OFF)
+#   -- Set to ON to see all the commands used when building the CUDA file.  When
+#      using a Makefile generator the value defaults to VERBOSE (run make
+#      VERBOSE=1 to see output), although setting CUDA_VERBOSE_BUILD to ON will
+#      always print the output.
+#
+# The script creates the following macros (in alphebetical order)::
+#
+#   CUDA_ADD_CUFFT_TO_TARGET( cuda_target )
+#   -- Adds the cufft library to the target (can be any target).  Handles whether
+#      you are in emulation mode or not.
+#
+#   CUDA_ADD_CUBLAS_TO_TARGET( cuda_target )
+#   -- Adds the cublas library to the target (can be any target).  Handles
+#      whether you are in emulation mode or not.
+#
+#   CUDA_ADD_EXECUTABLE( cuda_target file0 file1 ...
+#                        [WIN32] [MACOSX_BUNDLE] [EXCLUDE_FROM_ALL] [OPTIONS ...] )
+#   -- Creates an executable "cuda_target" which is made up of the files
+#      specified.  All of the non CUDA C files are compiled using the standard
+#      build rules specified by CMAKE and the cuda files are compiled to object
+#      files using nvcc and the host compiler.  In addition CUDA_INCLUDE_DIRS is
+#      added automatically to include_directories().  Some standard CMake target
+#      calls can be used on the target after calling this macro
+#      (e.g. set_target_properties and target_link_libraries), but setting
+#      properties that adjust compilation flags will not affect code compiled by
+#      nvcc.  Such flags should be modified before calling CUDA_ADD_EXECUTABLE,
+#      CUDA_ADD_LIBRARY or CUDA_WRAP_SRCS.
+#
+#   CUDA_ADD_LIBRARY( cuda_target file0 file1 ...
+#                     [STATIC | SHARED | MODULE] [EXCLUDE_FROM_ALL] [OPTIONS ...] )
+#   -- Same as CUDA_ADD_EXECUTABLE except that a library is created.
+#
+#   CUDA_BUILD_CLEAN_TARGET()
+#   -- Creates a convenience target that deletes all the dependency files
+#      generated.  You should make clean after running this target to ensure the
+#      dependency files get regenerated.
+#
+#   CUDA_COMPILE( generated_files file0 file1 ... [STATIC | SHARED | MODULE]
+#                 [OPTIONS ...] )
+#   -- Returns a list of generated files from the input source files to be used
+#      with ADD_LIBRARY or ADD_EXECUTABLE.
+#
+#   CUDA_COMPILE_PTX( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of PTX files generated from the input source files.
+#
+#   CUDA_COMPILE_FATBIN( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of FATBIN files generated from the input source files.
+#
+#   CUDA_COMPILE_CUBIN( generated_files file0 file1 ... [OPTIONS ...] )
+#   -- Returns a list of CUBIN files generated from the input source files.
+#
+#   CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME( output_file_var
+#                                                        cuda_target
+#                                                        object_files )
+#   -- Compute the name of the intermediate link file used for separable
+#      compilation.  This file name is typically passed into
+#      CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS.  output_file_var is produced
+#      based on cuda_target the list of objects files that need separable
+#      compilation as specified by object_files.  If the object_files list is
+#      empty, then output_file_var will be empty.  This function is called
+#      automatically for CUDA_ADD_LIBRARY and CUDA_ADD_EXECUTABLE.  Note that
+#      this is a function and not a macro.
+#
+#   CUDA_INCLUDE_DIRECTORIES( path0 path1 ... )
+#   -- Sets the directories that should be passed to nvcc
+#      (e.g. nvcc -Ipath0 -Ipath1 ... ). These paths usually contain other .cu
+#      files.
+#
+#
+#   CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS( output_file_var cuda_target
+#                                            nvcc_flags object_files)
+#   -- Generates the link object required by separable compilation from the given
+#      object files.  This is called automatically for CUDA_ADD_EXECUTABLE and
+#      CUDA_ADD_LIBRARY, but can be called manually when using CUDA_WRAP_SRCS
+#      directly.  When called from CUDA_ADD_LIBRARY or CUDA_ADD_EXECUTABLE the
+#      nvcc_flags passed in are the same as the flags passed in via the OPTIONS
+#      argument.  The only nvcc flag added automatically is the bitness flag as
+#      specified by CUDA_64_BIT_DEVICE_CODE.  Note that this is a function
+#      instead of a macro.
+#
+#   CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
+#   -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
+#      target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
+#       - "Auto" detects local machine GPU compute arch at runtime.
+#       - "Common" and "All" cover common and entire subsets of architectures
+#      ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
+#      NAME: Kepler Maxwell Kepler+Tesla Maxwell+Tegra Pascal Volta Turing
+#      NUM: Any number. Only those pairs are currently accepted by NVCC though:
+#            3.5 3.7 5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5
+#      Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
+#      Additionally, sets ${out_variable}_readable to the resulting numeric list
+#      Example:
+#       CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
+#        LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
+#
+#      More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
+#      Note that this is a function instead of a macro.
+#
+#   CUDA_WRAP_SRCS ( cuda_target format generated_files file0 file1 ...
+#                    [STATIC | SHARED | MODULE] [OPTIONS ...] )
+#   -- This is where all the magic happens.  CUDA_ADD_EXECUTABLE,
+#      CUDA_ADD_LIBRARY, CUDA_COMPILE, and CUDA_COMPILE_PTX all call this
+#      function under the hood.
+#
+#      Given the list of files (file0 file1 ... fileN) this macro generates
+#      custom commands that generate either PTX or linkable objects (use "PTX" or
+#      "OBJ" for the format argument to switch).  Files that don't end with .cu
+#      or have the HEADER_FILE_ONLY property are ignored.
+#
+#      The arguments passed in after OPTIONS are extra command line options to
+#      give to nvcc.  You can also specify per configuration options by
+#      specifying the name of the configuration followed by the options.  General
+#      options must precede configuration specific options.  Not all
+#      configurations need to be specified, only the ones provided will be used.
+#
+#         OPTIONS -DFLAG=2 "-DFLAG_OTHER=space in flag"
+#         DEBUG -g
+#         RELEASE --use_fast_math
+#         RELWITHDEBINFO --use_fast_math;-g
+#         MINSIZEREL --use_fast_math
+#
+#      For certain configurations (namely VS generating object files with
+#      CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE set to ON), no generated file will
+#      be produced for the given cuda file.  This is because when you add the
+#      cuda file to Visual Studio it knows that this file produces an object file
+#      and will link in the resulting object file automatically.
+#
+#      This script will also generate a separate cmake script that is used at
+#      build time to invoke nvcc.  This is for several reasons.
+#
+#        1. nvcc can return negative numbers as return values which confuses
+#        Visual Studio into thinking that the command succeeded.  The script now
+#        checks the error codes and produces errors when there was a problem.
+#
+#        2. nvcc has been known to not delete incomplete results when it
+#        encounters problems.  This confuses build systems into thinking the
+#        target was generated when in fact an unusable file exists.  The script
+#        now deletes the output files if there was an error.
+#
+#        3. By putting all the options that affect the build into a file and then
+#        make the build rule dependent on the file, the output files will be
+#        regenerated when the options change.
+#
+#      This script also looks at optional arguments STATIC, SHARED, or MODULE to
+#      determine when to target the object compilation for a shared library.
+#      BUILD_SHARED_LIBS is ignored in CUDA_WRAP_SRCS, but it is respected in
+#      CUDA_ADD_LIBRARY.  On some systems special flags are added for building
+#      objects intended for shared libraries.  A preprocessor macro,
+#      _EXPORTS is defined when a shared library compilation is
+#      detected.
+#
+#      Flags passed into add_definitions with -D or /D are passed along to nvcc.
+#
+#
+#
+# The script defines the following variables::
+#
+#   CUDA_VERSION_MAJOR    -- The major version of cuda as reported by nvcc.
+#   CUDA_VERSION_MINOR    -- The minor version.
+#   CUDA_VERSION
+#   CUDA_VERSION_STRING   -- CUDA_VERSION_MAJOR.CUDA_VERSION_MINOR
+#   CUDA_HAS_FP16         -- Whether a short float (float16,fp16) is supported.
+#
+#   CUDA_TOOLKIT_ROOT_DIR -- Path to the CUDA Toolkit (defined if not set).
+#   CUDA_SDK_ROOT_DIR     -- Path to the CUDA SDK.  Use this to find files in the
+#                            SDK.  This script will not directly support finding
+#                            specific libraries or headers, as that isn't
+#                            supported by NVIDIA.  If you want to change
+#                            libraries when the path changes see the
+#                            FindCUDA.cmake script for an example of how to clear
+#                            these variables.  There are also examples of how to
+#                            use the CUDA_SDK_ROOT_DIR to locate headers or
+#                            libraries, if you so choose (at your own risk).
+#   CUDA_INCLUDE_DIRS     -- Include directory for cuda headers.  Added automatically
+#                            for CUDA_ADD_EXECUTABLE and CUDA_ADD_LIBRARY.
+#   CUDA_LIBRARIES        -- Cuda RT library.
+#   CUDA_CUFFT_LIBRARIES  -- Device or emulation library for the Cuda FFT
+#                            implementation (alternative to:
+#                            CUDA_ADD_CUFFT_TO_TARGET macro)
+#   CUDA_CUBLAS_LIBRARIES -- Device or emulation library for the Cuda BLAS
+#                            implementation (alternative to:
+#                            CUDA_ADD_CUBLAS_TO_TARGET macro).
+#   CUDA_cudart_static_LIBRARY -- Statically linkable cuda runtime library.
+#                                 Only available for CUDA version 5.5+
+#   CUDA_cudadevrt_LIBRARY -- Device runtime library.
+#                             Required for separable compilation.
+#   CUDA_cupti_LIBRARY    -- CUDA Profiling Tools Interface library.
+#                            Only available for CUDA version 4.0+.
+#   CUDA_curand_LIBRARY   -- CUDA Random Number Generation library.
+#                            Only available for CUDA version 3.2+.
+#   CUDA_cusolver_LIBRARY -- CUDA Direct Solver library.
+#                            Only available for CUDA version 7.0+.
+#   CUDA_cusparse_LIBRARY -- CUDA Sparse Matrix library.
+#                            Only available for CUDA version 3.2+.
+#   CUDA_npp_LIBRARY      -- NVIDIA Performance Primitives lib.
+#                            Only available for CUDA version 4.0+.
+#   CUDA_nppc_LIBRARY     -- NVIDIA Performance Primitives lib (core).
+#                            Only available for CUDA version 5.5+.
+#   CUDA_nppi_LIBRARY     -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 5.5 - 8.0.
+#   CUDA_nppial_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppicc_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppicom_LIBRARY  -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppidei_LIBRARY  -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppif_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppig_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppim_LIBRARY    -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppist_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppisu_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_nppitc_LIBRARY   -- NVIDIA Performance Primitives lib (image processing).
+#                            Only available for CUDA version 9.0.
+#   CUDA_npps_LIBRARY     -- NVIDIA Performance Primitives lib (signal processing).
+#                            Only available for CUDA version 5.5+.
+#   CUDA_nvcuvenc_LIBRARY -- CUDA Video Encoder library.
+#                            Only available for CUDA version 3.2+.
+#                            Windows only.
+#   CUDA_nvcuvid_LIBRARY  -- CUDA Video Decoder library.
+#                            Only available for CUDA version 3.2+.
+#                            Windows only.
+#
+
+#   James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#   Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#   Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#   Copyright (c) 2007-2009
+#   Scientific Computing and Imaging Institute, University of Utah
+#
+#   This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#   for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+###############################################################################
+
+# FindCUDA.cmake
+
+include(FindPackageHandleStandardArgs)
+# This macro helps us find the location of helper files we will need the full path to
+macro(CUDA_FIND_HELPER_FILE _name _extension)
+  set(_full_name "${_name}.${_extension}")
+  # CMAKE_CURRENT_LIST_FILE contains the full path to the file currently being
+  # processed.  Using this variable, we can pull out the current path, and
+  # provide a way to get access to the other files we need local to here.
+  get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+  set(CUDA_${_name} "${CMAKE_CURRENT_LIST_DIR}/FindCUDA/${_full_name}")
+  if(NOT EXISTS "${CUDA_${_name}}")
+    set(error_message "${_full_name} not found in ${CMAKE_CURRENT_LIST_DIR}/FindCUDA")
+    if(CUDA_FIND_REQUIRED)
+      message(FATAL_ERROR "${error_message}")
+    else()
+      if(NOT CUDA_FIND_QUIETLY)
+        message(STATUS "${error_message}")
+      endif()
+    endif()
+  endif()
+  # Set this variable as internal, so the user isn't bugged with it.
+  set(CUDA_${_name} ${CUDA_${_name}} CACHE INTERNAL "Location of ${_full_name}" FORCE)
+endmacro()
+
+#####################################################################
+## CUDA_INCLUDE_NVCC_DEPENDENCIES
+##
+
+# So we want to try and include the dependency file if it exists.  If
+# it doesn't exist then we need to create an empty one, so we can
+# include it.
+
+# If it does exist, then we need to check to see if all the files it
+# depends on exist.  If they don't then we should clear the dependency
+# file and regenerate it later.  This covers the case where a header
+# file has disappeared or moved.
+
+macro(CUDA_INCLUDE_NVCC_DEPENDENCIES dependency_file)
+  set(CUDA_NVCC_DEPEND)
+  set(CUDA_NVCC_DEPEND_REGENERATE FALSE)
+
+
+  # Include the dependency file.  Create it first if it doesn't exist .  The
+  # INCLUDE puts a dependency that will force CMake to rerun and bring in the
+  # new info when it changes.  DO NOT REMOVE THIS (as I did and spent a few
+  # hours figuring out why it didn't work.
+  if(NOT EXISTS ${dependency_file})
+    file(WRITE ${dependency_file} "#FindCUDA.cmake generated file.  Do not edit.\n")
+  endif()
+  # Always include this file to force CMake to run again next
+  # invocation and rebuild the dependencies.
+  #message("including dependency_file = ${dependency_file}")
+  include(${dependency_file})
+
+  # Now we need to verify the existence of all the included files
+  # here.  If they aren't there we need to just blank this variable and
+  # make the file regenerate again.
+#   if(DEFINED CUDA_NVCC_DEPEND)
+#     message("CUDA_NVCC_DEPEND set")
+#   else()
+#     message("CUDA_NVCC_DEPEND NOT set")
+#   endif()
+  if(CUDA_NVCC_DEPEND)
+    #message("CUDA_NVCC_DEPEND found")
+    foreach(f ${CUDA_NVCC_DEPEND})
+      # message("searching for ${f}")
+      if(NOT EXISTS ${f})
+        #message("file ${f} not found")
+        set(CUDA_NVCC_DEPEND_REGENERATE TRUE)
+      endif()
+    endforeach()
+  else()
+    #message("CUDA_NVCC_DEPEND false")
+    # No dependencies, so regenerate the file.
+    set(CUDA_NVCC_DEPEND_REGENERATE TRUE)
+  endif()
+
+  #message("CUDA_NVCC_DEPEND_REGENERATE = ${CUDA_NVCC_DEPEND_REGENERATE}")
+  # No incoming dependencies, so we need to generate them.  Make the
+  # output depend on the dependency file itself, which should cause the
+  # rule to re-run.
+  if(CUDA_NVCC_DEPEND_REGENERATE)
+    set(CUDA_NVCC_DEPEND ${dependency_file})
+    #message("Generating an empty dependency_file: ${dependency_file}")
+    file(WRITE ${dependency_file} "#FindCUDA.cmake generated file.  Do not edit.\n")
+  endif()
+
+endmacro()
+
+###############################################################################
+###############################################################################
+# Setup variables' defaults
+###############################################################################
+###############################################################################
+
+# Allow the user to specify if the device code is supposed to be 32 or 64 bit.
+if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+  set(CUDA_64_BIT_DEVICE_CODE_DEFAULT ON)
+else()
+  set(CUDA_64_BIT_DEVICE_CODE_DEFAULT OFF)
+endif()
+option(CUDA_64_BIT_DEVICE_CODE "Compile device code in 64 bit mode" ${CUDA_64_BIT_DEVICE_CODE_DEFAULT})
+
+# Attach the build rule to the source file in VS.  This option
+option(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE "Attach the build rule to the CUDA source file.  Enable only when the CUDA source file is added to at most one target." ON)
+
+# Prints out extra information about the cuda file during compilation
+option(CUDA_BUILD_CUBIN "Generate and parse .cubin files in Device mode." OFF)
+
+# Set whether we are using emulation or device mode.
+option(CUDA_BUILD_EMULATION "Build in Emulation mode" OFF)
+
+# Where to put the generated output.
+set(CUDA_GENERATED_OUTPUT_DIR "" CACHE PATH "Directory to put all the output files.  If blank it will default to the CMAKE_CURRENT_BINARY_DIR")
+
+# Parse HOST_COMPILATION mode.
+option(CUDA_HOST_COMPILATION_CPP "Generated file extension" ON)
+
+# Extra user settable flags
+cmake_initialize_per_config_variable(CUDA_NVCC_FLAGS "Semi-colon delimit multiple arguments.")
+
+if(DEFINED ENV{CUDAHOSTCXX})
+  set(CUDA_HOST_COMPILER "$ENV{CUDAHOSTCXX}" CACHE FILEPATH "Host side compiler used by NVCC")
+elseif(CMAKE_GENERATOR MATCHES "Visual Studio")
+  set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)Tools/MSVC/$(VCToolsVersion)/bin/Host$(Platform)/$(PlatformTarget)")
+  if(MSVC_VERSION LESS 1910)
+   set(_CUDA_MSVC_HOST_COMPILER "$(VCInstallDir)bin")
+  endif()
+
+  set(CUDA_HOST_COMPILER "${_CUDA_MSVC_HOST_COMPILER}" CACHE FILEPATH "Host side compiler used by NVCC")
+
+else()
+  if(APPLE
+      AND "${CMAKE_C_COMPILER_ID}" MATCHES "Clang"
+      AND "${CMAKE_C_COMPILER}" MATCHES "/cc$")
+    # Using cc which is symlink to clang may let NVCC think it is GCC and issue
+    # unhandled -dumpspecs option to clang. Also in case neither
+    # CMAKE_C_COMPILER is defined (project does not use C language) nor
+    # CUDA_HOST_COMPILER is specified manually we should skip -ccbin and let
+    # nvcc use its own default C compiler.
+    # Only care about this on APPLE with clang to avoid
+    # following symlinks to things like ccache
+    if(DEFINED CMAKE_C_COMPILER AND NOT DEFINED CUDA_HOST_COMPILER)
+      get_filename_component(c_compiler_realpath "${CMAKE_C_COMPILER}" REALPATH)
+      # if the real path does not end up being clang then
+      # go back to using CMAKE_C_COMPILER
+      if(NOT "${c_compiler_realpath}" MATCHES "/clang$")
+        set(c_compiler_realpath "${CMAKE_C_COMPILER}")
+      endif()
+    else()
+      set(c_compiler_realpath "")
+    endif()
+    set(CUDA_HOST_COMPILER "${c_compiler_realpath}" CACHE FILEPATH "Host side compiler used by NVCC")
+  elseif(MSVC AND "${CMAKE_C_COMPILER}" MATCHES "clcache|sccache")
+    # NVCC does not think it will work if it is passed clcache.exe or sccache.exe
+    # as the host compiler, which means that builds with CC=cl.exe won't work.
+    # Best to just feed it whatever the actual cl.exe is as the host compiler.
+    set(CUDA_HOST_COMPILER "cl.exe" CACHE FILEPATH "Host side compiler used by NVCC")
+  else()
+    set(CUDA_HOST_COMPILER "${CMAKE_C_COMPILER}"
+      CACHE FILEPATH "Host side compiler used by NVCC")
+  endif()
+endif()
+
+# Propagate the host flags to the host compiler via -Xcompiler
+option(CUDA_PROPAGATE_HOST_FLAGS "Propagate C/CXX_FLAGS and friends to the host compiler via -Xcompile" ON)
+
+# Blacklisted flags to prevent propagation
+set(CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST  "" CACHE STRING "Blacklisted flags to prevent propagation")
+
+# Enable CUDA_SEPARABLE_COMPILATION
+option(CUDA_SEPARABLE_COMPILATION "Compile CUDA objects with separable compilation enabled.  Requires CUDA 5.0+" OFF)
+
+# Specifies whether the commands used when compiling the .cu file will be printed out.
+option(CUDA_VERBOSE_BUILD "Print out the commands run while compiling the CUDA source file.  With the Makefile generator this defaults to VERBOSE variable specified on the command line, but can be forced on with this option." OFF)
+
+mark_as_advanced(
+  CUDA_64_BIT_DEVICE_CODE
+  CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE
+  CUDA_GENERATED_OUTPUT_DIR
+  CUDA_HOST_COMPILATION_CPP
+  CUDA_NVCC_FLAGS
+  CUDA_PROPAGATE_HOST_FLAGS
+  CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST
+  CUDA_BUILD_CUBIN
+  CUDA_BUILD_EMULATION
+  CUDA_VERBOSE_BUILD
+  CUDA_SEPARABLE_COMPILATION
+  )
+
+# Single config generators like Makefiles or Ninja don't usually have
+# CMAKE_CONFIGURATION_TYPES defined (but note that it can be defined if set by
+# projects or developers). Even CMAKE_BUILD_TYPE might not be defined for
+# single config generators (and should not be defined for multi-config
+# generators). To ensure we get a complete superset of all possible
+# configurations, we combine CMAKE_CONFIGURATION_TYPES, CMAKE_BUILD_TYPE and
+# all of the standard configurations, then weed out duplicates with
+# list(REMOVE_DUPLICATES). Looping over the unique set then ensures we have
+# each configuration-specific set of nvcc flags defined and marked as advanced.
+set(CUDA_configuration_types ${CMAKE_CONFIGURATION_TYPES} ${CMAKE_BUILD_TYPE} Debug MinSizeRel Release RelWithDebInfo)
+list(REMOVE_DUPLICATES CUDA_configuration_types)
+
+###############################################################################
+###############################################################################
+# Locate CUDA, Set Build Type, etc.
+###############################################################################
+###############################################################################
+
+macro(cuda_unset_include_and_libraries)
+  unset(CUDA_TOOLKIT_INCLUDE CACHE)
+  unset(CUDA_CUDART_LIBRARY CACHE)
+  unset(CUDA_CUDA_LIBRARY CACHE)
+  # Make sure you run this before you unset CUDA_VERSION.
+  unset(CUDA_cudart_static_LIBRARY CACHE)
+  unset(CUDA_cudadevrt_LIBRARY CACHE)
+  unset(CUDA_cublas_LIBRARY CACHE)
+  unset(CUDA_cublas_device_LIBRARY CACHE)
+  unset(CUDA_cublasemu_LIBRARY CACHE)
+  unset(CUDA_cublasLt_LIBRARY CACHE)
+  unset(CUDA_cufft_LIBRARY CACHE)
+  unset(CUDA_cufftemu_LIBRARY CACHE)
+  unset(CUDA_cupti_LIBRARY CACHE)
+  unset(CUDA_curand_LIBRARY CACHE)
+  unset(CUDA_cusolver_LIBRARY CACHE)
+  unset(CUDA_cusparse_LIBRARY CACHE)
+  unset(CUDA_npp_LIBRARY CACHE)
+  unset(CUDA_nppc_LIBRARY CACHE)
+  unset(CUDA_nppi_LIBRARY CACHE)
+  unset(CUDA_npps_LIBRARY CACHE)
+  unset(CUDA_nvcuvenc_LIBRARY CACHE)
+  unset(CUDA_nvcuvid_LIBRARY CACHE)
+  unset(CUDA_GPU_DETECT_OUTPUT CACHE)
+endmacro()
+
+# Check to see if the CUDA_TOOLKIT_ROOT_DIR and CUDA_SDK_ROOT_DIR have changed,
+# if they have then clear the cache variables, so that will be detected again.
+if(NOT "${CUDA_TOOLKIT_ROOT_DIR}" STREQUAL "${CUDA_TOOLKIT_ROOT_DIR_INTERNAL}")
+  unset(CUDA_TOOLKIT_TARGET_DIR CACHE)
+  unset(CUDA_NVCC_EXECUTABLE CACHE)
+  cuda_unset_include_and_libraries()
+  unset(CUDA_VERSION CACHE)
+endif()
+
+if(NOT "${CUDA_TOOLKIT_TARGET_DIR}" STREQUAL "${CUDA_TOOLKIT_TARGET_DIR_INTERNAL}")
+  cuda_unset_include_and_libraries()
+endif()
+
+#
+#  End of unset()
+#
+
+#
+#  Start looking for things
+#
+
+# Search for the cuda distribution.
+if(NOT CUDA_TOOLKIT_ROOT_DIR AND NOT CMAKE_CROSSCOMPILING)
+  # Search in the CUDA_BIN_PATH first.
+  find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC
+    NAMES nvcc nvcc.exe
+    PATHS
+      ENV CUDA_TOOLKIT_ROOT
+      ENV CUDA_PATH
+      ENV CUDA_BIN_PATH
+    PATH_SUFFIXES bin bin64
+    DOC "Toolkit location."
+    NO_DEFAULT_PATH
+    )
+
+  # Now search default paths
+  find_program(CUDA_TOOLKIT_ROOT_DIR_NVCC
+    NAMES nvcc nvcc.exe
+    PATHS /opt/cuda/bin
+    PATH_SUFFIXES cuda/bin
+    DOC "Toolkit location."
+    )
+
+  if (CUDA_TOOLKIT_ROOT_DIR_NVCC)
+    get_filename_component(CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR "${CUDA_TOOLKIT_ROOT_DIR_NVCC}" DIRECTORY)
+    get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDA_TOOLKIT_ROOT_DIR_NVCC_PAR}" DIRECTORY CACHE)
+    string(REGEX REPLACE "[/\\\\]?bin[64]*[/\\\\]?$" "" CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+    # We need to force this back into the cache.
+    set(CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT_DIR} CACHE PATH "Toolkit location." FORCE)
+    set(CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+  endif()
+  unset(CUDA_TOOLKIT_ROOT_DIR_NVCC CACHE)
+
+  if (NOT EXISTS ${CUDA_TOOLKIT_ROOT_DIR})
+    if(CUDA_FIND_REQUIRED)
+      message(FATAL_ERROR "Specify CUDA_TOOLKIT_ROOT_DIR")
+    elseif(NOT CUDA_FIND_QUIETLY)
+      message("CUDA_TOOLKIT_ROOT_DIR not found or specified")
+    endif()
+  endif ()
+endif ()
+
+if(CMAKE_CROSSCOMPILING)
+  SET (CUDA_TOOLKIT_ROOT $ENV{CUDA_TOOLKIT_ROOT})
+  if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7-a")
+    # Support for NVPACK
+    set (CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-androideabi")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "arm")
+    # Support for arm cross compilation
+    set(CUDA_TOOLKIT_TARGET_NAMES "armv7-linux-gnueabihf")
+  elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
+    # Support for aarch64 cross compilation
+    if (ANDROID_ARCH_NAME STREQUAL "arm64")
+      set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux-androideabi")
+    else()
+      set(CUDA_TOOLKIT_TARGET_NAMES "aarch64-linux" "sbsa-linux")
+    endif (ANDROID_ARCH_NAME STREQUAL "arm64")
+  endif()
+
+  foreach(CUDA_TOOLKIT_TARGET_NAME IN LISTS CUDA_TOOLKIT_TARGET_NAMES)
+    if (EXISTS "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}")
+      set(CUDA_TOOLKIT_TARGET_DIR "${CUDA_TOOLKIT_ROOT}/targets/${CUDA_TOOLKIT_TARGET_NAME}" CACHE PATH "CUDA Toolkit target location.")
+      SET (CUDA_TOOLKIT_ROOT_DIR ${CUDA_TOOLKIT_ROOT} CACHE PATH "Toolkit location." FORCE)
+      mark_as_advanced(CUDA_TOOLKIT_TARGET_DIR)
+      break()
+    endif()
+  endforeach()
+
+  # add known CUDA targetr root path to the set of directories we search for programs, libraries and headers
+  set( CMAKE_FIND_ROOT_PATH "${CUDA_TOOLKIT_TARGET_DIR};${CMAKE_FIND_ROOT_PATH}")
+  macro( cuda_find_host_program )
+    if (COMMAND find_host_program)
+      find_host_program( ${ARGN} )
+    else()
+      find_program( ${ARGN} )
+    endif()
+  endmacro()
+else()
+  # for non-cross-compile, find_host_program == find_program and CUDA_TOOLKIT_TARGET_DIR == CUDA_TOOLKIT_ROOT_DIR
+  macro( cuda_find_host_program )
+    find_program( ${ARGN} )
+  endmacro()
+  SET (CUDA_TOOLKIT_TARGET_DIR ${CUDA_TOOLKIT_ROOT_DIR})
+endif()
+
+
+# CUDA_NVCC_EXECUTABLE
+if(DEFINED ENV{CUDA_NVCC_EXECUTABLE})
+  set(CUDA_NVCC_EXECUTABLE "$ENV{CUDA_NVCC_EXECUTABLE}" CACHE FILEPATH "The CUDA compiler")
+else()
+  cuda_find_host_program(CUDA_NVCC_EXECUTABLE
+    NAMES nvcc
+    PATHS "${CUDA_TOOLKIT_ROOT_DIR}"
+    ENV CUDA_PATH
+    ENV CUDA_BIN_PATH
+    PATH_SUFFIXES bin bin64
+    NO_DEFAULT_PATH
+    )
+  # Search default search paths, after we search our own set of paths.
+  cuda_find_host_program(CUDA_NVCC_EXECUTABLE nvcc)
+endif()
+
+if(CUDA_NVCC_EXECUTABLE AND NOT CUDA_VERSION)
+  # Compute the version.
+  execute_process(COMMAND ${CUDA_NVCC_EXECUTABLE} "--version"
+    OUTPUT_VARIABLE NVCC_OUT
+    RESULT_VARIABLE NVCC_RC)
+  if(NOT (${NVCC_RC} EQUAL 0))
+    message(WARNING "Failed to execute '${CUDA_NVCC_EXECUTABLE} --version'")
+    set(CUDA_FOUND FALSE)
+    return()
+  endif()
+  string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR ${NVCC_OUT})
+  string(REGEX REPLACE ".*release ([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR ${NVCC_OUT})
+  set(CUDA_VERSION "${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}" CACHE STRING "Version of CUDA as computed from nvcc.")
+  mark_as_advanced(CUDA_VERSION)
+else()
+  # Need to set these based off of the cached value
+  string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR "${CUDA_VERSION}")
+  string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR "${CUDA_VERSION}")
+endif()
+
+# Always set this convenience variable
+set(CUDA_VERSION_STRING "${CUDA_VERSION}")
+
+# CUDA_TOOLKIT_INCLUDE
+find_path(CUDA_TOOLKIT_INCLUDE
+  device_functions.h # Header included in toolkit
+  PATHS ${CUDA_TOOLKIT_TARGET_DIR}
+  ENV CUDA_PATH
+  ENV CUDA_INC_PATH
+  PATH_SUFFIXES include
+  NO_DEFAULT_PATH
+  )
+# Search default search paths, after we search our own set of paths.
+find_path(CUDA_TOOLKIT_INCLUDE device_functions.h)
+mark_as_advanced(CUDA_TOOLKIT_INCLUDE)
+
+set(CUDA_HAS_FP16 TRUE)
+
+# Set the user list of include dir to nothing to initialize it.
+set (CUDA_NVCC_INCLUDE_DIRS_USER "")
+set (CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
+
+macro(cuda_find_library_local_first_with_path_ext _var _names _doc _path_ext )
+  if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+    # CUDA 3.2+ on Windows moved the library directories, so we need the new
+    # and old paths.
+    set(_cuda_64bit_lib_dir "${_path_ext}lib/x64" "${_path_ext}lib64" "${_path_ext}libx64" )
+  endif()
+  # CUDA 3.2+ on Windows moved the library directories, so we need to new
+  # (lib/Win32) and the old path (lib).
+  find_library(${_var}
+    NAMES ${_names}
+    PATHS "${CUDA_TOOLKIT_TARGET_DIR}"
+    ENV CUDA_PATH
+    ENV CUDA_LIB_PATH
+    PATH_SUFFIXES ${_cuda_64bit_lib_dir} "${_path_ext}lib/Win32" "${_path_ext}lib" "${_path_ext}libWin32"
+    DOC ${_doc}
+    NO_DEFAULT_PATH
+    )
+  if (NOT CMAKE_CROSSCOMPILING)
+    # Search default search paths, after we search our own set of paths.
+    find_library(${_var}
+      NAMES ${_names}
+      PATHS "/usr/lib/nvidia-current"
+      DOC ${_doc}
+      )
+  endif()
+endmacro()
+
+macro(cuda_find_library_local_first _var _names _doc)
+  cuda_find_library_local_first_with_path_ext( "${_var}" "${_names}" "${_doc}" "" )
+endmacro()
+
+macro(find_library_local_first _var _names _doc )
+  cuda_find_library_local_first( "${_var}" "${_names}" "${_doc}" "" )
+endmacro()
+
+
+# CUDA_LIBRARIES
+cuda_find_library_local_first(CUDA_CUDART_LIBRARY cudart "\"cudart\" library")
+
+cuda_find_library_local_first(CUDA_cudart_static_LIBRARY cudart_static "static CUDA runtime library")
+mark_as_advanced(CUDA_cudart_static_LIBRARY)
+
+
+if(CUDA_cudart_static_LIBRARY)
+  # If static cudart available, use it by default, but provide a user-visible option to disable it.
+  option(CUDA_USE_STATIC_CUDA_RUNTIME "Use the static version of the CUDA runtime library if available" ON)
+else()
+  # If not available, silently disable the option.
+  set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "")
+endif()
+
+if(CUDA_USE_STATIC_CUDA_RUNTIME)
+  set(CUDA_CUDART_LIBRARY_VAR CUDA_cudart_static_LIBRARY)
+else()
+  set(CUDA_CUDART_LIBRARY_VAR CUDA_CUDART_LIBRARY)
+endif()
+
+cuda_find_library_local_first(CUDA_cudadevrt_LIBRARY cudadevrt "\"cudadevrt\" library")
+mark_as_advanced(CUDA_cudadevrt_LIBRARY)
+
+if(CUDA_USE_STATIC_CUDA_RUNTIME)
+  if(UNIX)
+    # Check for the dependent libraries.  Here we look for pthreads.
+    if (DEFINED CMAKE_THREAD_PREFER_PTHREAD)
+      set(_cuda_cmake_thread_prefer_pthread ${CMAKE_THREAD_PREFER_PTHREAD})
+    endif()
+    set(CMAKE_THREAD_PREFER_PTHREAD 1)
+
+    # Many of the FindXYZ CMake comes with makes use of try_compile with int main(){return 0;}
+    # as the source file.  Unfortunately this causes a warning with -Wstrict-prototypes and
+    # -Werror causes the try_compile to fail.  We will just temporarily disable other flags
+    # when doing the find_package command here.
+    set(_cuda_cmake_c_flags ${CMAKE_C_FLAGS})
+    set(CMAKE_C_FLAGS "-fPIC")
+    find_package(Threads REQUIRED)
+    set(CMAKE_C_FLAGS ${_cuda_cmake_c_flags})
+
+    if (DEFINED _cuda_cmake_thread_prefer_pthread)
+      set(CMAKE_THREAD_PREFER_PTHREAD ${_cuda_cmake_thread_prefer_pthread})
+      unset(_cuda_cmake_thread_prefer_pthread)
+    else()
+      unset(CMAKE_THREAD_PREFER_PTHREAD)
+    endif()
+
+    if(NOT APPLE)
+      #On Linux, you must link against librt when using the static cuda runtime.
+      find_library(CUDA_rt_LIBRARY rt)
+      if (NOT CUDA_rt_LIBRARY)
+        message(WARNING "Expecting to find librt for libcudart_static, but didn't find it.")
+      endif()
+    endif()
+  endif()
+endif()
+
+cuda_find_library_local_first_with_path_ext(CUDA_cupti_LIBRARY cupti "\"cupti\" library" "extras/CUPTI/")
+mark_as_advanced(CUDA_cupti_LIBRARY)
+
+# Set the CUDA_LIBRARIES variable.  This is the set of stuff to link against if you are
+# using the CUDA runtime.  For the dynamic version of the runtime, most of the
+# dependencies are brought in, but for the static version there are additional libraries
+# and linker commands needed.
+# Initialize to empty
+set(CUDA_LIBRARIES)
+
+# If we are using emulation mode and we found the cudartemu library then use
+# that one instead of cudart.
+if(CUDA_BUILD_EMULATION AND CUDA_CUDARTEMU_LIBRARY)
+  list(APPEND CUDA_LIBRARIES ${CUDA_CUDARTEMU_LIBRARY})
+elseif(CUDA_USE_STATIC_CUDA_RUNTIME AND CUDA_cudart_static_LIBRARY)
+  list(APPEND CUDA_LIBRARIES ${CUDA_cudart_static_LIBRARY} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS})
+  if (CUDA_rt_LIBRARY)
+    list(APPEND CUDA_LIBRARIES ${CUDA_rt_LIBRARY})
+  endif()
+  if(APPLE)
+    # We need to add the default path to the driver (libcuda.dylib) as an rpath, so that
+    # the static cuda runtime can find it at runtime.
+    list(APPEND CUDA_LIBRARIES -Wl,-rpath,/usr/local/cuda/lib)
+  endif()
+else()
+  list(APPEND CUDA_LIBRARIES ${CUDA_CUDART_LIBRARY})
+endif()
+
+# 1.1 toolkit on linux doesn't appear to have a separate library on
+# some platforms.
+cuda_find_library_local_first(CUDA_CUDA_LIBRARY cuda "\"cuda\" library (older versions only).")
+
+mark_as_advanced(
+  CUDA_CUDA_LIBRARY
+  CUDA_CUDART_LIBRARY
+  )
+
+#######################
+# Look for some of the toolkit helper libraries
+macro(FIND_CUDA_HELPER_LIBS _name)
+  cuda_find_library_local_first(CUDA_${_name}_LIBRARY ${_name} "\"${_name}\" library")
+  mark_as_advanced(CUDA_${_name}_LIBRARY)
+endmacro()
+
+if(CUDA_BUILD_EMULATION)
+  message(FATAL_ERROR "CUDA_BUILD_EMULATION is not supported in version 3.1 and onwards.  You must disable it to proceed.  You have version ${CUDA_VERSION}.")
+endif()
+
+find_cuda_helper_libs(cufft)
+find_cuda_helper_libs(cublas)
+find_cuda_helper_libs(cublasLt)
+# cusparse showed up in version 3.2
+find_cuda_helper_libs(cusparse)
+find_cuda_helper_libs(curand)
+if (WIN32)
+  find_cuda_helper_libs(nvcuvenc)
+  find_cuda_helper_libs(nvcuvid)
+endif()
+
+# In CUDA 9.0 NPP was nppi was removed
+find_cuda_helper_libs(nppc)
+find_cuda_helper_libs(nppial)
+find_cuda_helper_libs(nppicc)
+find_cuda_helper_libs(nppicom)
+find_cuda_helper_libs(nppidei)
+find_cuda_helper_libs(nppif)
+find_cuda_helper_libs(nppig)
+find_cuda_helper_libs(nppim)
+find_cuda_helper_libs(nppist)
+find_cuda_helper_libs(nppisu)
+find_cuda_helper_libs(nppitc)
+find_cuda_helper_libs(npps)
+set(CUDA_npp_LIBRARY "${CUDA_nppc_LIBRARY};${CUDA_nppial_LIBRARY};${CUDA_nppicc_LIBRARY};${CUDA_nppicom_LIBRARY};${CUDA_nppidei_LIBRARY};${CUDA_nppif_LIBRARY};${CUDA_nppig_LIBRARY};${CUDA_nppim_LIBRARY};${CUDA_nppist_LIBRARY};${CUDA_nppisu_LIBRARY};${CUDA_nppitc_LIBRARY};${CUDA_npps_LIBRARY}")
+# cusolver showed up in version 7.0
+find_cuda_helper_libs(cusolver)
+
+if (CUDA_BUILD_EMULATION)
+  set(CUDA_CUFFT_LIBRARIES ${CUDA_cufftemu_LIBRARY})
+  set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublasemu_LIBRARY})
+else()
+  set(CUDA_CUFFT_LIBRARIES ${CUDA_cufft_LIBRARY})
+  set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
+endif()
+
+########################
+# Look for the SDK stuff.  As of CUDA 3.0 NVSDKCUDA_ROOT has been replaced with
+# NVSDKCOMPUTE_ROOT with the old CUDA C contents moved into the C subdirectory
+find_path(CUDA_SDK_ROOT_DIR common/inc/cutil.h
+ HINTS
+  "$ENV{NVSDKCOMPUTE_ROOT}/C"
+  ENV NVSDKCUDA_ROOT
+  "[HKEY_LOCAL_MACHINE\\SOFTWARE\\NVIDIA Corporation\\Installed Products\\NVIDIA SDK 10\\Compute;InstallDir]"
+ PATHS
+  "/Developer/GPU\ Computing/C"
+  )
+
+# Keep the CUDA_SDK_ROOT_DIR first in order to be able to override the
+# environment variables.
+set(CUDA_SDK_SEARCH_PATH
+  "${CUDA_SDK_ROOT_DIR}"
+  "${CUDA_TOOLKIT_ROOT_DIR}/local/NVSDK0.2"
+  "${CUDA_TOOLKIT_ROOT_DIR}/NVSDK0.2"
+  "${CUDA_TOOLKIT_ROOT_DIR}/NV_CUDA_SDK"
+  "$ENV{HOME}/NVIDIA_CUDA_SDK"
+  "$ENV{HOME}/NVIDIA_CUDA_SDK_MACOSX"
+  "/Developer/CUDA"
+  )
+
+# Example of how to find an include file from the CUDA_SDK_ROOT_DIR
+
+# find_path(CUDA_CUT_INCLUDE_DIR
+#   cutil.h
+#   PATHS ${CUDA_SDK_SEARCH_PATH}
+#   PATH_SUFFIXES "common/inc"
+#   DOC "Location of cutil.h"
+#   NO_DEFAULT_PATH
+#   )
+# # Now search system paths
+# find_path(CUDA_CUT_INCLUDE_DIR cutil.h DOC "Location of cutil.h")
+
+# mark_as_advanced(CUDA_CUT_INCLUDE_DIR)
+
+
+# Example of how to find a library in the CUDA_SDK_ROOT_DIR
+
+# # cutil library is called cutil64 for 64 bit builds on windows.  We don't want
+# # to get these confused, so we are setting the name based on the word size of
+# # the build.
+
+# if(CMAKE_SIZEOF_VOID_P EQUAL 8)
+#   set(cuda_cutil_name cutil64)
+# else()
+#   set(cuda_cutil_name cutil32)
+# endif()
+
+# find_library(CUDA_CUT_LIBRARY
+#   NAMES cutil ${cuda_cutil_name}
+#   PATHS ${CUDA_SDK_SEARCH_PATH}
+#   # The new version of the sdk shows up in common/lib, but the old one is in lib
+#   PATH_SUFFIXES "common/lib" "lib"
+#   DOC "Location of cutil library"
+#   NO_DEFAULT_PATH
+#   )
+# # Now search system paths
+# find_library(CUDA_CUT_LIBRARY NAMES cutil ${cuda_cutil_name} DOC "Location of cutil library")
+# mark_as_advanced(CUDA_CUT_LIBRARY)
+# set(CUDA_CUT_LIBRARIES ${CUDA_CUT_LIBRARY})
+
+
+
+#############################
+# Check for required components
+set(CUDA_FOUND TRUE)
+
+set(CUDA_TOOLKIT_ROOT_DIR_INTERNAL "${CUDA_TOOLKIT_ROOT_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_TOOLKIT_ROOT_DIR was set successfully." FORCE)
+set(CUDA_TOOLKIT_TARGET_DIR_INTERNAL "${CUDA_TOOLKIT_TARGET_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_TOOLKIT_TARGET_DIR was set successfully." FORCE)
+set(CUDA_SDK_ROOT_DIR_INTERNAL "${CUDA_SDK_ROOT_DIR}" CACHE INTERNAL
+  "This is the value of the last time CUDA_SDK_ROOT_DIR was set successfully." FORCE)
+
+find_package_handle_standard_args(CUDA
+  REQUIRED_VARS
+    CUDA_TOOLKIT_ROOT_DIR
+    CUDA_NVCC_EXECUTABLE
+    CUDA_INCLUDE_DIRS
+    ${CUDA_CUDART_LIBRARY_VAR}
+  VERSION_VAR
+    CUDA_VERSION
+  )
+
+
+
+###############################################################################
+###############################################################################
+# Macros
+###############################################################################
+###############################################################################
+
+###############################################################################
+# Add include directories to pass to the nvcc command.
+macro(CUDA_INCLUDE_DIRECTORIES)
+  foreach(dir ${ARGN})
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS_USER ${dir})
+  endforeach()
+endmacro()
+
+
+##############################################################################
+cuda_find_helper_file(parse_cubin cmake)
+cuda_find_helper_file(make2cmake cmake)
+cuda_find_helper_file(run_nvcc cmake)
+include("${CMAKE_CURRENT_LIST_DIR}/FindCUDA/select_compute_arch.cmake")
+
+##############################################################################
+# Separate the OPTIONS out from the sources
+#
+macro(CUDA_GET_SOURCES_AND_OPTIONS _sources _cmake_options _options)
+  set( ${_sources} )
+  set( ${_cmake_options} )
+  set( ${_options} )
+  set( _found_options FALSE )
+  foreach(arg ${ARGN})
+    if("x${arg}" STREQUAL "xOPTIONS")
+      set( _found_options TRUE )
+    elseif(
+        "x${arg}" STREQUAL "xWIN32" OR
+        "x${arg}" STREQUAL "xMACOSX_BUNDLE" OR
+        "x${arg}" STREQUAL "xEXCLUDE_FROM_ALL" OR
+        "x${arg}" STREQUAL "xSTATIC" OR
+        "x${arg}" STREQUAL "xSHARED" OR
+        "x${arg}" STREQUAL "xMODULE"
+        )
+      list(APPEND ${_cmake_options} ${arg})
+    else()
+      if ( _found_options )
+        list(APPEND ${_options} ${arg})
+      else()
+        # Assume this is a file
+        list(APPEND ${_sources} ${arg})
+      endif()
+    endif()
+  endforeach()
+endmacro()
+
+##############################################################################
+# Parse the OPTIONS from ARGN and set the variables prefixed by _option_prefix
+#
+macro(CUDA_PARSE_NVCC_OPTIONS _option_prefix)
+  set( _found_config )
+  foreach(arg ${ARGN})
+    # Determine if we are dealing with a perconfiguration flag
+    foreach(config ${CUDA_configuration_types})
+      string(TOUPPER ${config} config_upper)
+      if (arg STREQUAL "${config_upper}")
+        set( _found_config _${arg})
+        # Set arg to nothing to keep it from being processed further
+        set( arg )
+      endif()
+    endforeach()
+
+    if ( arg )
+      list(APPEND ${_option_prefix}${_found_config} "${arg}")
+    endif()
+  endforeach()
+endmacro()
+
+##############################################################################
+# Helper to add the include directory for CUDA only once
+function(CUDA_ADD_CUDA_INCLUDE_ONCE)
+  get_directory_property(_include_directories INCLUDE_DIRECTORIES)
+  set(_add TRUE)
+  if(_include_directories)
+    foreach(dir ${_include_directories})
+      if("${dir}" STREQUAL "${CUDA_INCLUDE_DIRS}")
+        set(_add FALSE)
+      endif()
+    endforeach()
+  endif()
+  if(_add)
+    include_directories(${CUDA_INCLUDE_DIRS})
+  endif()
+endfunction()
+
+function(CUDA_BUILD_SHARED_LIBRARY shared_flag)
+  set(cmake_args ${ARGN})
+  # If SHARED, MODULE, or STATIC aren't already in the list of arguments, then
+  # add SHARED or STATIC based on the value of BUILD_SHARED_LIBS.
+  list(FIND cmake_args SHARED _cuda_found_SHARED)
+  list(FIND cmake_args MODULE _cuda_found_MODULE)
+  list(FIND cmake_args STATIC _cuda_found_STATIC)
+  if( _cuda_found_SHARED GREATER -1 OR
+      _cuda_found_MODULE GREATER -1 OR
+      _cuda_found_STATIC GREATER -1)
+    set(_cuda_build_shared_libs)
+  else()
+    if (BUILD_SHARED_LIBS)
+      set(_cuda_build_shared_libs SHARED)
+    else()
+      set(_cuda_build_shared_libs STATIC)
+    endif()
+  endif()
+  set(${shared_flag} ${_cuda_build_shared_libs} PARENT_SCOPE)
+endfunction()
+
+##############################################################################
+# Helper to avoid clashes of files with the same basename but different paths.
+# This doesn't attempt to do exactly what CMake internals do, which is to only
+# add this path when there is a conflict, since by the time a second collision
+# in names is detected it's already too late to fix the first one.  For
+# consistency sake the relative path will be added to all files.
+function(CUDA_COMPUTE_BUILD_PATH path build_path)
+  #message("CUDA_COMPUTE_BUILD_PATH([${path}] ${build_path})")
+  # Only deal with CMake style paths from here on out
+  file(TO_CMAKE_PATH "${path}" bpath)
+  if (IS_ABSOLUTE "${bpath}")
+    # Absolute paths are generally unnecessary, especially if something like
+    # file(GLOB_RECURSE) is used to pick up the files.
+
+    string(FIND "${bpath}" "${CMAKE_CURRENT_BINARY_DIR}" _binary_dir_pos)
+    if (_binary_dir_pos EQUAL 0)
+      file(RELATIVE_PATH bpath "${CMAKE_CURRENT_BINARY_DIR}" "${bpath}")
+    else()
+      file(RELATIVE_PATH bpath "${CMAKE_CURRENT_SOURCE_DIR}" "${bpath}")
+    endif()
+  endif()
+
+  # This recipe is from cmLocalGenerator::CreateSafeUniqueObjectFileName in the
+  # CMake source.
+
+  # Remove leading /
+  string(REGEX REPLACE "^[/]+" "" bpath "${bpath}")
+  # Avoid absolute paths by removing ':'
+  string(REPLACE ":" "_" bpath "${bpath}")
+  # Avoid relative paths that go up the tree
+  string(REPLACE "../" "__/" bpath "${bpath}")
+  # Avoid spaces
+  string(REPLACE " " "_" bpath "${bpath}")
+
+  # Strip off the filename.  I wait until here to do it, since removing the
+  # basename can make a path that looked like path/../basename turn into
+  # path/.. (notice the trailing slash).
+  get_filename_component(bpath "${bpath}" PATH)
+
+  set(${build_path} "${bpath}" PARENT_SCOPE)
+  #message("${build_path} = ${bpath}")
+endfunction()
+
+##############################################################################
+# This helper macro populates the following variables and setups up custom
+# commands and targets to invoke the nvcc compiler to generate C or PTX source
+# dependent upon the format parameter.  The compiler is invoked once with -M
+# to generate a dependency file and a second time with -cuda or -ptx to generate
+# a .cpp or .ptx file.
+# INPUT:
+#   cuda_target         - Target name
+#   format              - PTX, CUBIN, FATBIN or OBJ
+#   FILE1 .. FILEN      - The remaining arguments are the sources to be wrapped.
+#   OPTIONS             - Extra options to NVCC
+# OUTPUT:
+#   generated_files     - List of generated files
+##############################################################################
+##############################################################################
+
+macro(CUDA_WRAP_SRCS cuda_target format generated_files)
+
+  # Put optional arguments in list.
+  set(_argn_list "${ARGN}")
+  # If one of the given optional arguments is "PHONY", make a note of it, then
+  # remove it from the list.
+  list(FIND _argn_list "PHONY" _phony_idx)
+  if("${_phony_idx}" GREATER "-1")
+    set(_target_is_phony true)
+    list(REMOVE_AT _argn_list ${_phony_idx})
+  else()
+    set(_target_is_phony false)
+  endif()
+
+  # If CMake doesn't support separable compilation, complain
+  if(CUDA_SEPARABLE_COMPILATION AND CMAKE_VERSION VERSION_LESS "2.8.10.1")
+    message(SEND_ERROR "CUDA_SEPARABLE_COMPILATION isn't supported for CMake versions less than 2.8.10.1")
+  endif()
+
+  # Set up all the command line flags here, so that they can be overridden on a per target basis.
+
+  set(nvcc_flags "")
+
+  # Emulation if the card isn't present.
+  if (CUDA_BUILD_EMULATION)
+    # Emulation.
+    set(nvcc_flags ${nvcc_flags} --device-emulation -D_DEVICEEMU -g)
+  else()
+    # Device mode.  No flags necessary.
+  endif()
+
+  if(CUDA_HOST_COMPILATION_CPP)
+    set(CUDA_C_OR_CXX CXX)
+  else()
+    message(WARNING "--host-compilation flag is deprecated in CUDA version >= 3.0.  Removing --host-compilation C flag" )
+    set(CUDA_C_OR_CXX C)
+  endif()
+
+  set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION})
+
+  if(CUDA_64_BIT_DEVICE_CODE)
+    set(nvcc_flags ${nvcc_flags} -m64)
+  else()
+    set(nvcc_flags ${nvcc_flags} -m32)
+  endif()
+
+  if(CUDA_TARGET_CPU_ARCH)
+    set(nvcc_flags ${nvcc_flags} "--target-cpu-architecture=${CUDA_TARGET_CPU_ARCH}")
+  endif()
+
+  # This needs to be passed in at this stage, because VS needs to fill out the
+  # various macros from within VS.  Note that CCBIN is only used if
+  # -ccbin or --compiler-bindir isn't used and CUDA_HOST_COMPILER matches
+  # _CUDA_MSVC_HOST_COMPILER
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    set(ccbin_flags -D "\"CCBIN:PATH=${_CUDA_MSVC_HOST_COMPILER}\"" )
+  else()
+    set(ccbin_flags)
+  endif()
+
+  # Figure out which configure we will use and pass that in as an argument to
+  # the script.  We need to defer the decision until compilation time, because
+  # for VS projects we won't know if we are making a debug or release build
+  # until build time.
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    set( CUDA_build_configuration "$(ConfigurationName)" )
+  else()
+    set( CUDA_build_configuration "${CMAKE_BUILD_TYPE}")
+  endif()
+
+  # Initialize our list of includes with the user ones followed by the CUDA system ones.
+  set(CUDA_NVCC_INCLUDE_DIRS ${CUDA_NVCC_INCLUDE_DIRS_USER} "${CUDA_INCLUDE_DIRS}")
+  if(_target_is_phony)
+    # If the passed in target name isn't a real target (i.e., this is from a call to one of the
+    # cuda_compile_* functions), need to query directory properties to get include directories
+    # and compile definitions.
+    get_directory_property(_dir_include_dirs INCLUDE_DIRECTORIES)
+    get_directory_property(_dir_compile_defs COMPILE_DEFINITIONS)
+
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS "${_dir_include_dirs}")
+    set(CUDA_NVCC_COMPILE_DEFINITIONS "${_dir_compile_defs}")
+  else()
+    # Append the include directories for this target via generator expression, which is
+    # expanded by the FILE(GENERATE) call below.  This generator expression captures all
+    # include dirs set by the user, whether via directory properties or target properties
+    list(APPEND CUDA_NVCC_INCLUDE_DIRS "$")
+
+    # Do the same thing with compile definitions
+    set(CUDA_NVCC_COMPILE_DEFINITIONS "$")
+  endif()
+
+
+  # Reset these variables
+  set(CUDA_WRAP_OPTION_NVCC_FLAGS)
+  foreach(config ${CUDA_configuration_types})
+    string(TOUPPER ${config} config_upper)
+    set(CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper})
+  endforeach()
+
+  CUDA_GET_SOURCES_AND_OPTIONS(_cuda_wrap_sources _cuda_wrap_cmake_options _cuda_wrap_options ${_argn_list})
+  CUDA_PARSE_NVCC_OPTIONS(CUDA_WRAP_OPTION_NVCC_FLAGS ${_cuda_wrap_options})
+
+  # Figure out if we are building a shared library.  BUILD_SHARED_LIBS is
+  # respected in CUDA_ADD_LIBRARY.
+  set(_cuda_build_shared_libs FALSE)
+  # SHARED, MODULE
+  list(FIND _cuda_wrap_cmake_options SHARED _cuda_found_SHARED)
+  list(FIND _cuda_wrap_cmake_options MODULE _cuda_found_MODULE)
+  if(_cuda_found_SHARED GREATER -1 OR _cuda_found_MODULE GREATER -1)
+    set(_cuda_build_shared_libs TRUE)
+  endif()
+  # STATIC
+  list(FIND _cuda_wrap_cmake_options STATIC _cuda_found_STATIC)
+  if(_cuda_found_STATIC GREATER -1)
+    set(_cuda_build_shared_libs FALSE)
+  endif()
+
+  # CUDA_HOST_FLAGS
+  if(_cuda_build_shared_libs)
+    # If we are setting up code for a shared library, then we need to add extra flags for
+    # compiling objects for shared libraries.
+    set(CUDA_HOST_SHARED_FLAGS ${CMAKE_SHARED_LIBRARY_${CUDA_C_OR_CXX}_FLAGS})
+  else()
+    set(CUDA_HOST_SHARED_FLAGS)
+  endif()
+
+  macro(_filter_blocklisted_host_flags CUDA_FLAGS)
+    string(REGEX REPLACE "[ \t]+" ";" ${CUDA_FLAGS} "${${CUDA_FLAGS}}")
+    foreach(_blacklisted ${CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST})
+      list(REMOVE_ITEM ${CUDA_FLAGS} "${_blacklisted}")
+    endforeach()
+    string(REPLACE ";" " " ${CUDA_FLAGS} "${${CUDA_FLAGS}}")
+  endmacro()
+
+  # Only add the CMAKE_{C,CXX}_FLAGS if we are propagating host flags.  We
+  # always need to set the SHARED_FLAGS, though.
+  if(CUDA_PROPAGATE_HOST_FLAGS)
+    set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}")
+    _filter_blocklisted_host_flags(_cuda_C_FLAGS)
+    set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${_cuda_C_FLAGS} ${CUDA_HOST_SHARED_FLAGS})")
+  else()
+    set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${CUDA_HOST_SHARED_FLAGS})")
+  endif()
+
+  set(_cuda_nvcc_flags_config "# Build specific configuration flags")
+  # Loop over all the configuration types to generate appropriate flags for run_nvcc.cmake
+  foreach(config ${CUDA_configuration_types})
+    string(TOUPPER ${config} config_upper)
+    # CMAKE_FLAGS are strings and not lists.  By not putting quotes around CMAKE_FLAGS
+    # we convert the strings to lists (like we want).
+
+    if(CUDA_PROPAGATE_HOST_FLAGS)
+      # nvcc chokes on -g3 in versions previous to 3.0, so replace it with -g
+      set(_cuda_fix_g3 FALSE)
+
+      set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}")
+      _filter_blocklisted_host_flags(_cuda_C_FLAGS)
+      if(_cuda_fix_g3)
+        string(REPLACE "-g3" "-g" _cuda_C_FLAGS "${_cuda_C_FLAGS}")
+      endif()
+
+      string(APPEND _cuda_host_flags "\nset(CMAKE_HOST_FLAGS_${config_upper} ${_cuda_C_FLAGS})")
+    endif()
+
+    # Note that if we ever want CUDA_NVCC_FLAGS_ to be string (instead of a list
+    # like it is currently), we can remove the quotes around the
+    # ${CUDA_NVCC_FLAGS_${config_upper}} variable like the CMAKE_HOST_FLAGS_ variable.
+    string(APPEND _cuda_nvcc_flags_config "\nset(CUDA_NVCC_FLAGS_${config_upper} ${CUDA_NVCC_FLAGS_${config_upper}} ;; ${CUDA_WRAP_OPTION_NVCC_FLAGS_${config_upper}})")
+  endforeach()
+
+  # Process the C++14 flag.  If the host sets the flag, we need to add it to nvcc and
+  # remove it from the host. This is because -Xcompile -std=c++ will choke nvcc (it uses
+  # the C preprocessor).  In order to get this to work correctly, we need to use nvcc's
+  # specific c++14 flag.
+  if( "${_cuda_host_flags}" MATCHES "-std=c\\+\\+11")
+    # Add the c++14 flag to nvcc if it isn't already present.  Note that we only look at
+    # the main flag instead of the configuration specific flags.
+    if( NOT "${CUDA_NVCC_FLAGS}" MATCHES "-std=c\\+\\+14" )
+      list(APPEND nvcc_flags --std c++14)
+    endif()
+    string(REGEX REPLACE "[-]+std=c\\+\\+14" "" _cuda_host_flags "${_cuda_host_flags}")
+  endif()
+
+  if(_cuda_build_shared_libs)
+    list(APPEND nvcc_flags "-D${cuda_target}_EXPORTS")
+  endif()
+
+  # Reset the output variable
+  set(_cuda_wrap_generated_files "")
+
+  # Iterate over the macro arguments and create custom
+  # commands for all the .cu files.
+  foreach(file ${_argn_list})
+    # Ignore any file marked as a HEADER_FILE_ONLY
+    get_source_file_property(_is_header ${file} HEADER_FILE_ONLY)
+    # Allow per source file overrides of the format.  Also allows compiling non-.cu files.
+    get_source_file_property(_cuda_source_format ${file} CUDA_SOURCE_PROPERTY_FORMAT)
+    if((${file} MATCHES "\\.cu$" OR _cuda_source_format) AND NOT _is_header)
+
+      if(NOT _cuda_source_format)
+        set(_cuda_source_format ${format})
+      endif()
+      # If file isn't a .cu file, we need to tell nvcc to treat it as such.
+      if(NOT file MATCHES "\\.cu$")
+        set(cuda_language_flag -x=cu)
+      else()
+        set(cuda_language_flag)
+      endif()
+
+      if( ${_cuda_source_format} MATCHES "OBJ")
+        set( cuda_compile_to_external_module OFF )
+      else()
+        set( cuda_compile_to_external_module ON )
+        if( ${_cuda_source_format} MATCHES "PTX" )
+          set( cuda_compile_to_external_module_type "ptx" )
+        elseif( ${_cuda_source_format} MATCHES "CUBIN")
+          set( cuda_compile_to_external_module_type "cubin" )
+        elseif( ${_cuda_source_format} MATCHES "FATBIN")
+          set( cuda_compile_to_external_module_type "fatbin" )
+        else()
+          message( FATAL_ERROR "Invalid format flag passed to CUDA_WRAP_SRCS or set with CUDA_SOURCE_PROPERTY_FORMAT file property for file '${file}': '${_cuda_source_format}'.  Use OBJ, PTX, CUBIN or FATBIN.")
+        endif()
+      endif()
+
+      if(cuda_compile_to_external_module)
+        # Don't use any of the host compilation flags for PTX targets.
+        set(CUDA_HOST_FLAGS)
+        set(CUDA_NVCC_FLAGS_CONFIG)
+      else()
+        set(CUDA_HOST_FLAGS ${_cuda_host_flags})
+        set(CUDA_NVCC_FLAGS_CONFIG ${_cuda_nvcc_flags_config})
+      endif()
+
+      # Determine output directory
+      cuda_compute_build_path("${file}" cuda_build_path)
+      set(cuda_compile_intermediate_directory "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${cuda_build_path}")
+      if(CUDA_GENERATED_OUTPUT_DIR)
+        set(cuda_compile_output_dir "${CUDA_GENERATED_OUTPUT_DIR}")
+      else()
+        if ( cuda_compile_to_external_module )
+          set(cuda_compile_output_dir "${CMAKE_CURRENT_BINARY_DIR}")
+        else()
+          set(cuda_compile_output_dir "${cuda_compile_intermediate_directory}")
+        endif()
+      endif()
+
+      # Add a custom target to generate a c or ptx file. ######################
+
+      get_filename_component( basename ${file} NAME )
+      if( cuda_compile_to_external_module )
+        set(generated_file_path "${cuda_compile_output_dir}")
+        set(generated_file_basename "${cuda_target}_generated_${basename}.${cuda_compile_to_external_module_type}")
+        set(format_flag "-${cuda_compile_to_external_module_type}")
+        file(MAKE_DIRECTORY "${cuda_compile_output_dir}")
+      else()
+        set(generated_file_path "${cuda_compile_output_dir}/${CMAKE_CFG_INTDIR}")
+        set(generated_file_basename "${cuda_target}_generated_${basename}${generated_extension}")
+        if(CUDA_SEPARABLE_COMPILATION)
+          set(format_flag "-dc")
+        else()
+          set(format_flag "-c")
+        endif()
+      endif()
+
+      # Set all of our file names.  Make sure that whatever filenames that have
+      # generated_file_path in them get passed in through as a command line
+      # argument, so that the ${CMAKE_CFG_INTDIR} gets expanded at run time
+      # instead of configure time.
+      set(generated_file "${generated_file_path}/${generated_file_basename}")
+      set(cmake_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.depend")
+      set(NVCC_generated_dependency_file "${cuda_compile_intermediate_directory}/${generated_file_basename}.NVCC-depend")
+      set(generated_cubin_file "${generated_file_path}/${generated_file_basename}.cubin.txt")
+      set(custom_target_script_pregen "${cuda_compile_intermediate_directory}/${generated_file_basename}.cmake.pre-gen")
+      set(custom_target_script "${cuda_compile_intermediate_directory}/${generated_file_basename}$<$>:.$>.cmake")
+
+      # Setup properties for obj files:
+      if( NOT cuda_compile_to_external_module )
+        set_source_files_properties("${generated_file}"
+          PROPERTIES
+          EXTERNAL_OBJECT true # This is an object file not to be compiled, but only be linked.
+          )
+      endif()
+
+      # Don't add CMAKE_CURRENT_SOURCE_DIR if the path is already an absolute path.
+      get_filename_component(file_path "${file}" PATH)
+      if(IS_ABSOLUTE "${file_path}")
+        set(source_file "${file}")
+      else()
+        set(source_file "${CMAKE_CURRENT_SOURCE_DIR}/${file}")
+      endif()
+
+      if( NOT cuda_compile_to_external_module AND CUDA_SEPARABLE_COMPILATION)
+        list(APPEND ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS "${generated_file}")
+      endif()
+
+      # Bring in the dependencies.  Creates a variable CUDA_NVCC_DEPEND #######
+      cuda_include_nvcc_dependencies(${cmake_dependency_file})
+
+      # Convenience string for output #########################################
+      if(CUDA_BUILD_EMULATION)
+        set(cuda_build_type "Emulation")
+      else()
+        set(cuda_build_type "Device")
+      endif()
+
+      # Build the NVCC made dependency file ###################################
+      set(build_cubin OFF)
+      if ( NOT CUDA_BUILD_EMULATION AND CUDA_BUILD_CUBIN )
+         if ( NOT cuda_compile_to_external_module )
+           set ( build_cubin ON )
+         endif()
+      endif()
+
+      # Configure the build script
+      configure_file("${CUDA_run_nvcc}" "${custom_target_script_pregen}" @ONLY)
+      file(GENERATE
+        OUTPUT "${custom_target_script}"
+        INPUT "${custom_target_script_pregen}"
+        )
+
+      # So if a user specifies the same cuda file as input more than once, you
+      # can have bad things happen with dependencies.  Here we check an option
+      # to see if this is the behavior they want.
+      if(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE)
+        set(main_dep MAIN_DEPENDENCY ${source_file})
+      else()
+        set(main_dep DEPENDS ${source_file})
+      endif()
+
+      if(CUDA_VERBOSE_BUILD)
+        set(verbose_output ON)
+      elseif(CMAKE_GENERATOR MATCHES "Makefiles")
+        set(verbose_output "$(VERBOSE)")
+      # This condition lets us also turn on verbose output when someone
+      # specifies CMAKE_VERBOSE_MAKEFILE, even if the generator isn't
+      # the Makefiles generator (this is important for us, Ninja users.)
+      elseif(CMAKE_VERBOSE_MAKEFILE)
+        set(verbose_output ON)
+      else()
+        set(verbose_output OFF)
+      endif()
+
+      # Create up the comment string
+      file(RELATIVE_PATH generated_file_relative_path "${CMAKE_BINARY_DIR}" "${generated_file}")
+      if(cuda_compile_to_external_module)
+        set(cuda_build_comment_string "Building NVCC ${cuda_compile_to_external_module_type} file ${generated_file_relative_path}")
+      else()
+        set(cuda_build_comment_string "Building NVCC (${cuda_build_type}) object ${generated_file_relative_path}")
+      endif()
+
+      set(_verbatim VERBATIM)
+      if(ccbin_flags MATCHES "\\$\\(VCInstallDir\\)")
+        set(_verbatim "")
+      endif()
+
+      # Build the generated file and dependency file ##########################
+      add_custom_command(
+        OUTPUT ${generated_file}
+        # These output files depend on the source_file and the contents of cmake_dependency_file
+        ${main_dep}
+        DEPENDS ${CUDA_NVCC_DEPEND}
+        DEPENDS ${custom_target_script}
+        # Make sure the output directory exists before trying to write to it.
+        COMMAND ${CMAKE_COMMAND} -E make_directory "${generated_file_path}"
+        COMMAND ${CMAKE_COMMAND} ARGS
+          -D verbose:BOOL=${verbose_output}
+          ${ccbin_flags}
+          -D build_configuration:STRING=${CUDA_build_configuration}
+          -D "generated_file:STRING=${generated_file}"
+          -D "generated_cubin_file:STRING=${generated_cubin_file}"
+          -P "${custom_target_script}"
+        WORKING_DIRECTORY "${cuda_compile_intermediate_directory}"
+        COMMENT "${cuda_build_comment_string}"
+        ${_verbatim}
+        )
+
+      # Make sure the build system knows the file is generated.
+      set_source_files_properties(${generated_file} PROPERTIES GENERATED TRUE)
+
+      list(APPEND _cuda_wrap_generated_files ${generated_file})
+
+      # Add the other files that we want cmake to clean on a cleanup ##########
+      list(APPEND CUDA_ADDITIONAL_CLEAN_FILES "${cmake_dependency_file}")
+      list(REMOVE_DUPLICATES CUDA_ADDITIONAL_CLEAN_FILES)
+      set(CUDA_ADDITIONAL_CLEAN_FILES ${CUDA_ADDITIONAL_CLEAN_FILES} CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.")
+
+    endif()
+  endforeach()
+
+  # Set the return parameter
+  set(${generated_files} ${_cuda_wrap_generated_files})
+endmacro()
+
+function(_cuda_get_important_host_flags important_flags flag_string)
+  if(CMAKE_GENERATOR MATCHES "Visual Studio")
+    string(REGEX MATCHALL "/M[DT][d]?" flags "${flag_string}")
+    list(APPEND ${important_flags} ${flags})
+  else()
+    string(REGEX MATCHALL "-fPIC" flags "${flag_string}")
+    list(APPEND ${important_flags} ${flags})
+  endif()
+  set(${important_flags} ${${important_flags}} PARENT_SCOPE)
+endfunction()
+
+###############################################################################
+###############################################################################
+# Separable Compilation Link
+###############################################################################
+###############################################################################
+
+# Compute the filename to be used by CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS
+function(CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME output_file_var cuda_target object_files)
+  if (object_files)
+    set(generated_extension ${CMAKE_${CUDA_C_OR_CXX}_OUTPUT_EXTENSION})
+    set(output_file "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${cuda_target}.dir/${CMAKE_CFG_INTDIR}/${cuda_target}_intermediate_link${generated_extension}")
+  else()
+    set(output_file)
+  endif()
+
+  set(${output_file_var} "${output_file}" PARENT_SCOPE)
+endfunction()
+
+# Setup the build rule for the separable compilation intermediate link file.
+function(CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS output_file cuda_target options object_files)
+  if (object_files)
+
+    set_source_files_properties("${output_file}"
+      PROPERTIES
+      EXTERNAL_OBJECT TRUE # This is an object file not to be compiled, but only
+                           # be linked.
+      GENERATED TRUE       # This file is generated during the build
+      )
+
+    # For now we are ignoring all the configuration specific flags.
+    set(nvcc_flags)
+    CUDA_PARSE_NVCC_OPTIONS(nvcc_flags ${options})
+    if(CUDA_64_BIT_DEVICE_CODE)
+      list(APPEND nvcc_flags -m64)
+    else()
+      list(APPEND nvcc_flags -m32)
+    endif()
+    # If -ccbin, --compiler-bindir has been specified, don't do anything.  Otherwise add it here.
+    list( FIND nvcc_flags "-ccbin" ccbin_found0 )
+    list( FIND nvcc_flags "--compiler-bindir" ccbin_found1 )
+    if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER )
+      # Match VERBATIM check below.
+      if(CUDA_HOST_COMPILER MATCHES "\\$\\(VCInstallDir\\)")
+        list(APPEND nvcc_flags -ccbin "\"${CUDA_HOST_COMPILER}\"")
+      else()
+        list(APPEND nvcc_flags -ccbin "${CUDA_HOST_COMPILER}")
+      endif()
+    endif()
+
+    # Create a list of flags specified by CUDA_NVCC_FLAGS_${CONFIG} and CMAKE_${CUDA_C_OR_CXX}_FLAGS*
+    set(config_specific_flags)
+    set(flags)
+    foreach(config ${CUDA_configuration_types})
+      string(TOUPPER ${config} config_upper)
+      # Add config specific flags
+      foreach(f ${CUDA_NVCC_FLAGS_${config_upper}})
+        list(APPEND config_specific_flags $<$:${f}>)
+      endforeach()
+      set(important_host_flags)
+      _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}")
+      foreach(f ${important_host_flags})
+        list(APPEND flags $<$:-Xcompiler> $<$:${f}>)
+      endforeach()
+    endforeach()
+    # Add CMAKE_${CUDA_C_OR_CXX}_FLAGS
+    set(important_host_flags)
+    _cuda_get_important_host_flags(important_host_flags "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}")
+    foreach(f ${important_host_flags})
+      list(APPEND flags -Xcompiler ${f})
+    endforeach()
+
+    # Add our general CUDA_NVCC_FLAGS with the configuration specific flags
+    set(nvcc_flags ${CUDA_NVCC_FLAGS} ${config_specific_flags} ${nvcc_flags})
+
+    file(RELATIVE_PATH output_file_relative_path "${CMAKE_BINARY_DIR}" "${output_file}")
+
+    # Some generators don't handle the multiple levels of custom command
+    # dependencies correctly (obj1 depends on file1, obj2 depends on obj1), so
+    # we work around that issue by compiling the intermediate link object as a
+    # pre-link custom command in that situation.
+    set(do_obj_build_rule TRUE)
+    if (MSVC_VERSION GREATER 1599 AND MSVC_VERSION LESS 1800)
+      # VS 2010 and 2012 have this problem.
+      set(do_obj_build_rule FALSE)
+    endif()
+
+    set(_verbatim VERBATIM)
+    if(nvcc_flags MATCHES "\\$\\(VCInstallDir\\)")
+      set(_verbatim "")
+    endif()
+
+    if (do_obj_build_rule)
+      add_custom_command(
+        OUTPUT ${output_file}
+        DEPENDS ${object_files}
+        COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} -dlink ${object_files} -o ${output_file}
+        ${flags}
+        COMMENT "Building NVCC intermediate link file ${output_file_relative_path}"
+        COMMAND_EXPAND_LISTS
+        ${_verbatim}
+        )
+    else()
+      get_filename_component(output_file_dir "${output_file}" DIRECTORY)
+      add_custom_command(
+        TARGET ${cuda_target}
+        PRE_LINK
+        COMMAND ${CMAKE_COMMAND} -E echo "Building NVCC intermediate link file ${output_file_relative_path}"
+        COMMAND ${CMAKE_COMMAND} -E make_directory "${output_file_dir}"
+        COMMAND ${CUDA_NVCC_EXECUTABLE} ${nvcc_flags} ${flags} -dlink ${object_files} -o "${output_file}"
+        COMMAND_EXPAND_LISTS
+        ${_verbatim}
+        )
+    endif()
+ endif()
+endfunction()
+
+###############################################################################
+###############################################################################
+# ADD LIBRARY
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_LIBRARY cuda_target)
+
+  CUDA_ADD_CUDA_INCLUDE_ONCE()
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+  CUDA_BUILD_SHARED_LIBRARY(_cuda_shared_flag ${ARGN})
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources}
+    ${_cmake_options} ${_cuda_shared_flag}
+    OPTIONS ${_options} )
+
+  # Compute the file name of the intermedate link file used for separable
+  # compilation.
+  CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  # Add the library.
+  add_library(${cuda_target} ${_cmake_options}
+    ${_generated_files}
+    ${_sources}
+    ${link_file}
+    )
+
+  # Add a link phase for the separable compilation if it has been enabled.  If
+  # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS
+  # variable will have been defined.
+  CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+    ${CUDA_LIBRARIES}
+    )
+
+  if(CUDA_SEPARABLE_COMPILATION)
+    target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+      ${CUDA_cudadevrt_LIBRARY}
+      )
+  endif()
+
+  # We need to set the linker language based on what the expected generated file
+  # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP.
+  set_target_properties(${cuda_target}
+    PROPERTIES
+    LINKER_LANGUAGE ${CUDA_C_OR_CXX}
+    )
+
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# ADD EXECUTABLE
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_EXECUTABLE cuda_target)
+
+  CUDA_ADD_CUDA_INCLUDE_ONCE()
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${cuda_target} OBJ _generated_files ${_sources} OPTIONS ${_options} )
+
+  # Compute the file name of the intermedate link file used for separable
+  # compilation.
+  CUDA_COMPUTE_SEPARABLE_COMPILATION_OBJECT_FILE_NAME(link_file ${cuda_target} "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  # Add the library.
+  add_executable(${cuda_target} ${_cmake_options}
+    ${_generated_files}
+    ${_sources}
+    ${link_file}
+    )
+
+  # Add a link phase for the separable compilation if it has been enabled.  If
+  # it has been enabled then the ${cuda_target}_SEPARABLE_COMPILATION_OBJECTS
+  # variable will have been defined.
+  CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS("${link_file}" ${cuda_target} "${_options}" "${${cuda_target}_SEPARABLE_COMPILATION_OBJECTS}")
+
+  target_link_libraries(${cuda_target} ${CUDA_LINK_LIBRARIES_KEYWORD}
+    ${CUDA_LIBRARIES}
+    )
+
+  # We need to set the linker language based on what the expected generated file
+  # would be. CUDA_C_OR_CXX is computed based on CUDA_HOST_COMPILATION_CPP.
+  set_target_properties(${cuda_target}
+    PROPERTIES
+    LINKER_LANGUAGE ${CUDA_C_OR_CXX}
+    )
+
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# (Internal) helper for manually added cuda source files with specific targets
+###############################################################################
+###############################################################################
+macro(cuda_compile_base cuda_target format generated_files)
+  # Update a counter in this directory, to keep phony target names unique.
+  set(_cuda_target "${cuda_target}")
+  get_property(_counter DIRECTORY PROPERTY _cuda_internal_phony_counter)
+  if(_counter)
+    math(EXPR _counter "${_counter} + 1")
+  else()
+    set(_counter 1)
+  endif()
+  string(APPEND _cuda_target "_${_counter}")
+  set_property(DIRECTORY PROPERTY _cuda_internal_phony_counter ${_counter})
+
+  # Separate the sources from the options
+  CUDA_GET_SOURCES_AND_OPTIONS(_sources _cmake_options _options ${ARGN})
+
+  # Create custom commands and targets for each file.
+  CUDA_WRAP_SRCS( ${_cuda_target} ${format} _generated_files ${_sources}
+                  ${_cmake_options} OPTIONS ${_options} PHONY)
+
+  set( ${generated_files} ${_generated_files})
+
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE generated_files)
+  cuda_compile_base(cuda_compile OBJ ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE PTX
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_PTX generated_files)
+  cuda_compile_base(cuda_compile_ptx PTX ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE FATBIN
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_FATBIN generated_files)
+  cuda_compile_base(cuda_compile_fatbin FATBIN ${generated_files} ${ARGN})
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA COMPILE CUBIN
+###############################################################################
+###############################################################################
+macro(CUDA_COMPILE_CUBIN generated_files)
+  cuda_compile_base(cuda_compile_cubin CUBIN ${generated_files} ${ARGN})
+endmacro()
+
+
+###############################################################################
+###############################################################################
+# CUDA ADD CUFFT TO TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_CUFFT_TO_TARGET target)
+  if (CUDA_BUILD_EMULATION)
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufftemu_LIBRARY})
+  else()
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cufft_LIBRARY})
+  endif()
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA ADD CUBLAS TO TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_ADD_CUBLAS_TO_TARGET target)
+  if (CUDA_BUILD_EMULATION)
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublasemu_LIBRARY})
+  else()
+    target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
+  endif()
+endmacro()
+
+###############################################################################
+###############################################################################
+# CUDA BUILD CLEAN TARGET
+###############################################################################
+###############################################################################
+macro(CUDA_BUILD_CLEAN_TARGET)
+  # Call this after you add all your CUDA targets, and you will get a
+  # convenience target.  You should also make clean after running this target
+  # to get the build system to generate all the code again.
+
+  set(cuda_clean_target_name clean_cuda_depends)
+  if (CMAKE_GENERATOR MATCHES "Visual Studio")
+    string(TOUPPER ${cuda_clean_target_name} cuda_clean_target_name)
+  endif()
+  add_custom_target(${cuda_clean_target_name}
+    COMMAND ${CMAKE_COMMAND} -E remove ${CUDA_ADDITIONAL_CLEAN_FILES})
+
+  # Clear out the variable, so the next time we configure it will be empty.
+  # This is useful so that the files won't persist in the list after targets
+  # have been removed.
+  set(CUDA_ADDITIONAL_CLEAN_FILES "" CACHE INTERNAL "List of intermediate files that are part of the cuda dependency scanning.")
+endmacro()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..580f24a400d8c5662ec572c4631db9e3e47645d9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake
@@ -0,0 +1,106 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#  Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  Copyright (c) 2007-2009
+#  Scientific Computing and Imaging Institute, University of Utah
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+
+#######################################################################
+# This converts a file written in makefile syntax into one that can be included
+# by CMake.
+
+# Input variables
+#
+# verbose:BOOL=<>          OFF: Be as quiet as possible (default)
+#                          ON : Extra output
+#
+# input_file:FILEPATH=<>   Path to dependency file in makefile format
+#
+# output_file:FILEPATH=<>  Path to file with dependencies in CMake readable variable
+#
+
+file(READ ${input_file} depend_text)
+
+if (NOT "${depend_text}" STREQUAL "")
+
+  # message("FOUND DEPENDS")
+
+  string(REPLACE "\\ " " " depend_text ${depend_text})
+
+  # This works for the nvcc -M generated dependency files.
+  string(REGEX REPLACE "^.* : " "" depend_text ${depend_text})
+  string(REGEX REPLACE "[ \\\\]*\n" ";" depend_text ${depend_text})
+
+  set(dependency_list "")
+
+  foreach(file ${depend_text})
+
+    string(REGEX REPLACE "^ +" "" file ${file})
+
+    # OK, now if we had a UNC path, nvcc has a tendency to only output the first '/'
+    # instead of '//'.  Here we will test to see if the file exists, if it doesn't then
+    # try to prepend another '/' to the path and test again.  If it still fails remove the
+    # path.
+
+    if(NOT EXISTS "${file}")
+      if (EXISTS "/${file}")
+        set(file "/${file}")
+      else()
+        if(verbose)
+          message(WARNING " Removing non-existent dependency file: ${file}")
+        endif()
+        set(file "")
+      endif()
+    endif()
+
+    # Make sure we check to see if we have a file, before asking if it is not a directory.
+    # if(NOT IS_DIRECTORY "") will return TRUE.
+    if(file AND NOT IS_DIRECTORY "${file}")
+      # If softlinks start to matter, we should change this to REALPATH.  For now we need
+      # to flatten paths, because nvcc can generate stuff like /bin/../include instead of
+      # just /include.
+      get_filename_component(file_absolute "${file}" ABSOLUTE)
+      list(APPEND dependency_list "${file_absolute}")
+    endif()
+
+  endforeach()
+
+else()
+  # message("FOUND NO DEPENDS")
+endif()
+
+# Remove the duplicate entries and sort them.
+list(REMOVE_DUPLICATES dependency_list)
+list(SORT dependency_list)
+
+foreach(file ${dependency_list})
+  string(APPEND cuda_nvcc_depend " \"${file}\"\n")
+endforeach()
+
+file(WRITE ${output_file} "# Generated by: make2cmake.cmake\nSET(CUDA_NVCC_DEPEND\n ${cuda_nvcc_depend})\n\n")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..25ceb49f3dd8e684e35cac49834c4db0aa5c338a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake
@@ -0,0 +1,109 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#  Abe Stephens, SCI Institute -- http://www.sci.utah.edu/~abe/FindCuda.html
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  Copyright (c) 2007-2009
+#  Scientific Computing and Imaging Institute, University of Utah
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+#
+
+#######################################################################
+# Parses a .cubin file produced by nvcc and reports statistics about the file.
+
+
+file(READ ${input_file} file_text)
+
+if (NOT "${file_text}" STREQUAL "")
+
+  string(REPLACE ";" "\\;" file_text ${file_text})
+  string(REPLACE "\ncode" ";code" file_text ${file_text})
+
+  list(LENGTH file_text len)
+
+  foreach(line ${file_text})
+
+    # Only look at "code { }" blocks.
+    if(line MATCHES "^code")
+
+      # Break into individual lines.
+      string(REGEX REPLACE "\n" ";" line ${line})
+
+      foreach(entry ${line})
+
+        # Extract kernel names.
+        if (${entry} MATCHES "[^g]name = ([^ ]+)")
+          set(entry "${CMAKE_MATCH_1}")
+
+          # Check to see if the kernel name starts with "_"
+          set(skip FALSE)
+          # if (${entry} MATCHES "^_")
+            # Skip the rest of this block.
+            # message("Skipping ${entry}")
+            # set(skip TRUE)
+          # else ()
+            message("Kernel:    ${entry}")
+          # endif ()
+
+        endif()
+
+        # Skip the rest of the block if necessary
+        if(NOT skip)
+
+          # Registers
+          if (${entry} MATCHES "reg([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Registers: ${entry}")
+          endif()
+
+          # Local memory
+          if (${entry} MATCHES "lmem([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Local:     ${entry}")
+          endif()
+
+          # Shared memory
+          if (${entry} MATCHES "smem([ ]+)=([ ]+)([^ ]+)")
+            set(entry "${CMAKE_MATCH_3}")
+            message("Shared:    ${entry}")
+          endif()
+
+          if (${entry} MATCHES "^}")
+            message("")
+          endif()
+
+        endif()
+
+
+      endforeach()
+
+    endif()
+
+  endforeach()
+
+else()
+  # message("FOUND NO DEPENDS")
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..59c5c11a1091f34df89b681a926db602a1c75caa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake
@@ -0,0 +1,303 @@
+#  James Bigler, NVIDIA Corp (nvidia.com - jbigler)
+#
+#  Copyright (c) 2008 - 2009 NVIDIA Corporation.  All rights reserved.
+#
+#  This code is licensed under the MIT License.  See the FindCUDA.cmake script
+#  for the text of the license.
+
+# The MIT License
+#
+# License for the specific language governing rights and limitations under
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+
+##########################################################################
+# This file runs the nvcc commands to produce the desired output file along with
+# the dependency file needed by CMake to compute dependencies.  In addition the
+# file checks the output of each command and if the command fails it deletes the
+# output files.
+
+# Input variables
+#
+# verbose:BOOL=<>          OFF: Be as quiet as possible (default)
+#                          ON : Describe each step
+#
+# build_configuration:STRING=<> Typically one of Debug, MinSizeRel, Release, or
+#                               RelWithDebInfo, but it should match one of the
+#                               entries in CUDA_HOST_FLAGS. This is the build
+#                               configuration used when compiling the code.  If
+#                               blank or unspecified Debug is assumed as this is
+#                               what CMake does.
+#
+# generated_file:STRING=<> File to generate.  This argument must be passed in.
+#
+# generated_cubin_file:STRING=<> File to generate.  This argument must be passed
+#                                                   in if build_cubin is true.
+
+cmake_policy(PUSH)
+cmake_policy(SET CMP0007 NEW)
+cmake_policy(SET CMP0010 NEW)
+if(NOT generated_file)
+  message(FATAL_ERROR "You must specify generated_file on the command line")
+endif()
+
+# Set these up as variables to make reading the generated file easier
+set(CMAKE_COMMAND "@CMAKE_COMMAND@") # path
+set(source_file "@source_file@") # path
+set(NVCC_generated_dependency_file "@NVCC_generated_dependency_file@") # path
+set(cmake_dependency_file "@cmake_dependency_file@") # path
+set(CUDA_make2cmake "@CUDA_make2cmake@") # path
+set(CUDA_parse_cubin "@CUDA_parse_cubin@") # path
+set(build_cubin @build_cubin@) # bool
+set(CUDA_HOST_COMPILER "@CUDA_HOST_COMPILER@") # path
+# We won't actually use these variables for now, but we need to set this, in
+# order to force this file to be run again if it changes.
+set(generated_file_path "@generated_file_path@") # path
+set(generated_file_internal "@generated_file@") # path
+set(generated_cubin_file_internal "@generated_cubin_file@") # path
+
+set(CUDA_NVCC_EXECUTABLE "@CUDA_NVCC_EXECUTABLE@") # path
+set(CUDA_NVCC_FLAGS @CUDA_NVCC_FLAGS@ ;; @CUDA_WRAP_OPTION_NVCC_FLAGS@) # list
+@CUDA_NVCC_FLAGS_CONFIG@
+set(nvcc_flags @nvcc_flags@) # list
+set(CUDA_NVCC_INCLUDE_DIRS [==[@CUDA_NVCC_INCLUDE_DIRS@]==]) # list (needs to be in lua quotes to address backslashes)
+string(REPLACE "\\" "/" CUDA_NVCC_INCLUDE_DIRS "${CUDA_NVCC_INCLUDE_DIRS}")
+set(CUDA_NVCC_COMPILE_DEFINITIONS [==[@CUDA_NVCC_COMPILE_DEFINITIONS@]==]) # list (needs to be in lua quotes see #16510 ).
+set(format_flag "@format_flag@") # string
+set(cuda_language_flag @cuda_language_flag@) # list
+
+# Clean up list of include directories and add -I flags
+list(REMOVE_DUPLICATES CUDA_NVCC_INCLUDE_DIRS)
+set(CUDA_NVCC_INCLUDE_ARGS)
+foreach(dir ${CUDA_NVCC_INCLUDE_DIRS})
+  # Extra quotes are added around each flag to help nvcc parse out flags with spaces.
+  list(APPEND CUDA_NVCC_INCLUDE_ARGS "-I${dir}")
+endforeach()
+
+# Clean up list of compile definitions, add -D flags, and append to nvcc_flags
+list(REMOVE_DUPLICATES CUDA_NVCC_COMPILE_DEFINITIONS)
+foreach(def ${CUDA_NVCC_COMPILE_DEFINITIONS})
+  list(APPEND nvcc_flags "-D${def}")
+endforeach()
+
+if(build_cubin AND NOT generated_cubin_file)
+  message(FATAL_ERROR "You must specify generated_cubin_file on the command line")
+endif()
+
+# This is the list of host compilation flags.  It C or CXX should already have
+# been chosen by FindCUDA.cmake.
+@CUDA_HOST_FLAGS@
+
+# Take the compiler flags and package them up to be sent to the compiler via -Xcompiler
+set(nvcc_host_compiler_flags "")
+# If we weren't given a build_configuration, use Debug.
+if(NOT build_configuration)
+  set(build_configuration Debug)
+endif()
+string(TOUPPER "${build_configuration}" build_configuration)
+#message("CUDA_NVCC_HOST_COMPILER_FLAGS = ${CUDA_NVCC_HOST_COMPILER_FLAGS}")
+foreach(flag ${CMAKE_HOST_FLAGS} ${CMAKE_HOST_FLAGS_${build_configuration}})
+  # Extra quotes are added around each flag to help nvcc parse out flags with spaces.
+  string(APPEND nvcc_host_compiler_flags ",\"${flag}\"")
+endforeach()
+if (nvcc_host_compiler_flags)
+  set(nvcc_host_compiler_flags "-Xcompiler" ${nvcc_host_compiler_flags})
+endif()
+#message("nvcc_host_compiler_flags = \"${nvcc_host_compiler_flags}\"")
+# Add the build specific configuration flags
+list(APPEND CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS_${build_configuration}})
+
+# Any -ccbin existing in CUDA_NVCC_FLAGS gets highest priority
+list( FIND CUDA_NVCC_FLAGS "-ccbin" ccbin_found0 )
+list( FIND CUDA_NVCC_FLAGS "--compiler-bindir" ccbin_found1 )
+if( ccbin_found0 LESS 0 AND ccbin_found1 LESS 0 AND CUDA_HOST_COMPILER )
+  if (CUDA_HOST_COMPILER STREQUAL "@_CUDA_MSVC_HOST_COMPILER@" AND DEFINED CCBIN)
+    set(CCBIN -ccbin "${CCBIN}")
+  else()
+    set(CCBIN -ccbin "${CUDA_HOST_COMPILER}")
+  endif()
+endif()
+
+# cuda_execute_process - Executes a command with optional command echo and status message.
+#
+#   status  - Status message to print if verbose is true
+#   command - COMMAND argument from the usual execute_process argument structure
+#   ARGN    - Remaining arguments are the command with arguments
+#
+#   CUDA_result - return value from running the command
+#
+# Make this a macro instead of a function, so that things like RESULT_VARIABLE
+# and other return variables are present after executing the process.
+macro(cuda_execute_process status command)
+  set(_command ${command})
+  if(NOT "x${_command}" STREQUAL "xCOMMAND")
+    message(FATAL_ERROR "Malformed call to cuda_execute_process.  Missing COMMAND as second argument. (command = ${command})")
+  endif()
+  if(verbose)
+    execute_process(COMMAND "${CMAKE_COMMAND}" -E echo -- ${status})
+    # Now we need to build up our command string.  We are accounting for quotes
+    # and spaces, anything else is left up to the user to fix if they want to
+    # copy and paste a runnable command line.
+    set(cuda_execute_process_string)
+    foreach(arg ${ARGN})
+      # If there are quotes, escape them, so they come through.
+      string(REPLACE "\"" "\\\"" arg ${arg})
+      # Args with spaces need quotes around them to get them to be parsed as a single argument.
+      if(arg MATCHES " ")
+        list(APPEND cuda_execute_process_string "\"${arg}\"")
+      else()
+        list(APPEND cuda_execute_process_string ${arg})
+      endif()
+    endforeach()
+    # Echo the command
+    execute_process(COMMAND ${CMAKE_COMMAND} -E echo ${cuda_execute_process_string})
+  endif()
+  # Run the command
+  execute_process(COMMAND ${ARGN} RESULT_VARIABLE CUDA_result )
+endmacro()
+
+# Delete the target file
+cuda_execute_process(
+  "Removing ${generated_file}"
+  COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
+  )
+
+# For CUDA 2.3 and below, -G -M doesn't work, so remove the -G flag
+# for dependency generation and hope for the best.
+set(depends_CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
+set(CUDA_VERSION @CUDA_VERSION@)
+
+# nvcc doesn't define __CUDACC__ for some reason when generating dependency files.  This
+# can cause incorrect dependencies when #including files based on this macro which is
+# defined in the generating passes of nvcc invocation.  We will go ahead and manually
+# define this for now until a future version fixes this bug.
+set(CUDACC_DEFINE -D__CUDACC__)
+
+# Generate the dependency file
+cuda_execute_process(
+  "Generating dependency file: ${NVCC_generated_dependency_file}"
+  COMMAND "${CUDA_NVCC_EXECUTABLE}"
+  -M
+  ${CUDACC_DEFINE}
+  "${source_file}"
+  -o "${NVCC_generated_dependency_file}"
+  ${CCBIN}
+  ${nvcc_flags}
+  ${nvcc_host_compiler_flags}
+  ${depends_CUDA_NVCC_FLAGS}
+  -DNVCC
+  ${CUDA_NVCC_INCLUDE_ARGS}
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Generate the cmake readable dependency file to a temp file.  Don't put the
+# quotes just around the filenames for the input_file and output_file variables.
+# CMake will pass the quotes through and not be able to find the file.
+cuda_execute_process(
+  "Generating temporary cmake readable file: ${cmake_dependency_file}.tmp"
+  COMMAND "${CMAKE_COMMAND}"
+  -D "input_file:FILEPATH=${NVCC_generated_dependency_file}"
+  -D "output_file:FILEPATH=${cmake_dependency_file}.tmp"
+  -D "verbose=${verbose}"
+  -P "${CUDA_make2cmake}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Copy the file if it is different
+cuda_execute_process(
+  "Copy if different ${cmake_dependency_file}.tmp to ${cmake_dependency_file}"
+  COMMAND "${CMAKE_COMMAND}" -E copy_if_different "${cmake_dependency_file}.tmp" "${cmake_dependency_file}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Delete the temporary file
+cuda_execute_process(
+  "Removing ${cmake_dependency_file}.tmp and ${NVCC_generated_dependency_file}"
+  COMMAND "${CMAKE_COMMAND}" -E remove "${cmake_dependency_file}.tmp" "${NVCC_generated_dependency_file}"
+  )
+
+if(CUDA_result)
+  message(FATAL_ERROR "Error generating ${generated_file}")
+endif()
+
+# Generate the code
+cuda_execute_process(
+  "Generating ${generated_file}"
+  COMMAND "${CUDA_NVCC_EXECUTABLE}"
+  "${source_file}"
+  ${cuda_language_flag}
+  ${format_flag} -o "${generated_file}"
+  ${CCBIN}
+  ${nvcc_flags}
+  ${nvcc_host_compiler_flags}
+  ${CUDA_NVCC_FLAGS}
+  -DNVCC
+  ${CUDA_NVCC_INCLUDE_ARGS}
+  )
+
+if(CUDA_result)
+  # Since nvcc can sometimes leave half done files make sure that we delete the output file.
+  cuda_execute_process(
+    "Removing ${generated_file}"
+    COMMAND "${CMAKE_COMMAND}" -E remove "${generated_file}"
+    )
+  message(FATAL_ERROR "Error generating file ${generated_file}")
+else()
+  if(verbose)
+    message("Generated ${generated_file} successfully.")
+  endif()
+endif()
+
+# Cubin resource report commands.
+if( build_cubin )
+  # Run with -cubin to produce resource usage report.
+  cuda_execute_process(
+    "Generating ${generated_cubin_file}"
+    COMMAND "${CUDA_NVCC_EXECUTABLE}"
+    "${source_file}"
+    ${CUDA_NVCC_FLAGS}
+    ${nvcc_flags}
+    ${CCBIN}
+    ${nvcc_host_compiler_flags}
+    -DNVCC
+    -cubin
+    -o "${generated_cubin_file}"
+    ${CUDA_NVCC_INCLUDE_ARGS}
+    )
+
+  # Execute the parser script.
+  cuda_execute_process(
+    "Executing the parser script"
+    COMMAND  "${CMAKE_COMMAND}"
+    -D "input_file:STRING=${generated_cubin_file}"
+    -P "${CUDA_parse_cubin}"
+    )
+
+endif()
+
+cmake_policy(POP)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bf7edd69ccd13990b24350fdf217b156343724f4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
@@ -0,0 +1,300 @@
+# Synopsis:
+#   CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
+#   -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
+#      target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
+#       - "Auto" detects local machine GPU compute arch at runtime.
+#       - "Common" and "All" cover common and entire subsets of architectures
+#      ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
+#      NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
+#      NUM: Any number. Only those pairs are currently accepted by NVCC though:
+#            3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0
+#      Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
+#      Additionally, sets ${out_variable}_readable to the resulting numeric list
+#      Example:
+#       CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
+#        LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
+#
+#      More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
+#
+
+if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+  if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA"
+      AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)")
+    set(CUDA_VERSION "${CMAKE_MATCH_1}")
+  endif()
+endif()
+
+# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
+
+# This list will be used for CUDA_ARCH_NAME = All option
+set(CUDA_KNOWN_GPU_ARCHITECTURES  "Kepler" "Maxwell")
+
+# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
+set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0")
+
+# This list is used to filter CUDA archs when autodetecting
+set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0")
+
+if(CUDA_VERSION VERSION_GREATER "10.5")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
+
+  if(CUDA_VERSION VERSION_LESS "11.1")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "11.1")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
+  set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6")
+
+  if(CUDA_VERSION VERSION_LESS "11.8")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "11.8")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0")
+
+  if(CUDA_VERSION VERSION_LESS "12.0")
+    set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX")
+    list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX")
+  endif()
+endif()
+
+if(NOT CUDA_VERSION VERSION_LESS "12.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a")
+  list(REMOVE_ITEM CUDA_COMMON_GPU_ARCHITECTURES "3.5")
+  list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5")
+endif()
+
+if(CUDA_VERSION VERSION_GREATER "12.6")
+  list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0a")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.1a")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0")
+  list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.1a")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0")
+  list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0a")
+endif()
+
+
+################################################################################################
+# A function for automatic detection of GPUs installed  (if autodetection is enabled)
+# Usage:
+#   CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
+#
+function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
+  if(NOT CUDA_GPU_DETECT_OUTPUT)
+    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu")
+    else()
+      set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp")
+    endif()
+
+    file(WRITE ${file} ""
+      "#include \n"
+      "#include \n"
+      "int main()\n"
+      "{\n"
+      "  int count = 0;\n"
+      "  if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
+      "  if (count == 0) return -1;\n"
+      "  for (int device = 0; device < count; ++device)\n"
+      "  {\n"
+      "    cudaDeviceProp prop;\n"
+      "    if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
+      "      std::printf(\"%d.%d \", prop.major, prop.minor);\n"
+      "  }\n"
+      "  return 0;\n"
+      "}\n")
+
+    if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language
+      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
+              RUN_OUTPUT_VARIABLE compute_capabilities)
+    else()
+      try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
+              CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
+              LINK_LIBRARIES ${CUDA_LIBRARIES}
+              RUN_OUTPUT_VARIABLE compute_capabilities)
+    endif()
+
+    # Filter unrelated content out of the output.
+    string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}")
+
+    if(run_result EQUAL 0)
+      string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}")
+      set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities}
+        CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE)
+    endif()
+  endif()
+
+  if(NOT CUDA_GPU_DETECT_OUTPUT)
+    message(STATUS "Automatic GPU detection failed. Building for common architectures.")
+    set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
+  else()
+    # Filter based on CUDA version supported archs
+    set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
+    separate_arguments(CUDA_GPU_DETECT_OUTPUT)
+    foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
+        if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
+                                            ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
+        list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
+        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
+      else()
+        string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}")
+      endif()
+    endforeach()
+
+    set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE)
+  endif()
+endfunction()
+
+
+################################################################################################
+# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
+# Usage:
+#   SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
+function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
+  set(CUDA_ARCH_LIST "${ARGN}")
+
+  if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
+    set(CUDA_ARCH_LIST "Auto")
+  endif()
+
+  set(cuda_arch_bin)
+  set(cuda_arch_ptx)
+
+  if("${CUDA_ARCH_LIST}" STREQUAL "All")
+    set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
+  elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
+    set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
+  elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
+    CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
+    message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
+  endif()
+
+  # Now process the list and look for names
+  string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
+  list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
+  foreach(arch_name ${CUDA_ARCH_LIST})
+    set(arch_bin)
+    set(arch_ptx)
+    set(add_ptx FALSE)
+    # Check to see if we are compiling PTX
+    if(arch_name MATCHES "(.*)\\+PTX$")
+      set(add_ptx TRUE)
+      set(arch_name ${CMAKE_MATCH_1})
+    endif()
+    if(arch_name MATCHES "^([0-9]+\\.[0-9][af]?(\\([0-9]+\\.[0-9]\\))?)$")
+      set(arch_bin ${CMAKE_MATCH_1})
+      set(arch_ptx ${arch_bin})
+    else()
+      # Look for it in our list of known architectures
+      if(${arch_name} STREQUAL "Kepler+Tesla")
+        set(arch_bin 3.7)
+      elseif(${arch_name} STREQUAL "Kepler")
+        set(arch_bin 3.5)
+        set(arch_ptx 3.5)
+      elseif(${arch_name} STREQUAL "Maxwell+Tegra")
+        set(arch_bin 5.3)
+      elseif(${arch_name} STREQUAL "Maxwell")
+        set(arch_bin 5.0 5.2)
+        set(arch_ptx 5.2)
+      elseif(${arch_name} STREQUAL "Pascal")
+        set(arch_bin 6.0 6.1)
+        set(arch_ptx 6.1)
+     elseif(${arch_name} STREQUAL "Volta+Tegra")
+        set(arch_bin 7.2)
+      elseif(${arch_name} STREQUAL "Volta")
+        set(arch_bin 7.0 7.0)
+        set(arch_ptx 7.0)
+      elseif(${arch_name} STREQUAL "Turing")
+        set(arch_bin 7.5)
+        set(arch_ptx 7.5)
+      elseif(${arch_name} STREQUAL "Ampere+Tegra")
+        set(arch_bin 8.7)
+      elseif(${arch_name} STREQUAL "Ampere")
+        set(arch_bin 8.0 8.6)
+        set(arch_ptx 8.0 8.6)
+      elseif(${arch_name} STREQUAL "Ada")
+        set(arch_bin 8.9)
+        set(arch_ptx 8.9)
+      elseif(${arch_name} STREQUAL "Hopper")
+        set(arch_bin 9.0)
+        set(arch_ptx 9.0)
+      elseif(${arch_name} STREQUAL "Blackwell+Tegra")
+        set(arch_bin 10.1)
+      elseif(${arch_name} STREQUAL "Blackwell")
+        set(arch_bin 10.0 12.0)
+        set(arch_ptx 10.0 12.0)
+      else()
+        message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ")
+      endif()
+    endif()
+    if(NOT arch_bin)
+      message(SEND_ERROR "arch_bin wasn't set for some reason")
+    endif()
+    list(APPEND cuda_arch_bin ${arch_bin})
+    if(add_ptx)
+      if (NOT arch_ptx)
+        set(arch_ptx ${arch_bin})
+      endif()
+      list(APPEND cuda_arch_ptx ${arch_ptx})
+    endif()
+  endforeach()
+
+  # remove dots and convert to lists
+  string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
+  string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
+  string(REGEX MATCHALL "[0-9()]+[af]?" cuda_arch_bin "${cuda_arch_bin}")
+  string(REGEX MATCHALL "[0-9]+[af]?"   cuda_arch_ptx "${cuda_arch_ptx}")
+
+  if(cuda_arch_bin)
+    list(REMOVE_DUPLICATES cuda_arch_bin)
+  endif()
+  if(cuda_arch_ptx)
+    list(REMOVE_DUPLICATES cuda_arch_ptx)
+  endif()
+
+  set(nvcc_flags "")
+  set(nvcc_archs_readable "")
+
+  # Tell NVCC to add binaries for the specified GPUs
+  foreach(arch ${cuda_arch_bin})
+    if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
+      # User explicitly specified ARCH for the concrete CODE
+      list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
+      list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
+    else()
+      # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
+      list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
+      list(APPEND nvcc_archs_readable sm_${arch})
+    endif()
+  endforeach()
+
+  # Tell NVCC to add PTX intermediate code for the specified architectures
+  foreach(arch ${cuda_arch_ptx})
+    list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
+    list(APPEND nvcc_archs_readable compute_${arch})
+  endforeach()
+
+  string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
+  set(${out_variable}          ${nvcc_flags}          PARENT_SCOPE)
+  set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
+endfunction()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..6821cee4f77a9d84c74f2c140870a2163ae5a5f0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake
@@ -0,0 +1,47 @@
+# Distributed under the OSI-approved BSD 3-Clause License.  See accompanying
+# file Copyright.txt or https://cmake.org/licensing for details.
+
+#.rst:
+# FindPackageMessage
+# ------------------
+#
+#
+#
+# FIND_PACKAGE_MESSAGE( "message for user" "find result details")
+#
+# This macro is intended to be used in FindXXX.cmake modules files.  It
+# will print a message once for each unique find result.  This is useful
+# for telling the user where a package was found.  The first argument
+# specifies the name (XXX) of the package.  The second argument
+# specifies the message to display.  The third argument lists details
+# about the find result so that if they change the message will be
+# displayed again.  The macro also obeys the QUIET argument to the
+# find_package command.
+#
+# Example:
+#
+# ::
+#
+#   if(X11_FOUND)
+#     FIND_PACKAGE_MESSAGE(X11 "Found X11: ${X11_X11_LIB}"
+#       "[${X11_X11_LIB}][${X11_INCLUDE_DIR}]")
+#   else()
+#    ...
+#   endif()
+
+function(FIND_PACKAGE_MESSAGE pkg msg details)
+  # Avoid printing a message repeatedly for the same find result.
+  if(NOT ${pkg}_FIND_QUIETLY)
+    string(REPLACE "\n" "" details "${details}")
+    set(DETAILS_VAR FIND_PACKAGE_MESSAGE_DETAILS_${pkg})
+    if(NOT "${details}" STREQUAL "${${DETAILS_VAR}}")
+      # The message has not yet been printed.
+      message(STATUS "${msg}")
+
+      # Save the find details in the cache to avoid printing the same
+      # message again.
+      set("${DETAILS_VAR}" "${details}"
+        CACHE INTERNAL "Details about finding ${pkg}")
+    endif()
+  endif()
+endfunction()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..7ecaff5109f42efb336b30a6ef0ad429a30051d3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/LoadHIP.cmake
@@ -0,0 +1,257 @@
+set(PYTORCH_FOUND_HIP FALSE)
+
+# If ROCM_PATH is set, assume intention is to compile with
+# ROCm support and error out if the ROCM_PATH does not exist.
+# Else ROCM_PATH does not exist, assume a default of /opt/rocm
+# In the latter case, if /opt/rocm does not exist emit status
+# message and return.
+if(DEFINED ENV{ROCM_PATH})
+  file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH)
+  if(NOT EXISTS ${ROCM_PATH})
+    message(FATAL_ERROR
+      "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n"
+      "Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.")
+  endif()
+else()
+  if(UNIX)
+    set(ROCM_PATH /opt/rocm)
+  else() # Win32
+    set(ROCM_PATH C:/opt/rocm)
+  endif()
+  if(NOT EXISTS ${ROCM_PATH})
+    message(STATUS
+        "ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n"
+        "Building without ROCm support.")
+    return()
+  endif()
+endif()
+
+# MAGMA_HOME
+if(NOT DEFINED ENV{MAGMA_HOME})
+  set(MAGMA_HOME ${ROCM_PATH}/magma)
+  set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma)
+else()
+  file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME)
+endif()
+
+# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different
+# installation directory.
+if(WIN32)
+  if(NOT DEFINED ENV{MIOPEN_PATH})
+    set(miopen_DIR C:/opt/miopen/lib/cmake/miopen)
+  else()
+    set(miopen_DIR $ENV{MIOPEN_PATH}/lib/cmake/miopen)
+  endif()
+endif()
+
+torch_hip_get_arch_list(PYTORCH_ROCM_ARCH)
+if(PYTORCH_ROCM_ARCH STREQUAL "")
+  message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.")
+endif()
+message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
+
+# Add HIP to the CMAKE Module Path
+# needed because the find_package call to this module uses the Module mode search
+# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes
+if(UNIX)
+  set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
+else() # Win32
+  set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH})
+endif()
+
+# Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package
+# call to individual ROCM components uses the Config mode search
+list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
+
+macro(find_package_and_print_version PACKAGE_NAME)
+  find_package("${PACKAGE_NAME}" ${ARGN})
+  if(NOT ${PACKAGE_NAME}_FOUND)
+    message("Optional package ${PACKAGE_NAME} not found")
+  else()
+    message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
+    if(${PACKAGE_NAME}_INCLUDE_DIR)
+      list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
+    endif()
+  endif()
+endmacro()
+
+# Find the HIP Package
+# MODULE argument is added for clarity that CMake is searching
+# for FindHIP.cmake in Module mode
+find_package_and_print_version(HIP 1.0 MODULE)
+
+if(HIP_FOUND)
+  set(PYTORCH_FOUND_HIP TRUE)
+  find_package_and_print_version(hip REQUIRED CONFIG)
+  if(HIP_VERSION)
+    # Check if HIP_VERSION contains a dash (e.g., "7.1.25421-32f9fa6ca5")
+    # and strip everything after it to get clean numeric version
+    string(FIND "${HIP_VERSION}" "-" DASH_POS)
+    if(NOT DASH_POS EQUAL -1)
+      string(SUBSTRING "${HIP_VERSION}" 0 ${DASH_POS} HIP_VERSION_CLEAN)
+      set(HIP_VERSION "${HIP_VERSION_CLEAN}")
+  endif()
+  message("HIP version: ${HIP_VERSION}")
+endif()
+
+# The rocm-core package was only introduced in ROCm 6.4, so we make it optional.
+  find_package(rocm-core CONFIG)
+
+  # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow
+  # falling back to the hip version, which everyone should have.
+  # rocm_version.h lives in the rocm-core package and hip_version.h lives in the
+  # hip (lower-case) package. Both are probed above and will be in
+  # ROCM_INCLUDE_DIRS if available.
+  find_file(ROCM_VERSION_HEADER_PATH
+    NAMES rocm-core/rocm_version.h hip/hip_version.h
+    NO_DEFAULT_PATH
+    PATHS ${ROCM_INCLUDE_DIRS}
+  )
+  if(ROCM_VERSION_HEADER_PATH MATCHES "rocm-core/rocm_version.h$")
+    set(ROCM_LIB_NAME "ROCM")
+  else()
+    set(ROCM_LIB_NAME "HIP")
+  endif()
+
+  if(NOT ROCM_VERSION_HEADER_PATH)
+    message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}")
+  endif()
+  get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
+
+  if(EXISTS ${ROCM_VERSION_HEADER_PATH})
+    set(ROCM_HEADER_FILE ${ROCM_VERSION_HEADER_PATH})
+  else()
+    message(FATAL_ERROR "********************* ${ROCM_HEADER_NAME} could not be found ******************\n")
+  endif()
+
+  # Read the ROCM headerfile into a variable
+  message(STATUS "Reading ROCM version from: ${ROCM_HEADER_FILE}")
+  message(STATUS "Content: ${ROCM_HEADER_CONTENT}")
+  file(READ "${ROCM_HEADER_FILE}" ROCM_HEADER_CONTENT)
+
+  # Below we use a RegEx to find ROCM version numbers.
+  # Note that CMake does not support \s for blank space. That is
+  # why in the regular expressions below we have a blank space in
+  # the square brackets.
+  # There are three steps:
+  # 1. Match regular expression
+  # 2. Strip the non-numerical part of the string
+  # 3. Strip leading and trailing spaces
+
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_MAJOR" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR)
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_MINOR" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR)
+  string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
+  string(REPLACE "${ROCM_LIB_NAME}_VERSION_PATCH" "" TEMP2 ${TEMP1})
+  string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH)
+
+  # Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros
+  set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
+  math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
+
+  message("\n***** ROCm version from ${ROCM_HEADER_NAME} ****\n")
+  message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
+  message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
+  message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
+  message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
+  message("ROCM_VERSION_DEV_INT:   ${ROCM_VERSION_DEV_INT}")
+
+  math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")
+  message("HIP_VERSION_MAJOR: ${HIP_VERSION_MAJOR}")
+  message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}")
+  message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}")
+
+  # Find ROCM components using Config mode
+  # These components will be searced for recursively in ${ROCM_PATH}
+  message("\n***** Library versions from cmake find_package *****\n")
+  find_package_and_print_version(amd_comgr REQUIRED)
+  find_package_and_print_version(rocrand REQUIRED)
+  find_package_and_print_version(hiprand REQUIRED)
+  find_package_and_print_version(rocblas REQUIRED)
+  find_package_and_print_version(hipblas REQUIRED)
+  find_package_and_print_version(miopen REQUIRED)
+  find_package_and_print_version(hipfft REQUIRED)
+  find_package_and_print_version(hipsparse REQUIRED)
+  find_package_and_print_version(rocprim REQUIRED)
+  find_package_and_print_version(hipcub REQUIRED)
+  find_package_and_print_version(rocthrust REQUIRED)
+  find_package_and_print_version(hipsolver REQUIRED)
+  find_package_and_print_version(rocsolver REQUIRED)
+  # workaround cmake 4 build issue
+  if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
+    message(WARNING "Work around hiprtc cmake failure for cmake >= 4")
+    set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
+    find_package_and_print_version(hiprtc REQUIRED)
+    unset(CMAKE_POLICY_VERSION_MINIMUM)
+  else()
+    find_package_and_print_version(hiprtc REQUIRED)
+  endif()
+  find_package_and_print_version(hipblaslt REQUIRED)
+
+  if(UNIX)
+    find_package_and_print_version(rccl)
+    find_package_and_print_version(hsa-runtime64 REQUIRED)
+  endif()
+
+  # Optional components.
+  find_package_and_print_version(hipsparselt)  # Will be required when ready.
+
+  list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)
+
+  if(UNIX)
+    # roctx is part of roctracer
+    find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
+
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
+
+    if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+      # check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
+      set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc")
+      file(WRITE ${file} ""
+        "#define LEGACY_HIPBLAS_DIRECT\n"
+        "#include \n"
+        "int main() {\n"
+        "    hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n"
+        "    return 0;\n"
+        "}\n"
+        )
+      try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file}
+        CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
+        COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
+        OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec)
+
+      # check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT
+      set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc")
+      file(WRITE ${file} ""
+        "#define LEGACY_HIPBLAS_DIRECT\n"
+        "#include \n"
+        "int main() {\n"
+        "    hipblasLtMatmulDescAttributes_t attr = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;\n"
+        "    return 0;\n"
+        "}\n"
+        )
+      try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file}
+        CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
+        COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
+        OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext)
+
+      if(hipblaslt_compile_result_outer_vec)
+        set(HIPBLASLT_OUTER_VEC ON)
+        set(HIPBLASLT_VEC_EXT OFF)
+        message("hipblaslt is using scale pointer outer vec")
+      elseif(hipblaslt_compile_result_vec_ext)
+        set(HIPBLASLT_OUTER_VEC OFF)
+        set(HIPBLASLT_VEC_EXT ON)
+        message("hipblaslt is using scale pointer vec ext")
+      else()
+        set(HIPBLASLT_OUTER_VEC OFF)
+        set(HIPBLASLT_VEC_EXT OFF)
+        message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}")
+        message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}")
+      endif()
+    endif()
+  endif()
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bc8855d23e61fbbe5979beae22ab6086a388ba1f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake
@@ -0,0 +1,391 @@
+# ---[ cuda
+
+# Poor man's include guard
+if(TARGET torch::cudart)
+  return()
+endif()
+
+# sccache is only supported in CMake master and not in the newest official
+# release (3.11.3) yet. Hence we need our own Modules_CUDA_fix to enable sccache.
+list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/../Modules_CUDA_fix)
+
+# We don't want to statically link cudart, because we rely on it's dynamic linkage in
+# python (follow along torch/cuda/__init__.py and usage of cudaGetErrorName).
+# Technically, we can link cudart here statically, and link libtorch_python.so
+# to a dynamic libcudart.so, but that's just wasteful.
+# However, on Windows, if this one gets switched off, the error "cuda: unknown error"
+# will be raised when running the following code:
+# >>> import torch
+# >>> torch.cuda.is_available()
+# >>> torch.cuda.current_device()
+# More details can be found in the following links.
+# https://github.com/pytorch/pytorch/issues/20635
+# https://github.com/pytorch/pytorch/issues/17108
+if(NOT MSVC)
+  set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "")
+endif()
+
+# Find CUDA.
+find_package(CUDA)
+if(NOT CUDA_FOUND)
+  # If user explicitly set USE_CUDA=1, error out instead of falling back
+  if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA)
+    message(FATAL_ERROR
+      "PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. "
+      "Please check your CUDA installation, ensure CUDA toolkit is installed, "
+      "and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. "
+      "If you want to build without CUDA, please set USE_CUDA=0.")
+  endif()
+
+  message(WARNING
+    "PyTorch: CUDA cannot be found. Depending on whether you are building "
+    "PyTorch or a PyTorch dependent library, the next warning / error will "
+    "give you more info.")
+  set(CAFFE2_USE_CUDA OFF)
+  return()
+endif()
+
+# Enable CUDA language support
+set(CUDAToolkit_ROOT "${CUDA_TOOLKIT_ROOT_DIR}")
+# Pass clang as host compiler, which according to the docs
+# Must be done before CUDA language is enabled, see
+# https://cmake.org/cmake/help/v3.15/variable/CMAKE_CUDA_HOST_COMPILER.html
+if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
+  set(CMAKE_CUDA_HOST_COMPILER "${CMAKE_CXX_COMPILER}")
+endif()
+enable_language(CUDA)
+if("X${CMAKE_CUDA_STANDARD}" STREQUAL "X" )
+  set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})
+endif()
+set(CMAKE_CUDA_STANDARD_REQUIRED ON)
+
+# CMP0074 - find_package will respect _ROOT variables
+cmake_policy(PUSH)
+if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.12.0)
+  cmake_policy(SET CMP0074 NEW)
+endif()
+
+find_package(CUDAToolkit REQUIRED)
+
+cmake_policy(POP)
+
+if(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_EQUAL CUDAToolkit_VERSION)
+  message(FATAL_ERROR "Found two conflicting CUDA versions:\n"
+                      "V${CMAKE_CUDA_COMPILER_VERSION} in '${CUDA_INCLUDE_DIRS}' and\n"
+                      "V${CUDAToolkit_VERSION} in '${CUDAToolkit_INCLUDE_DIRS}'")
+endif()
+
+message(STATUS "PyTorch: CUDA detected: " ${CUDA_VERSION})
+message(STATUS "PyTorch: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
+message(STATUS "PyTorch: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})
+if(CUDA_VERSION VERSION_LESS 12.0)
+  message(FATAL_ERROR "PyTorch requires CUDA 12.0 or above.")
+endif()
+
+if(CUDA_FOUND)
+  # Sometimes, we may mismatch nvcc with the CUDA headers we are
+  # compiling with, e.g., if a ccache nvcc is fed to us by CUDA_NVCC_EXECUTABLE
+  # but the PATH is not consistent with CUDA_HOME.  It's better safe
+  # than sorry: make sure everything is consistent.
+  if(MSVC AND CMAKE_GENERATOR MATCHES "Visual Studio")
+    # When using Visual Studio, it attempts to lock the whole binary dir when
+    # `try_run` is called, which will cause the build to fail.
+    string(RANDOM BUILD_SUFFIX)
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}/${BUILD_SUFFIX}")
+  else()
+    set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
+  endif()
+  set(file "${PROJECT_BINARY_DIR}/detect_cuda_version.cc")
+  file(WRITE ${file} ""
+    "#include \n"
+    "#include \n"
+    "int main() {\n"
+    "  printf(\"%d.%d\", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);\n"
+    "  return 0;\n"
+    "}\n"
+    )
+  if(NOT CMAKE_CROSSCOMPILING)
+    try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
+      CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
+      LINK_LIBRARIES ${CUDA_LIBRARIES}
+      RUN_OUTPUT_VARIABLE cuda_version_from_header
+      COMPILE_OUTPUT_VARIABLE output_var
+      )
+    if(NOT compile_result)
+      message(FATAL_ERROR "PyTorch: Couldn't determine version from header: " ${output_var})
+    endif()
+    message(STATUS "PyTorch: Header version is: " ${cuda_version_from_header})
+    if(NOT cuda_version_from_header STREQUAL ${CUDA_VERSION_STRING})
+      # Force CUDA to be processed for again next time
+      # TODO: I'm not sure if this counts as an implementation detail of
+      # FindCUDA
+      set(cuda_version_from_findcuda ${CUDA_VERSION_STRING})
+      unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE)
+      # Not strictly necessary, but for good luck.
+      unset(CUDA_VERSION CACHE)
+      # Error out
+      message(FATAL_ERROR "FindCUDA says CUDA version is ${cuda_version_from_findcuda} (usually determined by nvcc), "
+        "but the CUDA headers say the version is ${cuda_version_from_header}.  This often occurs "
+        "when you set both CUDA_HOME and CUDA_NVCC_EXECUTABLE to "
+        "non-standard locations, without also setting PATH to point to the correct nvcc.  "
+        "Perhaps, try re-running this command again with PATH=${CUDA_TOOLKIT_ROOT_DIR}/bin:$PATH.  "
+        "See above log messages for more diagnostics, and see https://github.com/pytorch/pytorch/issues/8092 for more details.")
+    endif()
+  endif()
+endif()
+
+# ---[ CUDA libraries wrapper
+
+# find lbnvrtc.so
+set(CUDA_NVRTC_LIB "${CUDA_nvrtc_LIBRARY}" CACHE FILEPATH "")
+if(CUDA_NVRTC_LIB AND NOT CUDA_NVRTC_SHORTHASH)
+  find_package(Python COMPONENTS Interpreter)
+  execute_process(
+    COMMAND Python::Interpreter -c
+    "import hashlib;hash=hashlib.sha256();hash.update(open('${CUDA_NVRTC_LIB}','rb').read());print(hash.hexdigest()[:8])"
+    RESULT_VARIABLE _retval
+    OUTPUT_VARIABLE CUDA_NVRTC_SHORTHASH)
+  if(NOT _retval EQUAL 0)
+    message(WARNING "Failed to compute shorthash for libnvrtc.so")
+    set(CUDA_NVRTC_SHORTHASH "XXXXXXXX")
+  else()
+    string(STRIP "${CUDA_NVRTC_SHORTHASH}" CUDA_NVRTC_SHORTHASH)
+    message(STATUS "${CUDA_NVRTC_LIB} shorthash is ${CUDA_NVRTC_SHORTHASH}")
+  endif()
+endif()
+
+# Create new style imported libraries.
+# Several of these libraries have a hardcoded path if CAFFE2_STATIC_LINK_CUDA
+# is set. This path is where sane CUDA installations have their static
+# libraries installed. This flag should only be used for binary builds, so
+# end-users should never have this flag set.
+
+# cuda
+add_library(caffe2::cuda INTERFACE IMPORTED)
+set_property(
+    TARGET caffe2::cuda PROPERTY INTERFACE_LINK_LIBRARIES
+    CUDA::cuda_driver)
+
+# cudart
+add_library(torch::cudart INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA)
+    set_property(
+        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart_static)
+else()
+    set_property(
+        TARGET torch::cudart PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart)
+endif()
+
+
+# cublas
+add_library(caffe2::cublas INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    set_property(
+        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
+        # NOTE: cublas is always linked dynamically
+        CUDA::cublas CUDA::cublasLt)
+    set_property(
+        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cudart_static rt)
+else()
+    set_property(
+        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cublas CUDA::cublasLt)
+endif()
+
+# cudnn interface
+# static linking is handled by USE_STATIC_CUDNN environment variable
+if(CAFFE2_USE_CUDNN)
+  if(USE_STATIC_CUDNN)
+    set(CUDNN_STATIC ON CACHE BOOL "")
+  else()
+    set(CUDNN_STATIC OFF CACHE BOOL "")
+  endif()
+
+  find_package(CUDNN)
+
+  if(NOT CUDNN_FOUND)
+    message(WARNING
+      "Cannot find cuDNN library. Turning the option off")
+    set(CAFFE2_USE_CUDNN OFF)
+  else()
+    if(CUDNN_VERSION VERSION_LESS "8.1.0")
+      message(FATAL_ERROR "PyTorch requires cuDNN 8.1 and above.")
+    endif()
+  endif()
+
+  add_library(torch::cudnn INTERFACE IMPORTED)
+  target_include_directories(torch::cudnn INTERFACE ${CUDNN_INCLUDE_PATH})
+  if(CUDNN_STATIC AND NOT WIN32)
+    target_link_options(torch::cudnn INTERFACE
+        "-Wl,--exclude-libs,libcudnn_static.a")
+  else()
+    target_link_libraries(torch::cudnn INTERFACE ${CUDNN_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUDNN is set to 0. Compiling without cuDNN support")
+endif()
+
+if(CAFFE2_USE_CUSPARSELT)
+  find_package(CUSPARSELT)
+
+  if(NOT CUSPARSELT_FOUND)
+    message(WARNING
+      "Cannot find cuSPARSELt library. Turning the option off")
+    set(CAFFE2_USE_CUSPARSELT OFF)
+  else()
+    add_library(torch::cusparselt INTERFACE IMPORTED)
+    target_include_directories(torch::cusparselt INTERFACE ${CUSPARSELT_INCLUDE_PATH})
+    target_link_libraries(torch::cusparselt INTERFACE ${CUSPARSELT_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
+endif()
+
+if(USE_CUDSS)
+  find_package(CUDSS)
+
+  if(NOT CUDSS_FOUND)
+    message(WARNING
+      "Cannot find CUDSS library. Turning the option off")
+    set(USE_CUDSS OFF)
+  else()
+    add_library(torch::cudss INTERFACE IMPORTED)
+    target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH})
+    target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH})
+  endif()
+else()
+  message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support")
+endif()
+
+# cufile
+if(CAFFE2_USE_CUFILE)
+  add_library(torch::cufile INTERFACE IMPORTED)
+  if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+      set_property(
+          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cuFile_static)
+  else()
+      set_property(
+          TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cuFile)
+  endif()
+else()
+  message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
+endif()
+
+# curand
+add_library(caffe2::curand INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    set_property(
+        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::curand_static)
+else()
+    set_property(
+        TARGET caffe2::curand PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::curand)
+endif()
+
+# cufft
+add_library(caffe2::cufft INTERFACE IMPORTED)
+if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
+    if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
+      set_property(
+          TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cufft_static_nocallback)
+    else()
+      set_property(
+          TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+          CUDA::cufft_static)
+    endif()
+else()
+    set_property(
+        TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES
+        CUDA::cufft)
+endif()
+
+# nvrtc
+add_library(caffe2::nvrtc INTERFACE IMPORTED)
+set_property(
+    TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES
+    CUDA::nvrtc caffe2::cuda)
+
+# Add onnx namespace definition to nvcc
+if(ONNX_NAMESPACE)
+  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=${ONNX_NAMESPACE}")
+else()
+  list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=onnx_c2")
+endif()
+
+# Don't activate VC env again for Ninja generators with MSVC on Windows if CUDAHOSTCXX is not defined
+# by adding --use-local-env.
+if(MSVC AND CMAKE_GENERATOR STREQUAL "Ninja" AND NOT DEFINED ENV{CUDAHOSTCXX})
+  list(APPEND CUDA_NVCC_FLAGS "--use-local-env")
+endif()
+
+# setting nvcc arch flags
+torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA)
+# CMake 3.18 adds integrated support for architecture selection, but we can't rely on it
+if(DEFINED CMAKE_CUDA_ARCHITECTURES)
+  message(WARNING
+          "pytorch is not compatible with `CMAKE_CUDA_ARCHITECTURES` and will ignore its value. "
+          "Please configure `TORCH_CUDA_ARCH_LIST` instead.")
+  set(CMAKE_CUDA_ARCHITECTURES OFF)
+endif()
+
+list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
+message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}")
+
+# disable some nvcc diagnostic that appears in boost, glog, glags, opencv, etc.
+foreach(diag cc_clobber_ignored
+             field_without_dll_interface
+             base_class_has_different_dll_interface
+             dll_interface_conflict_none_assumed
+             dll_interface_conflict_dllexport_assumed
+             bad_friend_decl)
+  list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag})
+endforeach()
+string(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}")
+list(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS})
+
+set(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror")
+if(MSVC)
+  list(APPEND CUDA_NVCC_FLAGS "--Werror" "cross-execution-space-call")
+  list(APPEND CUDA_NVCC_FLAGS "--no-host-device-move-forward")
+endif()
+
+# Debug and Release symbol support
+if(MSVC)
+  if(${CAFFE2_USE_MSVC_STATIC_RUNTIME})
+    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MTd")
+    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MT")
+    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MT")
+    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MT")
+  else()
+    string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -Xcompiler /MDd")
+    string(APPEND CMAKE_CUDA_FLAGS_MINSIZEREL " -Xcompiler /MD")
+    string(APPEND CMAKE_CUDA_FLAGS_RELEASE " -Xcompiler /MD")
+    string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -Xcompiler /MD")
+  endif()
+  if(CUDA_NVCC_FLAGS MATCHES "Zi")
+    list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-FS")
+  endif()
+elseif(CUDA_DEVICE_DEBUG)
+  list(APPEND CUDA_NVCC_FLAGS "-g" "-G")  # -G enables device code debugging symbols
+endif()
+
+# Set expt-relaxed-constexpr to suppress Eigen warnings
+list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
+
+# Set expt-extended-lambda to support lambda on device
+list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
+
+foreach(FLAG ${CUDA_NVCC_FLAGS})
+  string(FIND "${FLAG}" " " flag_space_position)
+  if(NOT flag_space_position EQUAL -1)
+    message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'")
+  endif()
+  string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}")
+endforeach()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..186cda1a909ab79431114d1c61de895069255389
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/gflags.cmake
@@ -0,0 +1,83 @@
+# ---[ gflags
+
+# We will try to use the config mode first, and then manual find.
+find_package(gflags CONFIG QUIET)
+if(NOT TARGET gflags)
+  find_package(gflags MODULE QUIET)
+endif()
+
+if(TARGET gflags)
+  message(STATUS "Caffe2: Found gflags with new-style gflags target.")
+elseif(GFLAGS_FOUND)
+  message(STATUS "Caffe2: Found gflags with old-style gflag starget.")
+  add_library(gflags UNKNOWN IMPORTED)
+  set_property(
+      TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY})
+  set_property(
+      TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+      ${GFLAGS_INCLUDE_DIR})
+else()
+  message(STATUS
+      "Caffe2: Cannot find gflags automatically. Using legacy find.")
+
+  # - Try to find GFLAGS in the legacy way.
+  #
+  # The following variables are optionally searched for defaults
+  #  GFLAGS_ROOT_DIR: Base directory where all GFLAGS components are found
+  #
+  # The following are set after configuration is done:
+  #  GFLAGS_FOUND
+  #  GFLAGS_INCLUDE_DIRS
+  #  GFLAGS_LIBRARIES
+  #  GFLAGS_LIBRARYRARY_DIRS
+  include(FindPackageHandleStandardArgs)
+  set(GFLAGS_ROOT_DIR "" CACHE PATH "Folder contains Gflags")
+
+  # We are testing only a couple of files in the include directories
+  if(WIN32)
+    find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h
+        PATHS ${GFLAGS_ROOT_DIR}/src/windows)
+  else()
+    find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h
+        PATHS ${GFLAGS_ROOT_DIR})
+  endif()
+
+  if(WIN32)
+    find_library(GFLAGS_LIBRARY_RELEASE
+        NAMES libgflags
+        PATHS ${GFLAGS_ROOT_DIR}
+        PATH_SUFFIXES Release)
+
+    find_library(GFLAGS_LIBRARY_DEBUG
+        NAMES libgflags-debug
+        PATHS ${GFLAGS_ROOT_DIR}
+        PATH_SUFFIXES Debug)
+    set(GFLAGS_LIBRARY optimized ${GFLAGS_LIBRARY_RELEASE} debug ${GFLAGS_LIBRARY_DEBUG})
+  else()
+    find_library(GFLAGS_LIBRARY gflags)
+  endif()
+
+  find_package_handle_standard_args(
+      gflags DEFAULT_MSG GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY)
+
+  if(GFLAGS_FOUND)
+    message(
+        STATUS
+        "Caffe2: Found gflags  (include: ${GFLAGS_INCLUDE_DIR}, "
+        "library: ${GFLAGS_LIBRARY})")
+    add_library(gflags UNKNOWN IMPORTED)
+    set_property(
+        TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARY})
+    set_property(
+        TARGET gflags PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+        ${GFLAGS_INCLUDE_DIR})
+  endif()
+endif()
+
+# After above, we should have the gflags target now.
+if(NOT TARGET gflags)
+  message(WARNING
+      "Caffe2: gflags cannot be found. Depending on whether you are building "
+      "Caffe2 or a Caffe2 dependent library, the next warning / error will "
+      "give you more info.")
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..bb03e81f29e3afed43ba95260cc5c298be881f72
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/glog.cmake
@@ -0,0 +1,70 @@
+# ---[ glog
+
+# We will try to use the config mode first, and then manual find.
+find_package(glog CONFIG QUIET)
+if(NOT TARGET glog::glog)
+  find_package(glog MODULE QUIET)
+endif()
+
+if(TARGET glog::glog)
+  message(STATUS "Caffe2: Found glog with new-style glog target.")
+elseif(GLOG_FOUND)
+  message(
+      STATUS
+      "Caffe2: Found glog with old-style glog starget. Glog never shipped "
+      "old style glog targets, so somewhere in your cmake path there might "
+      "be a custom Findglog.cmake file that got triggered. We will make a "
+      "best effort to create the new style glog target for you.")
+  add_library(glog::glog UNKNOWN IMPORTED)
+  set_property(
+      TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY})
+  set_property(
+      TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+      ${GLOG_INCLUDE_DIR})
+else()
+  message(STATUS "Caffe2: Cannot find glog automatically. Using legacy find.")
+
+  # - Try to find Glog
+  #
+  # The following variables are optionally searched for defaults
+  #  GLOG_ROOT_DIR: Base directory where all GLOG components are found
+  #
+  # The following are set after configuration is done:
+  #  GLOG_FOUND
+  #  GLOG_INCLUDE_DIRS
+  #  GLOG_LIBRARIES
+  #  GLOG_LIBRARYRARY_DIRS
+
+  include(FindPackageHandleStandardArgs)
+  set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog")
+  if(NOT WIN32)
+      find_path(GLOG_INCLUDE_DIR glog/logging.h
+          PATHS ${GLOG_ROOT_DIR})
+  endif()
+
+  find_library(GLOG_LIBRARY glog
+      PATHS ${GLOG_ROOT_DIR}
+      PATH_SUFFIXES lib lib64)
+
+  find_package_handle_standard_args(glog DEFAULT_MSG GLOG_INCLUDE_DIR GLOG_LIBRARY)
+
+  if(GLOG_FOUND)
+    message(STATUS
+        "Caffe2: Found glog (include: ${GLOG_INCLUDE_DIR}, "
+        "library: ${GLOG_LIBRARY})")
+    add_library(glog::glog UNKNOWN IMPORTED)
+    set_property(
+        TARGET glog::glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARY})
+    set_property(
+        TARGET glog::glog PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+        ${GLOG_INCLUDE_DIR})
+  endif()
+endif()
+
+# After above, we should have the glog::glog target now.
+if(NOT TARGET glog::glog)
+  message(WARNING
+      "Caffe2: glog cannot be found. Depending on whether you are building "
+      "Caffe2 or a Caffe2 dependent library, the next warning / error will "
+      "give you more info.")
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..2f6d1fd905aa303cc240b058318acdfb2483e9ad
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkl.cmake
@@ -0,0 +1,40 @@
+find_package(MKL QUIET)
+
+if(TARGET caffe2::mkl)
+  return()
+endif()
+
+add_library(caffe2::mkl INTERFACE IMPORTED)
+target_include_directories(caffe2::mkl INTERFACE ${MKL_INCLUDE_DIR})
+target_link_libraries(caffe2::mkl INTERFACE ${MKL_LIBRARIES})
+foreach(MKL_LIB IN LISTS MKL_LIBRARIES)
+  if(EXISTS "${MKL_LIB}")
+    get_filename_component(MKL_LINK_DIR "${MKL_LIB}" DIRECTORY)
+    if(IS_DIRECTORY "${MKL_LINK_DIR}")
+      target_link_directories(caffe2::mkl INTERFACE "${MKL_LINK_DIR}")
+    endif()
+  endif()
+endforeach()
+
+# TODO: This is a hack, it will not pick up architecture dependent
+# MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008
+set_property(
+  TARGET caffe2::mkl PROPERTY INTERFACE_LINK_DIRECTORIES
+  ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64 ${MKL_ROOT}/lib/intel64_win ${MKL_ROOT}/lib/win-x64)
+
+if(UNIX)
+  if(USE_STATIC_MKL)
+    foreach(MKL_LIB_PATH IN LISTS MKL_LIBRARIES)
+      if(NOT EXISTS "${MKL_LIB_PATH}")
+        continue()
+      endif()
+
+      get_filename_component(MKL_LIB_NAME "${MKL_LIB_PATH}" NAME)
+
+      # Match archive libraries starting with "libmkl_"
+      if(MKL_LIB_NAME MATCHES "^libmkl_" AND MKL_LIB_NAME MATCHES ".a$")
+        target_link_options(caffe2::mkl INTERFACE "-Wl,--exclude-libs,${MKL_LIB_NAME}")
+      endif()
+    endforeach()
+  endif()
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..87935625f9bfb543d1cdc7f2b59f11e8d4a709e7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/mkldnn.cmake
@@ -0,0 +1,18 @@
+set(MKLDNN_USE_NATIVE_ARCH ${USE_NATIVE_ARCH})
+
+if(CPU_AARCH64)
+  include(${CMAKE_CURRENT_LIST_DIR}/ComputeLibrary.cmake)
+endif()
+
+find_package(MKLDNN QUIET)
+
+if(NOT TARGET caffe2::mkldnn)
+  add_library(caffe2::mkldnn INTERFACE IMPORTED)
+endif()
+
+set_property(
+  TARGET caffe2::mkldnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+  ${MKLDNN_INCLUDE_DIR})
+set_property(
+  TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES
+  ${MKLDNN_LIBRARIES})
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..77ec3622b132dc7a7817716dd24ef986e6ac030d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/protobuf.cmake
@@ -0,0 +1,92 @@
+# ---[ Protobuf
+
+# We will try to use the config mode first, and then manual find.
+find_package(Protobuf CONFIG QUIET)
+if(NOT Protobuf_FOUND)
+  find_package(Protobuf MODULE QUIET)
+endif()
+
+if((TARGET protobuf::libprotobuf OR TARGET protobuf::libprotobuf-lite) AND TARGET protobuf::protoc)
+  # Hooray. This is the most ideal situation, meaning that you either have a
+  # Protobuf config file installed (like on Windows), or you are using a
+  # modern CMake that ships with a FindProtobuf.cmake file that produces
+  # modern targets.
+  message(STATUS "Caffe2: Found protobuf with new-style protobuf targets.")
+elseif(Protobuf_FOUND OR PROTOBUF_FOUND)
+  # If the modern targets are not present, we will generate them for you for
+  # backward compatibility. This is backported from CMake's new FindProtobuf.cmake
+  # content.
+  if((NOT PROTOBUF_LIBRARY) AND (NOT PROTOBUF_LITE_LIBRARY))
+    message(FATAL_ERROR
+        "Caffe2: Found protobuf with old style targets, but could not find targets."
+        " PROTOBUF_LIBRARY: " ${PROTOBUF_LIBRARY}
+        " PROTOBUF_LITE_LIBRARY: " ${PROTOBUF_LITE_LIBRARY}
+        " Protobuf_LIBRARY: " ${Protobuf_LIBRARY}
+        " Protobuf_LITE_LIBRARY: " ${Protobuf_LITE_LIBRARY})
+  endif()
+  message(STATUS "Caffe2: Found protobuf with old-style protobuf targets.")
+
+  if(PROTOBUF_LIBRARY)
+    if(NOT TARGET protobuf::libprotobuf)
+      add_library(protobuf::libprotobuf UNKNOWN IMPORTED)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY}")
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION "${PROTOBUF_LIBRARY}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY_RELEASE}")
+      set_property(TARGET protobuf::libprotobuf APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS RELEASE)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION_RELEASE "${PROTOBUF_LIBRARY_RELEASE}")
+    endif()
+    if(EXISTS "${PROTOBUF_LIBRARY_DEBUG}")
+      set_property(TARGET protobuf::libprotobuf APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS DEBUG)
+      set_target_properties(protobuf::libprotobuf PROPERTIES
+          IMPORTED_LOCATION_DEBUG "${PROTOBUF_LIBRARY_DEBUG}")
+    endif()
+  endif()
+
+  if(PROTOBUF_LITE_LIBRARY)
+    if(NOT TARGET protobuf::libprotobuf-lite)
+      add_library(protobuf::libprotobuf-lite UNKNOWN IMPORTED)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          INTERFACE_INCLUDE_DIRECTORIES "${PROTOBUF_INCLUDE_DIRS}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY}")
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION "${PROTOBUF_LITE_LIBRARY}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY_RELEASE}")
+      set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS RELEASE)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION_RELEASE "${PROTOBUF_LITE_LIBRARY_RELEASE}")
+    endif()
+    if(EXISTS "${PROTOBUF_LITE_LIBRARY_DEBUG}")
+      set_property(TARGET protobuf::libprotobuf-lite APPEND PROPERTY
+          IMPORTED_CONFIGURATIONS DEBUG)
+      set_target_properties(protobuf::libprotobuf-lite PROPERTIES
+          IMPORTED_LOCATION_DEBUG "${PROTOBUF_LITE_LIBRARY_DEBUG}")
+    endif()
+  endif()
+
+  if(PROTOBUF_PROTOC_EXECUTABLE)
+    if(NOT TARGET protobuf::protoc)
+      add_executable(protobuf::protoc IMPORTED)
+    endif()
+    set_property(TARGET protobuf::protoc PROPERTY
+        IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})
+  endif()
+endif()
+
+# After above, we should have the protobuf related target now.
+if((NOT TARGET protobuf::libprotobuf) AND (NOT TARGET protobuf::libprotobuf-lite))
+  message(WARNING
+      "Protobuf cannot be found. Depending on whether you are building Caffe2 "
+      "or a Caffe2 dependent library, the next warning / error will give you "
+      "more info.")
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..3cdf5fb914b1ddaad115332079cb66a13ac2aea9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/utils.cmake
@@ -0,0 +1,552 @@
+################################################################################################
+# Exclude and prepend functionalities
+function(exclude OUTPUT INPUT)
+set(EXCLUDES ${ARGN})
+foreach(EXCLUDE ${EXCLUDES})
+        list(REMOVE_ITEM INPUT "${EXCLUDE}")
+endforeach()
+set(${OUTPUT} ${INPUT} PARENT_SCOPE)
+endfunction(exclude)
+
+function(prepend OUTPUT PREPEND)
+set(OUT "")
+foreach(ITEM ${ARGN})
+        list(APPEND OUT "${PREPEND}${ITEM}")
+endforeach()
+set(${OUTPUT} ${OUT} PARENT_SCOPE)
+endfunction(prepend)
+
+################################################################################################
+# Parses a version string that might have values beyond major, minor, and patch
+# and set version variables for the library.
+# Usage:
+#   caffe2_parse_version_str( )
+function(caffe2_parse_version_str LIBNAME VERSIONSTR)
+  string(REGEX REPLACE "^([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${VERSIONSTR}")
+  string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR  "${VERSIONSTR}")
+  string(REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${VERSIONSTR}")
+  set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE)
+  set(${LIBNAME}_VERSION "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE)
+endfunction()
+
+###
+# Removes common indentation from a block of text to produce code suitable for
+# setting to `python -c`, or using with pycmd. This allows multiline code to be
+# nested nicely in the surrounding code structure.
+#
+# This function respsects Python_EXECUTABLE if it defined, otherwise it uses
+# `python` and hopes for the best. An error will be thrown if it is not found.
+#
+# Args:
+#     outvar : variable that will hold the stdout of the python command
+#     text   : text to remove indentation from
+#
+function(dedent outvar text)
+  # Use Python_EXECUTABLE if it is defined, otherwise default to python
+  if("${Python_EXECUTABLE}" STREQUAL "")
+    set(_python_exe "python3")
+  else()
+    set(_python_exe "${Python_EXECUTABLE}")
+  endif()
+  set(_fixup_cmd "import sys; from textwrap import dedent; print(dedent(sys.stdin.read()))")
+  file(WRITE "${CMAKE_BINARY_DIR}/indented.txt" "${text}")
+  execute_process(
+    COMMAND "${_python_exe}" -c "${_fixup_cmd}"
+    INPUT_FILE "${CMAKE_BINARY_DIR}/indented.txt"
+    RESULT_VARIABLE _dedent_exitcode
+    OUTPUT_VARIABLE _dedent_text)
+  if(NOT _dedent_exitcode EQUAL 0)
+    message(ERROR " Failed to remove indentation from: \n\"\"\"\n${text}\n\"\"\"
+    Python dedent failed with error code: ${_dedent_exitcode}")
+    message(FATAL_ERROR " Python dedent failed with error code: ${_dedent_exitcode}")
+  endif()
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_dedent_text}" _dedent_text)
+  set(${outvar} "${_dedent_text}" PARENT_SCOPE)
+endfunction()
+
+
+function(pycmd_no_exit outvar exitcode cmd)
+  # Use Python_EXECUTABLE if it is defined, otherwise default to python
+  if("${Python_EXECUTABLE}" STREQUAL "")
+    set(_python_exe "python")
+  else()
+    set(_python_exe "${Python_EXECUTABLE}")
+  endif()
+  # run the actual command
+  execute_process(
+    COMMAND "${_python_exe}" -c "${cmd}"
+    RESULT_VARIABLE _exitcode
+    OUTPUT_VARIABLE _output)
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_output}" _output)
+  set(${outvar} "${_output}" PARENT_SCOPE)
+  set(${exitcode} "${_exitcode}" PARENT_SCOPE)
+endfunction()
+
+
+###
+# Helper function to run `python -c ""` and capture the results of stdout
+#
+# Runs a python command and populates an outvar with the result of stdout.
+# Common indentation in the text of `cmd` is removed before the command is
+# executed, so the caller does not need to worry about indentation issues.
+#
+# This function respsects Python_EXECUTABLE if it defined, otherwise it uses
+# `python` and hopes for the best. An error will be thrown if it is not found.
+#
+# Args:
+#     outvar : variable that will hold the stdout of the python command
+#     cmd    : text representing a (possibly multiline) block of python code
+#
+function(pycmd outvar cmd)
+  dedent(_dedent_cmd "${cmd}")
+  pycmd_no_exit(_output _exitcode "${_dedent_cmd}")
+
+  if(NOT _exitcode EQUAL 0)
+    message(ERROR " Failed when running python code: \"\"\"\n${_dedent_cmd}\n\"\"\"")
+    message(FATAL_ERROR " Python command failed with error code: ${_exitcode}")
+  endif()
+  # Remove supurflous newlines (artifacts of print)
+  string(STRIP "${_output}" _output)
+  set(${outvar} "${_output}" PARENT_SCOPE)
+endfunction()
+
+
+##############################################################################
+# Macro to update cached options.
+macro(caffe2_update_option variable value)
+  if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
+    get_property(__help_string CACHE ${variable} PROPERTY HELPSTRING)
+    set(${variable} ${value} CACHE BOOL ${__help_string} FORCE)
+  else()
+    set(${variable} ${value})
+  endif()
+endmacro()
+
+
+##############################################################################
+# Add an interface library definition that is dependent on the source.
+#
+# It's probably easiest to explain why this macro exists, by describing
+# what things would look like if we didn't have this macro.
+#
+# Let's suppose we want to statically link against torch.  We've defined
+# a library in cmake called torch, and we might think that we just
+# target_link_libraries(my-app PUBLIC torch).  This will result in a
+# linker argument 'libtorch.a' getting passed to the linker.
+#
+# Unfortunately, this link command is wrong!  We have static
+# initializers in libtorch.a that would get improperly pruned by
+# the default link settings.  What we actually need is for you
+# to do -Wl,--whole-archive,libtorch.a -Wl,--no-whole-archive to ensure
+# that we keep all symbols, even if they are (seemingly) not used.
+#
+# What caffe2_interface_library does is create an interface library
+# that indirectly depends on the real library, but sets up the link
+# arguments so that you get all of the extra link settings you need.
+# The result is not a "real" library, and so we have to manually
+# copy over necessary properties from the original target.
+#
+# (The discussion above is about static libraries, but a similar
+# situation occurs for dynamic libraries: if no symbols are used from
+# a dynamic library, it will be pruned unless you are --no-as-needed)
+macro(caffe2_interface_library SRC DST)
+  add_library(${DST} INTERFACE)
+  add_dependencies(${DST} ${SRC})
+  # Depending on the nature of the source library as well as the compiler,
+  # determine the needed compilation flags.
+  get_target_property(__src_target_type ${SRC} TYPE)
+  # Depending on the type of the source library, we will set up the
+  # link command for the specific SRC library.
+  if(${__src_target_type} STREQUAL "STATIC_LIBRARY")
+    # In the case of static library, we will need to add whole-static flags.
+    target_link_libraries(${DST} INTERFACE $)
+    # Link all interface link libraries of the src target as well.
+    # For static library, we need to explicitly depend on all the libraries
+    # that are the dependent library of the source library. Note that we cannot
+    # use the populated INTERFACE_LINK_LIBRARIES property, because if one of the
+    # dependent library is not a target, cmake creates a $ wrapper
+    # and then one is not able to find target "src". For more discussions, check
+    #   https://cmake.org/Bug/print_bug_page.php?bug_id=15415
+    #   https://cmake.org/pipermail/cmake-developers/2013-May/019019.html
+    # Specifically the following quote
+    #
+    # """
+    # For STATIC libraries we can define that the PUBLIC/PRIVATE/INTERFACE keys
+    # are ignored for linking and that it always populates both LINK_LIBRARIES
+    # LINK_INTERFACE_LIBRARIES.  Note that for STATIC libraries the
+    # LINK_LIBRARIES property will not be used for anything except build-order
+    # dependencies.
+    # """
+    target_link_libraries(${DST} INTERFACE
+        $)
+  elseif(${__src_target_type} STREQUAL "SHARED_LIBRARY")
+    if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU")
+      target_link_libraries(${DST} INTERFACE
+          "-Wl,--no-as-needed,\"$\" -Wl,--as-needed")
+    else()
+      target_link_libraries(${DST} INTERFACE ${SRC})
+    endif()
+    # Link all interface link libraries of the src target as well.
+    # For shared libraries, we can simply depend on the INTERFACE_LINK_LIBRARIES
+    # property of the target.
+    target_link_libraries(${DST} INTERFACE
+        $)
+  else()
+    message(FATAL_ERROR
+        "You made a CMake build file error: target " ${SRC}
+        " must be of type either STATIC_LIBRARY or SHARED_LIBRARY. However, "
+        "I got " ${__src_target_type} ".")
+  endif()
+  # For all other interface properties, manually inherit from the source target.
+  set_target_properties(${DST} PROPERTIES
+    INTERFACE_COMPILE_DEFINITIONS
+    $
+    INTERFACE_COMPILE_OPTIONS
+    $
+    INTERFACE_INCLUDE_DIRECTORIES
+    $
+    INTERFACE_SYSTEM_INCLUDE_DIRECTORIES
+    $)
+endmacro()
+
+
+##############################################################################
+# Creating a Caffe2 binary target with sources specified with relative path.
+# Usage:
+#   caffe2_binary_target(target_name_or_src  [] [] ...)
+# If only target_name_or_src is specified, this target is build with one single
+# source file and the target name is autogen from the filename. Otherwise, the
+# target name is given by the first argument and the rest are the source files
+# to build the target.
+function(caffe2_binary_target target_name_or_src)
+  # https://cmake.org/cmake/help/latest/command/function.html
+  # Checking that ARGC is greater than # is the only way to ensure
+  # that ARGV# was passed to the function as an extra argument.
+  if(ARGC GREATER 1)
+    set(__target ${target_name_or_src})
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
+  else()
+    get_filename_component(__target ${target_name_or_src} NAME_WE)
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}")
+  endif()
+  add_executable(${__target} ${__srcs})
+  target_link_libraries(${__target} torch_library)
+  # If we have Caffe2_MODULES defined, we will also link with the modules.
+  if(DEFINED Caffe2_MODULES)
+    target_link_libraries(${__target} ${Caffe2_MODULES})
+  endif()
+  install(TARGETS ${__target} DESTINATION bin)
+endfunction()
+
+function(caffe2_hip_binary_target target_name_or_src)
+  if(ARGC GREATER 1)
+    set(__target ${target_name_or_src})
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${ARGN}")
+  else()
+    get_filename_component(__target ${target_name_or_src} NAME_WE)
+    prepend(__srcs "${CMAKE_CURRENT_SOURCE_DIR}/" "${target_name_or_src}")
+  endif()
+
+  caffe2_binary_target(${target_name_or_src})
+
+  target_compile_options(${__target} PRIVATE ${HIP_CXX_FLAGS})
+  target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDE})
+endfunction()
+
+
+##############################################################################
+# Multiplex between adding libraries for CUDA versus HIP (AMD Software Stack).
+# Usage:
+#   torch_cuda_based_add_library(cuda_target)
+#
+macro(torch_cuda_based_add_library cuda_target)
+  if(USE_ROCM)
+    hip_add_library(${cuda_target} ${ARGN})
+  elseif(USE_CUDA)
+    add_library(${cuda_target} ${ARGN})
+  else()
+  endif()
+endmacro()
+
+##############################################################################
+# Get the HIP arch flags specified by PYTORCH_ROCM_ARCH.
+# Usage:
+#   torch_hip_get_arch_list(variable_to_store_flags)
+#
+macro(torch_hip_get_arch_list store_var)
+  if(DEFINED ENV{PYTORCH_ROCM_ARCH})
+    set(_TMP $ENV{PYTORCH_ROCM_ARCH})
+  else()
+    # Use arch of installed GPUs as default
+    execute_process(COMMAND "rocm_agent_enumerator" COMMAND bash "-c" "grep -v gfx000 | sort -u | xargs | tr -d '\n'"
+                    RESULT_VARIABLE ROCM_AGENT_ENUMERATOR_RESULT
+                    OUTPUT_VARIABLE ROCM_ARCH_INSTALLED)
+    if(NOT ROCM_AGENT_ENUMERATOR_RESULT EQUAL 0)
+      message(FATAL_ERROR " Could not detect ROCm arch for GPUs on machine. Result: '${ROCM_AGENT_ENUMERATOR_RESULT}'")
+    endif()
+    set(_TMP ${ROCM_ARCH_INSTALLED})
+  endif()
+  string(REPLACE " " ";" ${store_var} "${_TMP}")
+endmacro()
+
+##############################################################################
+# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST.
+# Usage:
+#   torch_xpu_get_arch_list(variable_to_store_flags)
+#
+macro(torch_xpu_get_arch_list store_var)
+  if(DEFINED ENV{TORCH_XPU_ARCH_LIST})
+    set(${store_var} $ENV{TORCH_XPU_ARCH_LIST})
+  endif()
+endmacro()
+
+##############################################################################
+# Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME.
+# Usage:
+#   torch_cuda_get_nvcc_gencode_flag(variable_to_store_flags)
+#
+macro(torch_cuda_get_nvcc_gencode_flag store_var)
+  # setting nvcc arch flags
+  # We need to support the explicitly and conveniently defined TORCH_CUDA_ARCH_LIST
+  if((NOT DEFINED TORCH_CUDA_ARCH_LIST) AND (DEFINED ENV{TORCH_CUDA_ARCH_LIST}))
+    set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST})
+  endif()
+  if(DEFINED CUDA_ARCH_NAME)
+    message(WARNING
+        "CUDA_ARCH_NAME is no longer used. Use TORCH_CUDA_ARCH_LIST instead. "
+        "Right now, CUDA_ARCH_NAME is ${CUDA_ARCH_NAME} and "
+        "TORCH_CUDA_ARCH_LIST is ${TORCH_CUDA_ARCH_LIST}.")
+    if(NOT TORCH_CUDA_ARCH_LIST)
+      set(TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME})
+    else()
+      list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME})
+    endif()
+  endif()
+
+  # Invoke cuda_select_nvcc_arch_flags from proper cmake FindCUDA.
+  cuda_select_nvcc_arch_flags(${store_var} ${TORCH_CUDA_ARCH_LIST})
+endmacro()
+
+
+##############################################################################
+# Add standard compile options.
+# Usage:
+#   torch_compile_options(lib_name)
+function(torch_compile_options libname)
+  set_property(TARGET ${libname} PROPERTY CXX_STANDARD 17)
+
+  # until they can be unified, keep these lists synced with setup.py
+  if(MSVC)
+
+    if(MSVC_Z7_OVERRIDE)
+      set(MSVC_DEBINFO_OPTION "/Z7")
+    else()
+      set(MSVC_DEBINFO_OPTION "/Zi")
+    endif()
+
+    if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 142)
+      # Add /permissive- flag for conformance mode to the compiler.
+      # This will force more strict check to the code standard.
+      # 1. From MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/permissive-standards-conformance?view=msvc-170#remarks
+      #    By default, the /permissive- option is set in new projects created by Visual Studio 2017 version 15.5 and later versions.
+      #    We set the /permissive- flag from VS 2019 (MSVC_TOOLSET_VERSION 142) to avoid compiling issues for old toolkit.
+      # 2. For MSVC VERSION: https://cmake.org/cmake/help/latest/variable/MSVC_TOOLSET_VERSION.html
+      target_compile_options(${libname} PUBLIC $<$:/permissive->)
+    endif()
+    # This option enables a token-based preprocessor that conforms to C99 and C++11 and later standards.
+    # This option is available since VS 2017.
+    # For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE)
+
+    target_compile_options(${libname} PUBLIC
+      $<$:
+        ${MSVC_RUNTIME_LIBRARY_OPTION}
+        $<$,$>:${MSVC_DEBINFO_OPTION}>
+        /EHsc
+        /bigobj>
+      )
+  else()
+    set(private_compile_options
+      -Wall
+      -Wextra
+      -Wdeprecated
+      -Wunused
+      -Wno-unused-parameter
+      -Wno-missing-field-initializers
+      -Wno-array-bounds
+      -Wno-unknown-pragmas
+      -Wno-strict-overflow
+      -Wno-strict-aliasing
+      )
+    if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+      list(APPEND private_compile_options -Wredundant-move)
+      # -Wno-interference-size only exists in GCC 12+
+      if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12)
+        list(APPEND private_compile_options -Wno-interference-size)
+      endif()
+    endif()
+    if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+      list(APPEND private_compile_options -Wextra-semi -Wmove)
+    else()
+      list(APPEND private_compile_options
+        # Considered to be flaky.  See the discussion at
+        # https://github.com/pytorch/pytorch/pull/9608
+        -Wno-maybe-uninitialized)
+    endif()
+
+    if(WERROR)
+      list(APPEND private_compile_options
+        -Werror
+        -Werror=ignored-attributes
+        -Werror=inconsistent-missing-override
+        -Werror=inconsistent-missing-destructor-override
+        -Werror=pedantic
+        -Werror=unused
+        -Wno-error=unused-parameter
+      )
+      if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+        list(APPEND private_compile_options -Werror=unused-but-set-variable)
+      endif()
+    endif()
+  endif()
+
+
+  target_compile_options(${libname} PRIVATE
+      $<$:${private_compile_options}>)
+  if(USE_CUDA)
+    foreach(option IN LISTS private_compile_options)
+      if(CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "GNU")
+        if("${option}" STREQUAL "-Wextra-semi")
+          continue()
+        endif()
+        if("${option}" STREQUAL "-Wunused-private-field")
+          continue()
+        endif()
+      endif()
+      target_compile_options(${libname} PRIVATE $<$:-Xcompiler ${option}>)
+    endforeach()
+  endif()
+
+  if(NOT WIN32 AND NOT USE_ASAN)
+    # Enable hidden visibility by default to make it easier to debug issues with
+    # TORCH_API annotations. Hidden visibility with selective default visibility
+    # behaves close enough to Windows' dllimport/dllexport.
+    #
+    # Unfortunately, hidden visibility messes up some ubsan warnings because
+    # templated classes crossing library boundary get duplicated (but identical)
+    # definitions. It's easier to just disable it.
+    target_compile_options(${libname} PRIVATE
+        $<$: -fvisibility=hidden>)
+  endif()
+
+endfunction()
+
+##############################################################################
+# Set old-style FindCuda.cmake compile flags from modern CMake cuda flags.
+# Usage:
+#   torch_update_find_cuda_flags()
+function(torch_update_find_cuda_flags)
+  # Convert -O2 -Xcompiler="-O2 -Wall" to "-O2;-Xcompiler=-O2,-Wall"
+  if(USE_CUDA)
+    separate_arguments(FLAGS UNIX_COMMAND "${CMAKE_CUDA_FLAGS}")
+    string(REPLACE " " "," FLAGS "${FLAGS}")
+    set(CUDA_NVCC_FLAGS ${FLAGS} PARENT_SCOPE)
+
+    separate_arguments(FLAGS_DEBUG UNIX_COMMAND "${CMAKE_CUDA_FLAGS_DEBUG}")
+    string(REPLACE " " "," FLAGS_DEBUG "${FLAGS_DEBUG}")
+    set(CUDA_NVCC_FLAGS_DEBUG "${FLAGS_DEBUG}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_RELEASE UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELEASE}")
+    string(REPLACE " " "," FLAGS_RELEASE "${FLAGS_RELEASE}")
+    set(CUDA_NVCC_FLAGS_RELEASE "${FLAGS_RELEASE}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_MINSIZEREL UNIX_COMMAND "${CMAKE_CUDA_FLAGS_MINSIZEREL}")
+    string(REPLACE " " "," FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}")
+    set(CUDA_NVCC_FLAGS_MINSIZEREL "${FLAGS_MINSIZEREL}" PARENT_SCOPE)
+
+    separate_arguments(FLAGS_RELWITHDEBINFO UNIX_COMMAND "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
+    string(REPLACE " " "," FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}")
+    set(CUDA_NVCC_FLAGS_RELWITHDEBINFO "${FLAGS_RELWITHDEBINFO}" PARENT_SCOPE)
+
+    message(STATUS "Converting CMAKE_CUDA_FLAGS to CUDA_NVCC_FLAGS:\n"
+                    "    CUDA_NVCC_FLAGS                = ${FLAGS}\n"
+                    "    CUDA_NVCC_FLAGS_DEBUG          = ${FLAGS_DEBUG}\n"
+                    "    CUDA_NVCC_FLAGS_RELEASE        = ${FLAGS_RELEASE}\n"
+                    "    CUDA_NVCC_FLAGS_RELWITHDEBINFO = ${FLAGS_RELWITHDEBINFO}\n"
+                    "    CUDA_NVCC_FLAGS_MINSIZEREL     = ${FLAGS_MINSIZEREL}")
+  endif()
+endfunction()
+
+include(CheckCXXCompilerFlag)
+include(CheckCCompilerFlag)
+include(CheckLinkerFlag)
+
+##############################################################################
+# CHeck if given flag is supported and append it to provided outputvar
+# Also define HAS_UPPER_CASE_FLAG_NAME variable
+# Usage:
+#   append_cxx_flag_if_supported("-Werror" CMAKE_CXX_FLAGS)
+function(append_cxx_flag_if_supported flag outputvar)
+    string(TOUPPER "HAS${flag}" _FLAG_NAME)
+    string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
+    # GCC silents unknown -Wno-XXX flags, so we detect the corresponding -WXXX.
+    if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+      string(REGEX REPLACE "Wno-" "W" new_flag "${flag}")
+    else()
+      set(new_flag ${flag})
+    endif()
+    check_cxx_compiler_flag("${new_flag}" ${_FLAG_NAME})
+    if(${_FLAG_NAME})
+        string(APPEND ${outputvar} " ${flag}")
+        set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
+    endif()
+endfunction()
+
+function(append_c_flag_if_supported flag outputvar)
+    string(TOUPPER "HAS${flag}" _FLAG_NAME)
+    string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
+
+    # GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
+    if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
+        string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
+    else()
+        set(new_flag "${flag}")
+    endif()
+
+    check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
+    if(${_FLAG_NAME})
+        string(APPEND ${outputvar} " ${flag}")
+        set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
+    endif()
+endfunction()
+
+function(target_compile_options_if_supported target flag)
+  set(_compile_options "")
+  append_cxx_flag_if_supported("${flag}" _compile_options)
+  if(NOT "${_compile_options}" STREQUAL "")
+    target_compile_options(${target} PRIVATE ${flag})
+  endif()
+endfunction()
+
+# Check if a global link option is supported
+function(add_link_options_if_supported flag)
+  check_linker_flag(C "LINKER:${flag}" _supported)
+  if("${_supported}")
+    add_link_options("LINKER:${flag}")
+  else()
+    message(WARNING "Attempted to use unsupported link option : ${flag}.")
+  endif()
+endfunction()
+
+function(target_link_options_if_supported tgt flag)
+  check_linker_flag(C "LINKER:${flag}" _supported)
+  if("${_supported}")
+    target_link_options("${tgt}" PRIVATE "LINKER:${flag}")
+  else()
+    message(WARNING "Attempted to use unsupported link option : ${flag}.")
+  endif()
+endfunction()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b39e31d0ade8aa52206784ae93f37238a3b7fd11
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Caffe2/public/xpu.cmake
@@ -0,0 +1,56 @@
+# ---[ xpu
+
+# Poor man's include guard
+if(TARGET torch::xpurt)
+  return()
+endif()
+
+set(XPU_HOST_CXX_FLAGS)
+
+# Find SYCL library.
+find_package(SYCLToolkit REQUIRED)
+if(NOT SYCL_FOUND)
+  set(PYTORCH_FOUND_XPU FALSE)
+  # Exit early to avoid populating XPU_HOST_CXX_FLAGS.
+  return()
+endif()
+set(PYTORCH_FOUND_XPU TRUE)
+
+# SYCL library interface
+add_library(torch::sycl INTERFACE IMPORTED)
+
+set_property(
+    TARGET torch::sycl PROPERTY INTERFACE_INCLUDE_DIRECTORIES
+    ${SYCL_INCLUDE_DIR})
+set_property(
+    TARGET torch::sycl PROPERTY INTERFACE_LINK_LIBRARIES
+    ${SYCL_LIBRARY})
+
+# xpurt
+add_library(torch::xpurt INTERFACE IMPORTED)
+set_property(
+    TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES
+    torch::sycl)
+
+# setting xpu arch flags
+torch_xpu_get_arch_list(XPU_ARCH_FLAGS)
+# propagate to torch-xpu-ops
+set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS})
+
+# Ensure USE_XPU is enabled.
+string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU")
+string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}")
+
+if(DEFINED ENV{XPU_ENABLE_KINETO})
+  set(XPU_ENABLE_KINETO TRUE)
+else()
+  set(XPU_ENABLE_KINETO FALSE)
+endif()
+
+if(WIN32)
+  if(${SYCL_COMPILER_VERSION} GREATER_EQUAL 20250101)
+    set(XPU_ENABLE_KINETO TRUE)
+  endif()
+else()
+  set(XPU_ENABLE_KINETO TRUE)
+endif()
\ No newline at end of file
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..b59f8ceca10f56aaad16d71c32979919ea0537c1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake
@@ -0,0 +1,39 @@
+#----------------------------------------------------------------
+# Generated CMake target import file for configuration "Release".
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Import target "tensorpipe_uv" for configuration "Release"
+set_property(TARGET tensorpipe_uv APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe_uv PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "C"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe_uv.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe_uv )
+list(APPEND _cmake_import_check_files_for_tensorpipe_uv "${_IMPORT_PREFIX}/lib64/libtensorpipe_uv.a" )
+
+# Import target "tensorpipe" for configuration "Release"
+set_property(TARGET tensorpipe APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe )
+list(APPEND _cmake_import_check_files_for_tensorpipe "${_IMPORT_PREFIX}/lib64/libtensorpipe.a" )
+
+# Import target "tensorpipe_cuda" for configuration "Release"
+set_property(TARGET tensorpipe_cuda APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
+set_target_properties(tensorpipe_cuda PROPERTIES
+  IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
+  IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libtensorpipe_cuda.a"
+  )
+
+list(APPEND _cmake_import_check_targets tensorpipe_cuda )
+list(APPEND _cmake_import_check_files_for_tensorpipe_cuda "${_IMPORT_PREFIX}/lib64/libtensorpipe_cuda.a" )
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..26ba6741ec29a4a4940154884073da6fc469553d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Tensorpipe/TensorpipeTargets.cmake
@@ -0,0 +1,122 @@
+# Generated by CMake
+
+if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
+   message(FATAL_ERROR "CMake >= 2.8.12 required")
+endif()
+if(CMAKE_VERSION VERSION_LESS "2.8.12")
+   message(FATAL_ERROR "CMake >= 2.8.12 required")
+endif()
+cmake_policy(PUSH)
+cmake_policy(VERSION 2.8.12...4.0)
+#----------------------------------------------------------------
+# Generated CMake target import file.
+#----------------------------------------------------------------
+
+# Commands may need to know the format version.
+set(CMAKE_IMPORT_FILE_VERSION 1)
+
+# Protect against multiple inclusion, which would fail when already imported targets are added once more.
+set(_cmake_targets_defined "")
+set(_cmake_targets_not_defined "")
+set(_cmake_expected_targets "")
+foreach(_cmake_expected_target IN ITEMS tensorpipe_uv tensorpipe tensorpipe_cuda)
+  list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
+  if(TARGET "${_cmake_expected_target}")
+    list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
+  else()
+    list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
+  endif()
+endforeach()
+unset(_cmake_expected_target)
+if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
+  unset(_cmake_targets_defined)
+  unset(_cmake_targets_not_defined)
+  unset(_cmake_expected_targets)
+  unset(CMAKE_IMPORT_FILE_VERSION)
+  cmake_policy(POP)
+  return()
+endif()
+if(NOT _cmake_targets_defined STREQUAL "")
+  string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
+  string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
+  message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
+endif()
+unset(_cmake_targets_defined)
+unset(_cmake_targets_not_defined)
+unset(_cmake_expected_targets)
+
+
+# Compute the installation prefix relative to this file.
+get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
+if(_IMPORT_PREFIX STREQUAL "/")
+  set(_IMPORT_PREFIX "")
+endif()
+
+# Create imported target tensorpipe_uv
+add_library(tensorpipe_uv STATIC IMPORTED)
+
+set_target_properties(tensorpipe_uv PROPERTIES
+  INTERFACE_LINK_LIBRARIES "\$;\$;\$;\$"
+)
+
+# Create imported target tensorpipe
+add_library(tensorpipe STATIC IMPORTED)
+
+set_target_properties(tensorpipe PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
+  INTERFACE_LINK_LIBRARIES "\$"
+)
+
+# Create imported target tensorpipe_cuda
+add_library(tensorpipe_cuda STATIC IMPORTED)
+
+set_target_properties(tensorpipe_cuda PROPERTIES
+  INTERFACE_INCLUDE_DIRECTORIES "/usr/local/cuda/include"
+  INTERFACE_LINK_LIBRARIES "tensorpipe;/usr/local/cuda/lib64/libcudart.so"
+)
+
+# Load information for each installed configuration.
+file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TensorpipeTargets-*.cmake")
+foreach(_cmake_config_file IN LISTS _cmake_config_files)
+  include("${_cmake_config_file}")
+endforeach()
+unset(_cmake_config_file)
+unset(_cmake_config_files)
+
+# Cleanup temporary variables.
+set(_IMPORT_PREFIX)
+
+# Loop over all imported files and verify that they actually exist
+foreach(_cmake_target IN LISTS _cmake_import_check_targets)
+  if(CMAKE_VERSION VERSION_LESS "3.28"
+      OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
+      OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
+    foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
+      if(NOT EXISTS "${_cmake_file}")
+        message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
+   \"${_cmake_file}\"
+but this file does not exist.  Possible reasons include:
+* The file was deleted, renamed, or moved to another location.
+* An install or uninstall procedure did not complete successfully.
+* The installation package was faulty and contained
+   \"${CMAKE_CURRENT_LIST_FILE}\"
+but not all the files it references.
+")
+      endif()
+    endforeach()
+  endif()
+  unset(_cmake_file)
+  unset("_cmake_import_check_files_for_${_cmake_target}")
+endforeach()
+unset(_cmake_target)
+unset(_cmake_import_check_targets)
+
+# This file does not depend on other imported targets which have
+# been exported from the same project but in a separate export set.
+
+# Commands beyond this point should not need to know the version.
+set(CMAKE_IMPORT_FILE_VERSION)
+cmake_policy(POP)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..83dc0fd9eb073ff05285b2a3f7a41d745a123899
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfig.cmake
@@ -0,0 +1,170 @@
+# FindTorch
+# -------
+#
+# Finds the Torch library
+#
+# This will define the following variables:
+#
+#   TORCH_FOUND        -- True if the system has the Torch library
+#   TORCH_INCLUDE_DIRS -- The include directories for torch
+#   TORCH_LIBRARIES    -- Libraries to link against
+#   TORCH_CXX_FLAGS    -- Additional (required) compiler flags
+#
+# and the following imported targets:
+#
+#   torch
+macro(append_torchlib_if_found)
+  foreach (_arg ${ARGN})
+    find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    if(${_arg}_LIBRARY)
+      list(APPEND TORCH_LIBRARIES ${${_arg}_LIBRARY})
+    else()
+      message(WARNING "static library ${${_arg}_LIBRARY} not found.")
+    endif()
+  endforeach()
+endmacro()
+
+macro(append_wholearchive_lib_if_found)
+  foreach (_arg ${ARGN})
+    find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    if(${_arg}_LIBRARY)
+      if(APPLE)
+        list(APPEND TORCH_LIBRARIES "-Wl,-force_load,${${_arg}_LIBRARY}")
+      elseif(MSVC)
+        list(APPEND TORCH_LIBRARIES "-WHOLEARCHIVE:${${_arg}_LIBRARY}")
+      else()
+        # Linux
+        list(APPEND TORCH_LIBRARIES "-Wl,--whole-archive ${${_arg}_LIBRARY} -Wl,--no-whole-archive")
+      endif()
+    else()
+      message(WARNING "static library ${${_arg}_LIBRARY} not found.")
+    endif()
+  endforeach()
+endmacro()
+
+include(FindPackageHandleStandardArgs)
+
+if(DEFINED ENV{TORCH_INSTALL_PREFIX})
+  set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX})
+else()
+  # Assume we are in /share/cmake/Torch/TorchConfig.cmake
+  get_filename_component(CMAKE_CURRENT_LIST_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
+  get_filename_component(TORCH_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
+endif()
+
+# Include directories.
+if(EXISTS "${TORCH_INSTALL_PREFIX}/include")
+  set(TORCH_INCLUDE_DIRS
+    ${TORCH_INSTALL_PREFIX}/include
+    ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
+else()
+  set(TORCH_INCLUDE_DIRS
+    ${TORCH_INSTALL_PREFIX}/include
+    ${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
+endif()
+
+# Library dependencies.
+if(ON)
+  find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2)
+  set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS})
+  append_torchlib_if_found(c10)
+else()
+  add_library(torch STATIC IMPORTED) # set imported_location at the bottom
+  #library need whole archive
+  append_wholearchive_lib_if_found(torch torch_cpu)
+  if(ON)
+    append_wholearchive_lib_if_found(torch_cuda c10_cuda)
+  endif()
+  if(OFF)
+    append_wholearchive_lib_if_found(torch_xpu c10_xpu)
+  endif()
+
+  # We need manually add dependent libraries when they are not linked into the
+  # shared library.
+  # TODO: this list might be incomplete.
+  append_torchlib_if_found(c10)
+
+  if(ON)
+    append_torchlib_if_found(nnpack)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(pytorch_qnnpack)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(XNNPACK)
+    append_torchlib_if_found(microkernels-prod)
+  endif()
+
+  if(OFF)
+    append_torchlib_if_found(kleidiai)
+  endif()
+
+  append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
+  append_torchlib_if_found(onnx onnx_proto)
+
+  append_torchlib_if_found(fmt)
+  append_torchlib_if_found(cpuinfo clog)
+
+  append_torchlib_if_found(eigen_blas)
+  append_torchlib_if_found(pthreadpool)
+
+  if(ON)
+    append_torchlib_if_found(fbgemm)
+  endif()
+
+  if(ON)
+    append_torchlib_if_found(dnnl mkldnn)
+  endif()
+
+  append_torchlib_if_found(sleef asmjit)
+endif()
+
+if(1)
+  append_torchlib_if_found(kineto)
+endif()
+
+if(ON)
+  if(MSVC)
+    find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY})
+  else()
+    set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
+  endif()
+  if(TARGET torch::nvtoolsext)
+    list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
+  endif()
+
+  if(ON)
+    find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
+    list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY} ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
+  endif()
+  list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES})
+endif()
+
+if(OFF AND ON)
+    append_torchlib_if_found(c10_xpu torch_xpu)
+endif()
+
+find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
+# the statements below changes target properties on
+# - the imported target from Caffe2Targets.cmake in shared library mode (see the find_package above)
+#    - this is untested whether it is the correct (or desired) methodology in CMake
+# - the imported target created in this file in static library mode
+if(NOT ON)
+  # do not set this property on the shared library target, as it will cause confusion in some builds
+  # as the configuration specific property is set in the Caffe2Targets.cmake file
+  set_target_properties(torch PROPERTIES
+      IMPORTED_LOCATION "${TORCH_LIBRARY}"
+  )
+endif()
+set_target_properties(torch PROPERTIES
+    INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
+    CXX_STANDARD 17
+)
+if(TORCH_CXX_FLAGS)
+  set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}")
+endif()
+
+find_package_handle_standard_args(Torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..c7379319b36ec11b13d940841cde5ff9d17025ce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/share/cmake/Torch/TorchConfigVersion.cmake
@@ -0,0 +1,11 @@
+set(PACKAGE_VERSION "2.10.0")
+
+# Check whether the requested PACKAGE_FIND_VERSION is compatible
+if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
+  set(PACKAGE_VERSION_COMPATIBLE FALSE)
+else()
+  set(PACKAGE_VERSION_COMPATIBLE TRUE)
+  if("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
+    set(PACKAGE_VERSION_EXACT TRUE)
+  endif()
+endif()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6a08ca95bf1d3a3afeb05e3b70111a4a1d82e06
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6749a92c6fc1525ea95c7d4d1e398229ab10b7a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__init__.py
@@ -0,0 +1,28 @@
+from .windows import (
+    bartlett,
+    blackman,
+    cosine,
+    exponential,
+    gaussian,
+    general_cosine,
+    general_hamming,
+    hamming,
+    hann,
+    kaiser,
+    nuttall,
+)
+
+
+__all__ = [
+    "bartlett",
+    "blackman",
+    "cosine",
+    "exponential",
+    "gaussian",
+    "general_cosine",
+    "general_hamming",
+    "hamming",
+    "hann",
+    "kaiser",
+    "nuttall",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1787398f57b81107aaceecbece7f30caef9ecf15
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..817f131fafb7b83406355e9d7499c2ef762e6668
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/windows.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/windows.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda60aadfe1d6208354b045a86700e858cc946f0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/signal/windows/windows.py
@@ -0,0 +1,883 @@
+# mypy: allow-untyped-defs
+from collections.abc import Callable, Iterable
+from math import sqrt
+from typing import TypeVar
+
+import torch
+from torch import Tensor
+from torch._torch_docs import factory_common_args, merge_dicts, parse_kwargs
+
+
+__all__ = [
+    "bartlett",
+    "blackman",
+    "cosine",
+    "exponential",
+    "gaussian",
+    "general_cosine",
+    "general_hamming",
+    "hamming",
+    "hann",
+    "kaiser",
+    "nuttall",
+]
+
+_T = TypeVar("_T")
+
+window_common_args = merge_dicts(
+    parse_kwargs(
+        """
+    M (int): the length of the window.
+        In other words, the number of points of the returned window.
+    sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
+        If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
+"""
+    ),
+    factory_common_args,
+    {
+        "normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if "
+        ":attr:`M` is even and :attr:`sym` is `True`.",
+    },
+)
+
+
+def _add_docstr(*args: str) -> Callable[[_T], _T]:
+    r"""Adds docstrings to a given decorated function.
+
+    Specially useful when then docstrings needs string interpolation, e.g., with
+    str.format().
+    REMARK: Do not use this function if the docstring doesn't need string
+    interpolation, just write a conventional docstring.
+
+    Args:
+        args (str):
+    """
+
+    def decorator(o: _T) -> _T:
+        o.__doc__ = "".join(args)
+        return o
+
+    return decorator
+
+
+def _window_function_checks(
+    function_name: str, M: int, dtype: torch.dtype, layout: torch.layout
+) -> None:
+    r"""Performs common checks for all the defined windows.
+    This function should be called before computing any window.
+
+    Args:
+        function_name (str): name of the window function.
+        M (int): length of the window.
+        dtype (:class:`torch.dtype`): the desired data type of returned tensor.
+        layout (:class:`torch.layout`): the desired layout of returned tensor.
+    """
+    if M < 0:
+        raise ValueError(
+            f"{function_name} requires non-negative window length, got M={M}"
+        )
+    if layout is not torch.strided:
+        raise ValueError(
+            f"{function_name} is implemented for strided tensors only, got: {layout}"
+        )
+    if dtype not in [torch.float32, torch.float64]:
+        raise ValueError(
+            f"{function_name} expects float32 or float64 dtypes, got: {dtype}"
+        )
+
+
+@_add_docstr(
+    r"""
+Computes a window with an exponential waveform.
+Also known as Poisson window.
+
+The exponential window is defined as follows:
+
+.. math::
+    w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)}
+
+where `c` is the ``center`` of the window.
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    center (float, optional): where the center of the window will be located.
+        Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
+    tau (float, optional): the decay value.
+        Tau is generally associated with a percentage, that means, that the value should
+        vary within the interval (0, 100]. If tau is 100, it is considered the uniform window.
+        Default: 1.0.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric exponential window of size 10 and with a decay value of 1.0.
+    >>> # The center will be at (M - 1) / 2, where M is 10.
+    >>> torch.signal.windows.exponential(10)
+    tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])
+
+    >>> # Generates a periodic exponential window and decay factor equal to .5
+    >>> torch.signal.windows.exponential(10, sym=False,tau=.5)
+    tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
+    """.format(**window_common_args),
+)
+def exponential(
+    M: int,
+    *,
+    center: float | None = None,
+    tau: float = 1.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("exponential", M, dtype, layout)
+
+    if tau <= 0:
+        raise ValueError(f"Tau must be positive, got: {tau} instead.")
+
+    if sym and center is not None:
+        raise ValueError("Center must be None for symmetric windows")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if center is None:
+        center = (M if not sym and M > 1 else M - 1) / 2.0
+
+    constant = 1 / tau
+
+    k = torch.linspace(
+        start=-center * constant,
+        end=(-center + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.exp(-torch.abs(k))
+
+
+@_add_docstr(
+    r"""
+Computes a window with a simple cosine waveform, following the same implementation as SciPy.
+This window is also known as the sine window.
+
+The cosine window is defined as follows:
+
+.. math::
+    w_n = \sin\left(\frac{\pi (n + 0.5)}{M}\right)
+
+This formula differs from the typical cosine window formula by incorporating a 0.5 term in the numerator,
+which shifts the sample positions. This adjustment results in a window that starts and ends with non-zero values.
+
+""",
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric cosine window.
+    >>> torch.signal.windows.cosine(10)
+    tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564])
+
+    >>> # Generates a periodic cosine window.
+    >>> torch.signal.windows.cosine(10, sym=False)
+    tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154])
+""".format(
+        **window_common_args,
+    ),
+)
+def cosine(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("cosine", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = 0.5
+    constant = torch.pi / (M + 1 if not sym and M > 1 else M)
+
+    k = torch.linspace(
+        start=start * constant,
+        end=(start + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.sin(k)
+
+
+@_add_docstr(
+    r"""
+Computes a window with a gaussian waveform.
+
+The gaussian window is defined as follows:
+
+.. math::
+    w_n = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)}
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
+        Default: 1.0.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
+    >>> torch.signal.windows.gaussian(10)
+    tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
+
+    >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
+    >>> torch.signal.windows.gaussian(10, sym=False,std=0.9)
+    tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
+""".format(
+        **window_common_args,
+    ),
+)
+def gaussian(
+    M: int,
+    *,
+    std: float = 1.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("gaussian", M, dtype, layout)
+
+    if std <= 0:
+        raise ValueError(f"Standard deviation must be positive, got: {std} instead.")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = -(M if not sym and M > 1 else M - 1) / 2.0
+
+    constant = 1 / (std * sqrt(2))
+
+    k = torch.linspace(
+        start=start * constant,
+        end=(start + (M - 1)) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.exp(-(k**2))  # pyrefly: ignore [unsupported-operation]
+
+
+@_add_docstr(
+    r"""
+Computes the Kaiser window.
+
+The Kaiser window is defined as follows:
+
+.. math::
+    w_n = I_0 \left( \beta \sqrt{1 - \left( {\frac{n - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta )
+
+where ``I_0`` is the zeroth order modified Bessel function of the first kind (see :func:`torch.special.i0`), and
+``N = M - 1 if sym else M``.
+    """,
+    r"""
+
+{normalization}
+
+Args:
+    {M}
+
+Keyword args:
+    beta (float, optional): shape parameter for the window. Must be non-negative. Default: 12.0
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
+    >>> torch.signal.windows.kaiser(5)
+    tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
+    >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
+    >>> torch.signal.windows.kaiser(5, sym=False,std=0.9)
+    tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
+""".format(
+        **window_common_args,
+    ),
+)
+def kaiser(
+    M: int,
+    *,
+    beta: float = 12.0,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("kaiser", M, dtype, layout)
+
+    if beta < 0:
+        raise ValueError(f"beta must be non-negative, got: {beta} instead.")
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    # Avoid NaNs by casting `beta` to the appropriate dtype.
+    # pyrefly: ignore [bad-assignment]
+    beta = torch.tensor(beta, dtype=dtype, device=device)
+
+    start = -beta
+    constant = 2.0 * beta / (M if not sym else M - 1)
+    end = torch.minimum(
+        # pyrefly: ignore [bad-argument-type]
+        beta,
+        # pyrefly: ignore [bad-argument-type]
+        start + (M - 1) * constant,
+    )
+
+    k = torch.linspace(
+        start=start,
+        end=end,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
+        # pyrefly: ignore [bad-argument-type]
+        beta
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Hamming window.
+
+The Hamming window is defined as follows:
+
+.. math::
+    w_n = \alpha - \beta\ \cos \left( \frac{2 \pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    alpha (float, optional): The coefficient :math:`\alpha` in the equation above.
+    beta (float, optional): The coefficient :math:`\beta` in the equation above.
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hamming window.
+    >>> torch.signal.windows.hamming(10)
+    tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
+
+    >>> # Generates a periodic Hamming window.
+    >>> torch.signal.windows.hamming(10, sym=False)
+    tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
+""".format(**window_common_args),
+)
+def hamming(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_hamming(
+        M,
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Hann window.
+
+The Hann window is defined as follows:
+
+.. math::
+    w_n = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{M - 1} \right)\right] =
+    \sin^2 \left( \frac{\pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hann window.
+    >>> torch.signal.windows.hann(10)
+    tensor([0.0000, 0.1170, 0.4132, 0.7500, 0.9698, 0.9698, 0.7500, 0.4132, 0.1170, 0.0000])
+
+    >>> # Generates a periodic Hann window.
+    >>> torch.signal.windows.hann(10, sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def hann(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_hamming(
+        M,
+        alpha=0.5,
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Blackman window.
+
+The Blackman window is defined as follows:
+
+.. math::
+    w_n = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{M - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{M - 1} \right)
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Blackman window.
+    >>> torch.signal.windows.blackman(5)
+    tensor([-1.4901e-08,  3.4000e-01,  1.0000e+00,  3.4000e-01, -1.4901e-08])
+
+    >>> # Generates a periodic Blackman window.
+    >>> torch.signal.windows.blackman(5, sym=False)
+    tensor([-1.4901e-08,  2.0077e-01,  8.4923e-01,  8.4923e-01,  2.0077e-01])
+""".format(**window_common_args),
+)
+def blackman(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("blackman", M, dtype, layout)
+
+    return general_cosine(
+        M,
+        a=[0.42, 0.5, 0.08],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the Bartlett window.
+
+The Bartlett window is defined as follows:
+
+.. math::
+    w_n = 1 - \left| \frac{2n}{M - 1} - 1 \right| = \begin{cases}
+        \frac{2n}{M - 1} & \text{if } 0 \leq n \leq \frac{M - 1}{2} \\
+        2 - \frac{2n}{M - 1} & \text{if } \frac{M - 1}{2} < n < M \\ \end{cases}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Bartlett window.
+    >>> torch.signal.windows.bartlett(10)
+    tensor([0.0000, 0.2222, 0.4444, 0.6667, 0.8889, 0.8889, 0.6667, 0.4444, 0.2222, 0.0000])
+
+    >>> # Generates a periodic Bartlett window.
+    >>> torch.signal.windows.bartlett(10, sym=False)
+    tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
+""".format(**window_common_args),
+)
+def bartlett(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("bartlett", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    start = -1
+    constant = 2 / (M if not sym else M - 1)
+
+    k = torch.linspace(
+        start=start,
+        end=start + (M - 1) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    return 1 - torch.abs(k)
+
+
+@_add_docstr(
+    r"""
+Computes the general cosine window.
+
+The general cosine window is defined as follows:
+
+.. math::
+    w_n = \sum^{M-1}_{i=0} (-1)^i a_i \cos{ \left( \frac{2 \pi i n}{M - 1}\right)}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    a (Iterable): the coefficients associated to each of the cosine functions.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric general cosine window with 3 coefficients.
+    >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True)
+    tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400])
+
+    >>> # Generates a periodic general cosine window with 2 coefficients.
+    >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def general_cosine(
+    M,
+    *,
+    a: Iterable,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
+
+    _window_function_checks("general_cosine", M, dtype, layout)
+
+    if M == 0:
+        return torch.empty(
+            (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if M == 1:
+        return torch.ones(
+            (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
+        )
+
+    if not isinstance(a, Iterable):
+        raise TypeError("Coefficients must be a list/tuple")
+
+    if not a:
+        raise ValueError("Coefficients cannot be empty")
+
+    constant = 2 * torch.pi / (M if not sym else M - 1)
+
+    k = torch.linspace(
+        start=0,
+        end=(M - 1) * constant,
+        steps=M,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+    a_i = torch.tensor(
+        [(-1) ** i * w for i, w in enumerate(a)],
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    i = torch.arange(
+        a_i.shape[0],
+        dtype=a_i.dtype,
+        device=a_i.device,
+        requires_grad=a_i.requires_grad,
+    )
+    return (a_i.unsqueeze(-1) * torch.cos(i.unsqueeze(-1) * k)).sum(0)
+
+
+@_add_docstr(
+    r"""
+Computes the general Hamming window.
+
+The general Hamming window is defined as follows:
+
+.. math::
+    w_n = \alpha - (1 - \alpha) \cos{ \left( \frac{2 \pi n}{M-1} \right)}
+    """,
+    r"""
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    alpha (float, optional): the window coefficient. Default: 0.54.
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Examples::
+
+    >>> # Generates a symmetric Hamming window with the general Hamming window.
+    >>> torch.signal.windows.general_hamming(10, sym=True)
+    tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
+
+    >>> # Generates a periodic Hann window with the general Hamming window.
+    >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
+    tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
+""".format(**window_common_args),
+)
+def general_hamming(
+    M,
+    *,
+    alpha: float = 0.54,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_cosine(
+        M,
+        a=[alpha, 1.0 - alpha],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
+
+
+@_add_docstr(
+    r"""
+Computes the minimum 4-term Blackman-Harris window according to Nuttall.
+
+.. math::
+    w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)}
+
+where :math:`z_n = \frac{2 \pi n}{M}`.
+    """,
+    """
+
+{normalization}
+
+Arguments:
+    {M}
+
+Keyword args:
+    {sym}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+References::
+
+    - A. Nuttall, "Some windows with very good sidelobe behavior,"
+      IEEE Transactions on Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91,
+      Feb 1981. https://doi.org/10.1109/TASSP.1981.1163506
+
+    - Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier transform (DFT),
+      including a comprehensive list of window functions and some new flat-top windows",
+      February 15, 2002 https://holometer.fnal.gov/GH_FFT.pdf
+
+Examples::
+
+    >>> # Generates a symmetric Nutall window.
+    >>> torch.signal.windows.general_hamming(5, sym=True)
+    tensor([3.6280e-04, 2.2698e-01, 1.0000e+00, 2.2698e-01, 3.6280e-04])
+
+    >>> # Generates a periodic Nuttall window.
+    >>> torch.signal.windows.general_hamming(5, sym=False)
+    tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
+""".format(**window_common_args),
+)
+def nuttall(
+    M: int,
+    *,
+    sym: bool = True,
+    dtype: torch.dtype | None = None,
+    layout: torch.layout = torch.strided,
+    device: torch.device | None = None,
+    requires_grad: bool = False,
+) -> Tensor:
+    return general_cosine(
+        M,
+        a=[0.3635819, 0.4891775, 0.1365995, 0.0106411],
+        sym=sym,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        requires_grad=requires_grad,
+    )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a77fbe4ad7e31a107a3beb464ddb0801ab75954b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42217a3d9ae3300f096a484ca2a630a5e4759035
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..103807a49db53f9cf42989c0ff175ad4865396eb
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22efc2de8735759002e23bf661c2cfd25174b88c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/_triton_ops.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93afaa8b1f4dba02bca37023a0ee9374a907c447
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3616fede6ce67b70f244419236864a03ffcbb35
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
@@ -0,0 +1,472 @@
+# mypy: ignore-errors
+
+import collections
+
+import torch
+from torch.testing._internal.common_utils import TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase
+
+
+class AutocastTestLists:
+    def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
+        input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
+               torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
+              torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_ih
+                   torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_hh
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32),  # bias_ih
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32))  # bias_hh
+
+        # returns args as a tuple
+        return input + hx + weights
+
+    # Supplies ops and arguments for test_autocast_* in test/test_cuda.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+        bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
+        # Each op is associated with a tuple of valid arguments.
+        # In addition, cudnn conv ops are not supported on ROCm and hence will
+        # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+            ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+        ]
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_fp16 = [
+            # deprecated _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True)),
+            # the current  _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True, True)),
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("conv_tbc", conv_args_fp32[0] + bias_fp32),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
+            ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
+                                                                 (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
+            ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("mv", mat0_fp32 + pointwise0_fp32),
+            ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
+            # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
+            ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+        ]
+        self.torch_fp32 = [
+            ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("cosh", pointwise0_fp16),
+            ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
+            ("exp", pointwise0_fp16),
+            ("expm1", pointwise0_fp16),
+            ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
+            ("reciprocal", pointwise0_fp16),
+            ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
+            ("sinh", pointwise0_fp16),
+            ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
+            # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
+            ("softmax", pointwise0_fp16 + (0,)),
+            ("log_softmax", pointwise0_fp16 + (0,)),
+            ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
+            ("group_norm", mat0_fp16 + (1,)),
+            ("norm", pointwise0_fp16),
+            ("norm", pointwise0_fp16, {"dim": 0}),
+            # these need magma
+            # ("norm", mat0_fp16, {"p": "nuc"}),
+            # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
+            ("norm", pointwise0_fp16, {"p": 1}),
+            ("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
+            ("cosine_similarity", mat0_fp16 + mat1_fp16),
+            ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
+            ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
+            ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("cumprod", pointwise0_fp16 + (0,)),
+            ("cumsum", pointwise0_fp16 + (0,)),
+            ("dist", pointwise0_fp16 + pointwise1_fp16),
+            ("pdist", mat0_fp16),
+            ("cdist", mat0_fp16 + mat1_fp16),
+            ("prod", pointwise0_fp16),
+            ("prod", pointwise0_fp16 + (0,)),
+            ("renorm", mat0_fp16 + (2, 0, 1.0)),
+            ("sum", pointwise0_fp16),
+            ("sum", mat0_fp16 + (1,)),
+            ("logsumexp", mat0_fp16 + (1,)),
+        ]
+        self.torch_need_autocast_promote = [
+            ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
+            ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
+            ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
+            ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1, 2), dtype=torch.float32, device=dev),
+                          torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1,), dtype=torch.float32, device=dev))),
+            ("cross", (torch.randn(3, dtype=torch.float32, device=dev),
+                       torch.randn(3, dtype=torch.float16, device=dev))),
+            ("dot", pointwise0_fp16 + pointwise1_fp32),
+            ("vdot", pointwise0_fp16 + pointwise1_fp32),
+            ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
+                              torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
+                              0, 0, False)),
+            ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float16))),
+            ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float32))),
+            ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
+                           torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
+        ]
+        self.nn_fp16 = [
+            ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
+        ]
+        self.nn_fp32 = [
+            ("softplus", pointwise0_fp16),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_fp16 + mat1_fp16),
+            ("smooth_l1_loss", mat0_fp16 + mat1_fp16),
+            ("mse_loss", mat0_fp16 + mat1_fp16),
+            ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+        ]
+        self.linalg_fp16 = [
+            ("linalg_vecdot", mat0_fp32 + mat0_fp32),
+            ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
+        ]
+        self.methods_fp16 = [
+            ("__matmul__", mat0_fp32 + mat1_fp32)
+        ]
+        self.methods_fp32 = [
+            ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
+        ]
+        self.banned = [
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
+                                      torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
+        ]
+
+
+class AutocastCPUTestLists:
+    # Supplies ops and arguments for test_autocast_* in test/test_cpu.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+
+        dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
+
+        dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
+                      for dimset in dummy_dimsets]
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        dummy_fp32 = [  # noqa: F841
+            (torch.randn(dimset, dtype=torch.float32, device=dev),)
+            for dimset in dummy_dimsets
+        ]
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
+        # Each op is associated with a tuple of valid arguments.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_16 = [
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
+                          torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
+                          torch.randn(5, device=dev, dtype=torch.float32),
+                          0)),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((3 * n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n), device=dev, dtype=torch.float32))),
+        ]
+        self.torch_fp32 = [
+            ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
+            ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
+            ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+        ]
+        self.nn_16 = [
+            ("linear", mat0_fp32 + mat1_fp32, {}),
+        ]
+        self.nn_fp32 = [
+            ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
+                                     (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+            ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_bf16 + mat1_bf16),
+            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
+            ("mse_loss", mat0_bf16 + mat1_bf16),
+            ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+            ("huber_loss", mat0_bf16 + mat1_bf16),
+        ]
+        self.torch_need_autocast_promote = [
+            ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+            ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+        ]
+
+
+class TestAutocast(TestCase):
+    def args_maybe_kwargs(self, op_with_args):
+        if len(op_with_args) == 2:
+            return op_with_args[0], op_with_args[1], {}
+        else:
+            return op_with_args[0], op_with_args[1], op_with_args[2]
+
+    def _run_autocast_outofplace(
+        self,
+        op,
+        args,
+        run_as_type,
+        device,
+        out_type=None,
+        module=torch,
+        add_kwargs=None,
+        amp_dtype=torch.bfloat16,
+    ):
+        # helper to cast args
+        def cast(val, to_type):
+            if isinstance(val, torch.Tensor):
+                return val.to(to_type) if val.is_floating_point() else val
+            elif isinstance(val, collections.abc.Iterable):
+                return type(val)(cast(v, to_type) for v in val)
+            else:
+                return val
+
+        if add_kwargs is None:
+            add_kwargs = {}
+
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
+        with torch.amp.autocast(device_type=device, dtype=amp_dtype):
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+
+            out_type = out_type if out_type is not None else run_as_type
+            output = output_method = None
+
+            # Try module.* variant, if requested:
+            if module is not None and hasattr(module, op):
+                output = getattr(module, op)(*args, **add_kwargs)
+                if isinstance(output, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output.dtype,
+                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
+                    )
+            # Try Tensor.* variant:
+            if hasattr(torch.Tensor, op):
+                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
+                if isinstance(output_method, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output_method.dtype,
+                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
+                    )
+
+            self.assertTrue(
+                (output is not None) or (output_method is not None),
+                f"{op} not found as an attribute on either Tensor or the requested module {module}",
+            )
+
+            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
+            # For example, lstm_cell returns a tuple and equal returns bool.
+            def compare(first, second):
+                if isinstance(first, torch.Tensor):
+                    return torch.equal(first, second)
+                elif isinstance(first, collections.abc.Iterable):
+                    return all(compare(f, s) for f, s in zip(first, second, strict=False))
+                else:
+                    return first == second
+
+            # If both torch.* and Tensor.* variants were found, check outputs are identical
+            if (output is not None) and (output_method is not None):
+                self.assertTrue(type(output) is type(output_method))
+                comparison = compare(output, output_method)
+                self.assertTrue(
+                    comparison, f"torch.{op} result did not match Tensor.{op} result"
+                )
+
+            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
+            # as the C++-side autocasting, and should be bitwise accurate.
+            output_to_compare = output if output is not None else output_method
+            with torch.amp.autocast(device_type=device, enabled=False):
+                self.assertFalse(
+                    torch.is_autocast_enabled(device_type=device)
+                )
+
+                if module is not None and hasattr(module, op):
+                    control = getattr(module, op)(
+                        *cast(args, run_as_type), **add_kwargs
+                    )
+                else:
+                    control = getattr(args[0].to(run_as_type), op)(
+                        *cast(args[1:], run_as_type), **add_kwargs
+                    )
+                self.assertTrue(type(output_to_compare) is type(control))
+                comparison = compare(output_to_compare, control)
+                self.assertTrue(comparison, f"torch.{op} result did not match control")
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2219ef4ea56aa306dfdd3af18b7403af8384c78
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
@@ -0,0 +1,164 @@
+# mypy: ignore-errors
+
+import os
+import re
+import sys
+
+__all__ = [
+    "check_code_for_cuda_kernel_launches",
+    "check_cuda_kernel_launches",
+]
+
+# FILES TO EXCLUDE (match is done with suffix using `endswith`)
+# You wouldn't drive without a seatbelt, though, so why would you
+# launch a kernel without some safety? Use this as a quick workaround
+# for a problem with the checker, fix the checker, then de-exclude
+# the files in question.
+exclude_files: list[str] = []
+
+# Without using a C++ AST we can't 100% detect kernel launches, so we
+# model them as having the pattern "<<>>(arguments);"
+# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
+# the next statement.
+#
+# We model the next statement as ending at the next `}` or `;`.
+# If we see `}` then a clause ended (bad) if we see a semi-colon then
+# we expect the launch check just before it.
+#
+# Since the kernel launch can include lambda statements, it's important
+# to find the correct end-paren of the kernel launch. Doing this with
+# pure regex requires recursive regex, which aren't part of the Python
+# standard library. To avoid an additional dependency, we build a prefix
+# regex that finds the start of a kernel launch, use a paren-matching
+# algorithm to find the end of the launch, and then another regex to
+# determine if a launch check is present.
+
+# Finds potential starts of kernel launches
+kernel_launch_start = re.compile(
+    r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
+)
+
+# This pattern should start at the character after the final paren of the
+# kernel launch. It returns a match if the launch check is not the next statement
+has_check = re.compile(
+    r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
+)
+
+def find_matching_paren(s: str, startpos: int) -> int:
+    """Given a string "prefix (unknown number of characters) suffix"
+    and the position of the first `(` returns the index of the character
+    1 past the `)`, accounting for paren nesting
+    """
+    opening = 0
+    for i, c in enumerate(s[startpos:]):
+        if c == '(':
+            opening += 1
+        elif c == ')':
+            opening -= 1
+            if opening == 0:
+                return startpos + i + 1
+
+    raise IndexError("Closing parens not found!")
+
+
+def should_exclude_file(filename) -> bool:
+    for exclude_suffix in exclude_files:
+        if filename.endswith(exclude_suffix):
+            return True
+    return False
+
+
+def check_code_for_cuda_kernel_launches(code, filename=None):
+    """Checks code for CUDA kernel launches without cuda error checks.
+
+    Args:
+        filename - Filename of file containing the code. Used only for display
+                   purposes, so you can put anything here.
+        code     - The code to check
+
+    Returns:
+        The number of unsafe kernel launches in the code
+    """
+    if filename is None:
+        filename = "##Python Function Call##"
+
+    # We break the code apart and put it back together to add
+    # helpful line numberings for identifying problem areas
+    code = enumerate(code.split("\n"))                             # Split by line breaks
+    code = [f"{lineno}: {linecode}" for lineno, linecode in code]  # Number the lines
+    code = '\n'.join(code)                                         # Put it back together
+
+    num_launches_without_checks = 0
+    for m in kernel_launch_start.finditer(code):
+        end_paren = find_matching_paren(code, m.end() - 1)
+        if has_check.match(code, end_paren):
+            num_launches_without_checks += 1
+            context = code[m.start():end_paren + 1]
+            print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
+
+    return num_launches_without_checks
+
+
+def check_file(filename):
+    """Checks a file for CUDA kernel launches without cuda error checks
+
+    Args:
+        filename - File to check
+
+    Returns:
+        The number of unsafe kernel launches in the file
+    """
+    if not (filename.endswith((".cu", ".cuh"))):
+        return 0
+    if should_exclude_file(filename):
+        return 0
+    with open(filename) as f:
+        contents = f.read()
+        unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
+    return unsafeCount
+
+
+def check_cuda_kernel_launches():
+    """Checks all pytorch code for CUDA kernel launches without cuda error checks
+
+    Returns:
+        The number of unsafe kernel launches in the codebase
+    """
+    torch_dir = os.path.dirname(os.path.realpath(__file__))
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent torch
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent caffe2
+
+    kernels_without_checks = 0
+    files_without_checks = []
+    for root, dirnames, filenames in os.walk(torch_dir):
+        # `$BASE/build` and `$BASE/torch/include` are generated
+        # so we don't want to flag their contents
+        if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
+            # Curtail search by modifying dirnames and filenames in place
+            # Yes, this is the way to do this, see `help(os.walk)`
+            dirnames[:] = []
+            continue
+
+        for x in filenames:
+            filename = os.path.join(root, x)
+            file_result = check_file(filename)
+            if file_result > 0:
+                kernels_without_checks += file_result
+                files_without_checks.append(filename)
+
+    if kernels_without_checks > 0:
+        count_str = f"Found {kernels_without_checks} instances in " \
+                    f"{len(files_without_checks)} files where kernel " \
+                    "launches didn't have checks."
+        print(count_str, file=sys.stderr)
+        print("Files without checks:", file=sys.stderr)
+        for x in files_without_checks:
+            print(f"\t{x}", file=sys.stderr)
+        print(count_str, file=sys.stderr)
+
+    return kernels_without_checks
+
+
+if __name__ == "__main__":
+    unsafe_launches = check_cuda_kernel_launches()
+    sys.exit(0 if unsafe_launches == 0 else 1)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6797cb9a2cfee235d6dce9c7e1bb11ad49ca1d9d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/codegen/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd14b85a21915ddf8ab415f3bf5dc6e79db14dfc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
@@ -0,0 +1,112 @@
+# mypy: ignore-errors
+
+# Owner(s): ["oncall: distributed"]
+
+
+import torch
+import torch.nn as nn
+
+
+class UnitModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.seq(self.l1(x)))
+
+
+class CompositeModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.u2(self.u1(self.l1(x))))
+
+
+class UnitParamModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+
+    def forward(self, x):
+        return torch.mm(self.seq(self.l(x)), self.p)
+
+
+class CompositeParamModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+        self.register_buffer(
+            "buffer", torch.randn((100, 100), device=device), persistent=True
+        )
+
+    def forward(self, x):
+        a = self.u2(self.u1(self.l(x)))
+        b = self.p
+        return torch.mm(a, b)
+
+
+class FakeSequential(nn.Module):
+    # Define this class to achieve a desired nested wrapping using the module
+    # wrap policy with `nn.Sequential`
+    def __init__(self, *modules: tuple[nn.Module, ...]) -> None:
+        super().__init__()
+        self._module_sequence = list(modules)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        for module in self._module_sequence:
+            x = module(x)
+        return x
+
+
+class NestedSequentialModel(nn.Module):
+    def __init__(self, device: torch.device) -> None:
+        super().__init__()
+        # This nested structure exercises traversal order to catch differences
+        # between valid traversals (e.g. BFS and DFS variations).
+        self.seq1 = nn.Sequential(
+            nn.Linear(1, 1, device=device),
+            FakeSequential(
+                nn.Linear(1, 1, device=device),
+                nn.ReLU(),
+                FakeSequential(
+                    nn.Linear(1, 1, device=device),
+                ),
+                nn.ReLU(),
+            ),
+            nn.Linear(1, 2, device=device),
+        )
+        self.lin = nn.Linear(2, 2, device=device)
+        self.seq2 = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(2, 3, device=device),
+            FakeSequential(
+                nn.Linear(3, 2, bias=False, device=device),
+                nn.Linear(2, 4, bias=False, device=device),
+            ),
+        )
+
+        # FIXME(rec): forward() is not a method, it's a local function inside __init__
+        # that is never used. It should probabkly be outdented by four spaces, or removed.
+        def forward(self, x: torch.Tensor) -> torch.Tensor:
+            return self.seq2(self.lin(self.seq1(x)))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df79fa00f81b92492fcd6f23a99f595695b8421
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
@@ -0,0 +1,1958 @@
+# mypy: ignore-errors
+
+import faulthandler
+import functools
+import itertools
+import logging
+import multiprocessing
+import operator
+import os
+import queue
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import types
+import unittest
+from collections.abc import Callable
+from contextlib import contextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from enum import Enum
+from functools import partial, reduce, wraps
+from io import StringIO
+from typing import Any, NamedTuple, Optional, Union
+from unittest.mock import patch
+
+import torch
+import torch._dynamo.test_case
+import torch.cuda.nccl
+import torch.distributed as c10d
+import torch.nn as nn
+from torch._C._autograd import DeviceType
+from torch._C._distributed_c10d import _SymmetricMemory
+from torch._logging._internal import trace_log
+from torch.testing._internal import common_utils
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    find_free_port,
+    IS_SANDCASTLE,
+    LazyVal,
+    retry_on_connect_failures,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_WITH_ROCM,
+    TEST_WITH_TSAN,
+    TEST_XPU,
+    TestCase,
+)
+from torch.testing._internal.distributed.multi_threaded_pg import (
+    _install_threaded_pg,
+    _uninstall_threaded_pg,
+    ProcessLocalGroup,
+)
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
+DDP_RANK_DEVICES = ["cuda", "xpu"]
+HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
+
+
+class TestSkip(NamedTuple):
+    exit_code: int
+    message: str
+
+
+TEST_SKIPS = {
+    "backend_unavailable": TestSkip(
+        72, "Skipped because distributed backend is not available."
+    ),
+    "small_worldsize": TestSkip(73, "Skipped due to small world size."),
+    "odd_worldsize": TestSkip(87, "Skipped due to odd world size."),
+    "no_cuda": TestSkip(74, "CUDA is not available."),
+    "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"),
+    "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"),
+    "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"),
+    "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"),
+    "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"),
+    "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"),
+    "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"),
+    "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"),
+    "nccl": TestSkip(76, "c10d not compiled with NCCL support"),
+    "skipIfRocm": TestSkip(78, "Test skipped for ROCm"),
+    "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"),
+    "generic": TestSkip(
+        86, "Test skipped at subprocess level, look at subprocess log for skip reason"
+    ),
+    "importerror": TestSkip(88, "Test skipped due to missing import"),
+    "no_accelerator": TestSkip(89, "accelerator is not available."),
+}
+
+
+@dataclass
+class DistTestCases:
+    # Backends that do not support a specific collective
+    skip_collective = {}
+    skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
+    skip_collective["reduce"] = set()
+    skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
+    skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
+
+    # Sets showing that something is implemented
+    backend_feature = {}
+    backend_feature["gpu"] = {"nccl", "gloo", "ucc"}
+    backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
+    backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
+    backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
+    backend_feature["plugin"] = set()
+    if TEST_HPU:
+        backend_feature["hpu"] = {"hccl"}
+    if TEST_XPU:
+        backend_feature["xpu"] = {"xccl"}
+
+
+def requires_ddp_rank(device):
+    return device in DDP_RANK_DEVICES
+
+
+def skip_if_no_gpu(func):
+    """Skips if the world size exceeds the number of GPUs, ensuring that if the
+    test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if not (TEST_CUDA or TEST_HPU or TEST_XPU):
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        world_size = int(os.environ["WORLD_SIZE"])
+        if TEST_CUDA and torch.cuda.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_HPU and torch.hpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_XPU and torch.xpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+# TODO (kwen2501): what is the purpose of this decorator?  Tests with this
+# decorator were always skipped. So they may be outdated already.
+# Oct 2024: bumping the small-world criteria to < 8, as we are increasing the
+# number of GPUs in CI from 2 to 4, and we need to continue skipping those tests
+# to keep CI green. But this is just a temporary solution. We should clean up
+# those tests somehow.
+def skip_if_small_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) < 8:
+            sys.exit(TEST_SKIPS["small_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def skip_if_odd_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1:
+            sys.exit(TEST_SKIPS["odd_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def require_n_gpus_for_nccl_backend(n, backend):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend == "nccl" and torch.cuda.device_count() < n:
+                sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
+            else:
+                return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+def import_transformers_or_skip():
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            try:
+                from transformers import AutoModelForMaskedLM, BertConfig  # noqa: F401
+
+                return func(*args, **kwargs)
+            except ImportError:
+                sys.exit(TEST_SKIPS["importerror"].exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def at_least_x_gpu(x):
+    if TEST_CUDA and torch.cuda.device_count() >= x:
+        return True
+    if TEST_HPU and torch.hpu.device_count() >= x:
+        return True
+    if TEST_XPU and torch.xpu.device_count() >= x:
+        return True
+    return False
+
+
+def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
+    _handle_test_skip = getattr(args[0], "_handle_test_skip", None)
+    if len(args) == 0 or _handle_test_skip is None:
+        return False
+    _handle_test_skip(msg)
+    return True
+
+
+def skip_if_lt_x_gpu(x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_HPU and torch.hpu.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_XPU and torch.xpu.device_count() >= x:
+                return func(*args, **kwargs)
+            test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
+            if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
+                sys.exit(test_skip.exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def requires_world_size(n: int):
+    """
+    Decorator to request a specific world size for a test. The test harness can
+    read this attribute to set the number of ranks to spawn. If there are fewer
+    than `n` CUDA devices available, the test should be skipped by the harness.
+
+    Usage:
+        @require_world_size(3)
+        def test_something(self):
+            ...
+    """
+
+    def decorator(func):
+        func._required_world_size = n
+        available = torch.cuda.device_count()
+        return unittest.skipUnless(
+            available >= n, f"requires {n} GPUs, found {available}"
+        )(func)
+
+    return decorator
+
+
+def get_required_world_size(obj: Any, default: int) -> int:
+    """
+    Returns the requested world size for the currently running unittest method on `obj`
+    if annotated via `@require_world_size(n)`, else returns `default`.
+    """
+    try:
+        # Try MultiProcessTestCase helper first, then unittest fallback
+        test_name = (
+            obj._current_test_name()  # type: ignore[attr-defined]
+            if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
+            else obj._testMethodName
+        )
+        fn = getattr(obj, test_name)
+        value = fn._required_world_size
+        return int(value)
+    except Exception:
+        return default
+
+
+# This decorator helps avoiding initializing cuda while testing other backends
+def nccl_skip_if_lt_x_gpu(backend, x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend != "nccl":
+                return func(*args, **kwargs)
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
+            if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
+                sys.exit(test_skip.exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def verify_ddp_error_logged(model_DDP, err_substr):
+    # Verify error was logged in ddp_logging_data.
+    ddp_logging_data = model_DDP._get_ddp_logging_data()
+    assert "iteration" in ddp_logging_data
+    assert "has_error" in ddp_logging_data
+    assert "error" in ddp_logging_data
+    logging_err = ddp_logging_data["error"]
+    # Remove C++ stacktrace if needed.
+    actual = (
+        err_substr
+        if err_substr.find("\nException raised from ") == -1
+        else err_substr.split("\nException raised from ")[0]
+    )
+    assert actual in logging_err, (
+        f"Did not find expected {actual} in ddp logging data error: {logging_err}"
+    )
+
+
+def with_nccl_blocking_wait(func):
+    """
+    Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of
+    this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for
+    the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and
+    TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values.
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING
+        try:
+            cached_nccl_async_error_handling: Union[str, None] = os.environ[
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
+            ]
+            del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
+        except KeyError:
+            # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset
+            cached_nccl_async_error_handling = None
+
+        # Save val of TORCH_NCCL_BLOCKING_WAIT and set it.
+        try:
+            cached_nccl_blocking_wait: Union[str, None] = os.environ[
+                "TORCH_NCCL_BLOCKING_WAIT"
+            ]
+        except KeyError:
+            cached_nccl_blocking_wait = None
+        finally:
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+
+        try:
+            ret = func(*args, **kwargs)
+            return ret
+        finally:
+            # restore old values.
+            if cached_nccl_async_error_handling is not None:
+                os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
+                    cached_nccl_async_error_handling
+                )
+
+            if cached_nccl_blocking_wait is not None:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
+
+    return wrapper
+
+
+def with_dist_debug_levels(levels):
+    """
+    Runs a test for each distributed debug level specified in levels.
+    """
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None)
+            for level in levels:
+                os.environ["TORCH_DISTRIBUTED_DEBUG"] = level
+                c10d.set_debug_level_from_env()
+                ret = func(*args, **kwargs)
+                c10d.barrier()
+                if old_level is not None:
+                    os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level
+            # Only returns test return for last test, but since these are
+            # unittests the return value is not really used and earlier tests
+            # would've raised had they failed.
+            return ret
+
+        return wrapper
+
+    return decorator
+
+
+def requires_gloo():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_gloo_available(),
+        "c10d was not compiled with the Gloo backend",
+    )
+
+
+def requires_nccl_version(version, msg):
+    if not TEST_CUDA:
+        return lambda f: f
+    if not c10d.is_nccl_available():
+        return skip_but_pass_in_sandcastle(
+            "c10d was not compiled with the NCCL backend",
+        )
+    else:
+        return skip_but_pass_in_sandcastle_if(
+            torch.cuda.nccl.version() < version,
+            f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
+        )
+
+
+def requires_nccl_shrink():
+    """
+    Require NCCL shrink support (NCCL available and version >= 2.27).
+    """
+    return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
+
+
+def requires_nccl():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_nccl_available(),
+        "c10d was not compiled with the NCCL backend",
+    )
+
+
+def requires_ucc():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_ucc_available(),
+        "c10d was not compiled with the UCC backend",
+    )
+
+
+def requires_mpi():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_mpi_available(),
+        "c10d was not compiled with the MPI backend",
+    )
+
+
+def requires_accelerator_dist_backend(backends=None):
+    """
+    Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
+
+    Args:
+        backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
+                                       If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
+
+    Returns:
+        callable: A decorator that skips the test if no specified accelerator backend is available.
+    """
+    if backends is None:
+        backends = ACCELERATOR_DIST_BACKENDS
+
+    backend_available = any(
+        {
+            "nccl": c10d.is_nccl_available,
+            "xccl": c10d.is_xccl_available,
+            "hccl": lambda: TEST_HPU,
+        }.get(backend, lambda: False)()
+        for backend in backends
+    )
+
+    return skip_but_pass_in_sandcastle_if(
+        not backend_available,
+        f"No accelerator communication backend available among {backends}",
+    )
+
+
+def requires_multicast_support():
+    has_multicast_support = (
+        torch.cuda.is_available()
+        and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
+    )
+    return skip_but_pass_in_sandcastle_if(
+        not has_multicast_support,
+        "multicast support is not available",
+    )
+
+
+def evaluate_platform_supports_symm_mem():
+    if TEST_CUDA:
+        if TEST_WITH_ROCM:
+            arch_list = ["gfx942", "gfx950"]
+            for arch in arch_list:
+                if arch in torch.cuda.get_device_properties(0).gcnArchName:
+                    return True
+            return False
+        else:
+            return True
+    else:
+        return False
+
+
+PLATFORM_SUPPORTS_SYMM_MEM: bool = LazyVal(
+    lambda: evaluate_platform_supports_symm_mem()
+)
+
+
+def skip_if_rocm_multiprocess(func):
+    """Skips a test for ROCm multiprocess UTs"""
+    return unittest.skipIf(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func)
+
+
+def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
+    """Skips a test for given ROCm archs - multiprocess UTs"""
+
+    def decorator(func):
+        reason = None
+        if TEST_WITH_ROCM:
+            prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
+            if prop in arch:
+                reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
+
+        return unittest.skipIf(reason is not None, reason)(func)
+
+    return decorator
+
+
+def skip_if_rocm_ver_lessthan_multiprocess(version=None):
+    """Skips a test for ROCm based on ROCm ver - multiprocess UTs"""
+
+    def decorator(func):
+        reason = None
+        if TEST_WITH_ROCM:
+            rocm_version = str(torch.version.hip)
+            rocm_version = rocm_version.split("-", maxsplit=1)[0]  # ignore git sha
+            rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+            if (
+                rocm_version_tuple is None
+                or version is None
+                or rocm_version_tuple < tuple(version)
+            ):
+                reason = f"skip_if_rocm_ver_lessthan_multiprocess: ROCm {rocm_version_tuple} is available but {version} required"
+
+        return unittest.skipIf(reason is not None, reason)(func)
+
+    return decorator
+
+
+def skip_if_win32():
+    return skip_but_pass_in_sandcastle_if(
+        sys.platform == "win32",
+        "This unit test case is not supported on Windows platform",
+    )
+
+
+def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
+    """
+    Returns True if the device's compute capability is (major, minor) or higher.
+    Error out if the device is not a CUDA device.
+    Returns False if device is a RoCM device.
+    Returns True if device is a non-CUDA device.
+    """
+    if device.type != "cuda":
+        return True
+
+    if torch.version.hip is not None:
+        # ROCm devices may have different compute capability codes
+        return False
+
+    return torch.cuda.get_device_capability(device) >= (major, minor)
+
+
+@retry_on_connect_failures
+def create_tcp_store(
+    addr="localhost",
+    world_size=1,
+    is_master=True,
+    timeout=timedelta(minutes=5),
+    wait_for_workers=True,
+    jit_class=False,
+    use_libuv=True,
+):
+    """
+    Creates a TCP store. Retries if the chosen port is already in use.
+    """
+    port = find_free_port()
+    if jit_class:
+        timeout_millisecond = int(timeout / timedelta(milliseconds=1))
+        return torch.classes.dist_c10d.TCPStore(
+            addr, port, world_size, is_master, timeout_millisecond
+        )
+    else:
+        return c10d.TCPStore(
+            addr,
+            port,
+            world_size,
+            is_master,
+            wait_for_workers=wait_for_workers,
+            use_libuv=use_libuv,
+        )
+
+
+if TEST_WITH_TSAN:
+    # TSAN runs much slower.
+    TIMEOUT_DEFAULT = 500
+else:
+    TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
+TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
+
+
+# https://github.com/pytorch/pytorch/issues/75665
+if TEST_WITH_ROCM:
+    TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
+
+
+def create_device(interface=None, lazy_init: bool = False):
+    if sys.platform == "win32" or interface is None:
+        return c10d.ProcessGroupGloo.create_device(
+            hostname="127.0.0.1", lazy_init=lazy_init
+        )
+    else:
+        return c10d.ProcessGroupGloo.create_device(
+            interface=interface, lazy_init=lazy_init
+        )
+
+
+def get_timeout(test_id) -> int:
+    return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT)
+
+
+@contextmanager
+def captured_output():
+    new_out, new_err = StringIO(), StringIO()
+    old_out, old_err = sys.stdout, sys.stderr
+    try:
+        sys.stdout, sys.stderr = new_out, new_err
+        yield sys.stdout, sys.stderr
+    finally:
+        sys.stdout, sys.stderr = old_out, old_err
+
+
+def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1):
+    """
+    Generate a number of basic test cases for sparse reduction.
+    These cover tensors with a varying number of sparse dimensions and a varying
+    number of dense dimensions. The only reduction operation we support is sum.
+    """
+
+    def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0):
+        # First sparse dimension is [0..rank].
+        # Subsequent dimensions are always 0, so we know there is
+        # a non-empty intersection between any two sparse tensors.
+        indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1))
+        shape = [world_size] + [2 for _ in range(dense_dims)]
+        for _ in range(sparse_dims - 1):
+            indices = torch.cat((indices, torch.zeros(1, rank + 1)))
+            shape.append(world_size)
+        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
+        return torch.sparse_coo_tensor(indices, values, shape)
+
+    def compute_sum(fn, world_size: int):
+        return reduce(
+            operator.add, [fn(rank, world_size) for rank in range(world_size)]
+        )
+
+    return [
+        (
+            [
+                fn(num_inputs * rank + i, num_inputs * world_size)
+                for i in range(num_inputs)
+            ],
+            [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)],
+        )
+        for fn in [
+            partial(generate, sparse_dims=1),
+            partial(generate, sparse_dims=2),
+            partial(generate, sparse_dims=3),
+            partial(generate, dense_dims=1),
+            partial(generate, dense_dims=2),
+            partial(generate, dense_dims=3),
+        ]
+    ]
+
+
+# HELPER FOR MULTIGPU TESTS
+def init_multigpu_helper(world_size: int, backend: str):
+    """Multigpu tests are designed to simulate the multi nodes with multi
+    GPUs on each node. Nccl backend requires equal #GPUs in each process.
+    On a single node, all visible GPUs are evenly
+    divided to subsets, each process only uses a subset.
+    """
+    nGPUs = torch.cuda.device_count()
+    if TEST_HPU:
+        nGPUs = torch.hpu.device_count()
+    if TEST_XPU:
+        nGPUs = torch.xpu.device_count()
+    visible_devices = range(nGPUs)
+
+    # If rank is less than or equal to number of available GPU's
+    # then each rank can be mapped to corresponding GPU.
+    nGPUs_per_process = 1
+    if world_size > nGPUs:
+        nGPUs_per_process = nGPUs // world_size
+    rank_to_GPU = {
+        i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
+        for i in range(world_size)
+    }
+    return rank_to_GPU
+
+
+tmp_dir: Optional[tempfile.TemporaryDirectory] = None
+
+
+def initialize_temp_directories(init_method: Optional[str] = None) -> None:
+    global tmp_dir
+    tmp_dir = tempfile.TemporaryDirectory()
+    os.environ["TEMP_DIR"] = tmp_dir.name
+    os.mkdir(os.path.join(tmp_dir.name, "barrier"))
+    os.mkdir(os.path.join(tmp_dir.name, "test_dir"))
+    init_dir_path = os.path.join(tmp_dir.name, "init_dir")
+    os.mkdir(init_dir_path)
+    # Set init method if specified.
+    if init_method is not None:
+        os.environ["INIT_METHOD"] = init_method
+    else:
+        os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join(
+            init_dir_path, "shared_init_file"
+        )
+
+
+def cleanup_temp_dir() -> None:
+    if tmp_dir is not None:
+        tmp_dir.cleanup()
+
+
+# Most tests operate with this worldsize
+DEFAULT_WORLD_SIZE = 4
+
+# [How does MultiProcessTestCase work?]
+# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
+# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
+# example which inherits from this class. Its `Setup()` methods calls into
+# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
+# subprocesses. During the spawn, the main process passes the test name to
+# subprocesses, and the name is acquired from self.id(). The subprocesses
+# then use the provided test function name to retrieve the function attribute
+# from the test instance and run it. The main process simply waits for all
+# subprocesses to join.
+
+
+class MultiProcessTestCase(TestCase):
+    MAIN_PROCESS_RANK = -1
+    # This exit code is used to indicate that the test code had an error and
+    # exited abnormally. There are certain tests that might use sys.exit() to
+    # simulate failures and in those cases, we can't have an exit code of 0,
+    # but we still want to ensure we didn't run into any other errors.
+    TEST_ERROR_EXIT_CODE = 10
+
+    # do not early terminate for distributed tests.
+    def _should_stop_test_suite(self) -> bool:
+        return False
+
+    # Many test cases init a process group but do not destroy it.  This property
+    # determines whether this base test class should call
+    # `destroy_process_group` on behalf of the test. Its value is customizable
+    # by derived TestCase's but it is a pan-TestCase value (cannot be customized
+    # for each test).
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        return True
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                self._join_processes(fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def setUp(self) -> None:
+        super().setUp()
+
+        # Used for tests that are expected to return a non-0 exit code, such as
+        # SIGABRT thrown by watchdog.
+        self.special_return_code_checks: dict = {}
+
+        # Used for tests that may return any exit code, which makes it hard to
+        # check. This is rare, use with caution.
+        self.skip_return_code_checks: list = []
+
+        self.processes = []  # type: ignore[var-annotated]
+        self.rank = self.MAIN_PROCESS_RANK
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            self.file_name = f.name
+        # pid to pipe consisting of error message from process.
+        self.pid_to_pipe = {}  # type: ignore[var-annotated]
+
+    def tearDown(self) -> None:
+        super().tearDown()
+        for p in self.processes:
+            p.terminate()
+        # Each Process instance holds a few open file descriptors. The unittest
+        # runner creates a new TestCase instance for each test method and keeps
+        # it alive until the end of the entire suite. We must thus reset the
+        # processes to prevent an effective file descriptor leak.
+        self.processes = []
+
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def _start_processes(self, proc) -> None:
+        self.processes = []
+        for rank in range(int(self.world_size)):
+            parent_conn, child_conn = torch.multiprocessing.Pipe()
+            process = proc(
+                target=self.__class__._run,
+                name="process " + str(rank),
+                args=(
+                    rank,
+                    self._current_test_name(),
+                    self.file_name,
+                    child_conn,
+                ),
+                kwargs={
+                    "fake_pg": getattr(self, "fake_pg", False),
+                },
+            )
+            process.start()
+            logger.info("Started process %s with pid %s", rank, process.pid)
+            self.pid_to_pipe[process.pid] = parent_conn
+            self.processes.append(process)
+
+    def _spawn_processes(self) -> None:
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            pass
+
+        proc = torch.multiprocessing.get_context("spawn").Process
+        self._start_processes(proc)
+
+    class Event(Enum):
+        GET_TRACEBACK = 1
+
+    @staticmethod
+    def _event_listener(parent_pipe, signal_pipe, rank: int):
+        logger.debug("Starting event listener thread for rank %s", rank)
+        while True:
+            ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
+
+            if parent_pipe in ready_pipes:
+                if parent_pipe.closed:
+                    logger.debug(
+                        "Pipe closed for process %s, stopping event listener thread",
+                        rank,
+                    )
+                    return
+
+                event = parent_pipe.recv()
+                logger.info("Received event %s on process %s", event, rank)
+
+                if event == MultiProcessTestCase.Event.GET_TRACEBACK:
+                    # Return traceback to the parent process.
+                    with tempfile.NamedTemporaryFile(mode="r+") as tmp_file:
+                        faulthandler.dump_traceback(tmp_file)
+                        # Flush buffers and seek to read from the beginning
+                        tmp_file.flush()
+                        tmp_file.seek(0)
+                        parent_pipe.send(tmp_file.read())
+
+                        logger.info("Process %s sent traceback", rank)
+
+            if signal_pipe in ready_pipes:
+                return
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        # Start event listener thread.
+        signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False)
+        event_listener_thread = threading.Thread(
+            target=MultiProcessTestCase._event_listener,
+            args=(parent_pipe, signal_recv_pipe, self.rank),
+            daemon=True,
+        )
+        event_listener_thread.start()
+        if sys.platform != "win32" and sys.platform != "darwin":
+            # Register signal handler to dump stack traces on FATALs.
+            # Windows and MacOS do not support the signal handlers.
+            torch._C._set_print_stack_traces_on_fatal_signal(True)
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+        common_utils.set_rng_seed()
+
+        # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
+        # We're retrieving a corresponding test and executing it.
+        try:
+            getattr(self, test_name)()
+        except unittest.SkipTest as se:
+            logger.info(  # noqa: G200
+                "Process %s skipping test %s for following reason: %s",
+                self.rank,
+                test_name,
+                str(se),
+            )
+            sys.exit(TEST_SKIPS["generic"].exit_code)
+        except Exception:
+            logger.error(
+                "Caught exception: \n%s exiting process %s with exit code: %s",
+                traceback.format_exc(),
+                self.rank,
+                MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
+            )
+            # Send error to parent process.
+            parent_pipe.send(traceback.format_exc())
+            sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
+        finally:
+            if signal_send_pipe is not None:
+                signal_send_pipe.send(None)
+
+            assert event_listener_thread is not None
+            event_listener_thread.join()
+            # Close pipe after done with test.
+            parent_pipe.close()
+
+        if self.destroy_pg_upon_exit:
+            try:
+                # Some tests do destroy the pgs, and destroy can't be called twice.
+                # This avoids spewing warnings about improperly shutting down.
+                c10d.destroy_process_group()
+            except (AssertionError, ValueError):
+                pass
+
+    def _get_timedout_process_traceback(self) -> None:
+        pipes = []
+        for i, process in enumerate(self.processes):
+            if process.exitcode is None:
+                pipe = self.pid_to_pipe[process.pid]
+                try:
+                    pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
+                    pipes.append((i, pipe))
+                except ConnectionError:
+                    logger.exception(
+                        "Encountered error while trying to get traceback for process %s",
+                        i,
+                    )
+
+        # Wait for results.
+        for rank, pipe in pipes:
+            try:
+                # Wait for traceback
+                if pipe.poll(5):
+                    if pipe.closed:
+                        logger.info(
+                            "Pipe closed for process %s, cannot retrieve traceback",
+                            rank,
+                        )
+                        continue
+
+                    traceback = pipe.recv()
+                    logger.error(
+                        "Process %s timed out with traceback: \n\n%s", rank, traceback
+                    )
+                else:
+                    logger.error(
+                        "Could not retrieve traceback for timed out process: %s", rank
+                    )
+            except ConnectionError:
+                logger.exception(
+                    "Encountered error while trying to get traceback for process %s",
+                    rank,
+                )
+
+    def _join_processes(self, fn) -> None:
+        timeout = get_timeout(self.id())
+        start_time = time.time()
+        subprocess_error = False
+        try:
+            while True:
+                # check to see if any subprocess exited with an error early.
+                for i, p in enumerate(self.processes):
+                    # This is the exit code processes exit with if they
+                    # encountered an exception.
+                    if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
+                        print(
+                            f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes."
+                        )
+                        active_children = torch.multiprocessing.active_children()
+                        for ac in active_children:
+                            ac.terminate()
+                        subprocess_error = True
+                        break
+                if subprocess_error:
+                    break
+                # All processes have joined cleanly if they all a valid exitcode
+                if all(p.exitcode is not None for p in self.processes):
+                    break
+                # Check if we should time out the test. If so, we terminate each process.
+                elapsed = time.time() - start_time
+                if elapsed > timeout:
+                    self._get_timedout_process_traceback()
+                    print(
+                        f"Timing out after {timeout} seconds and killing subprocesses."
+                    )
+                    for p in self.processes:
+                        p.terminate()
+                    break
+                # Sleep to avoid excessive busy polling.
+                time.sleep(0.1)
+
+            elapsed_time = time.time() - start_time
+            self._check_return_codes(fn, elapsed_time)
+        finally:
+            # Close all pipes
+            for pipe in self.pid_to_pipe.values():
+                pipe.close()
+
+    def _check_return_codes(self, fn, elapsed_time) -> None:
+        """
+        Checks that the return codes of all spawned processes match, and skips
+        tests if they returned a return code indicating a skipping condition.
+        """
+        # If no processes are spawned, there is nothing to check.
+        if not self.processes:
+            logger.warning(
+                "Note: no subprocesses were spawned, test was likely skipped."
+            )
+            return
+
+        first_process = self.processes[0]
+        # first, we check if there are errors in actual processes
+        # (via TEST_ERROR_EXIT CODE), and raise an exception for those.
+        # the reason we do this is to attempt to raise a more helpful error
+        # message than "Process x terminated/timed out"
+        # TODO: we should pipe the exception of the failed subprocess here.
+        # Currently, the actual exception is displayed as a logging output.
+        errored_processes = [
+            (i, p)
+            for i, p in enumerate(self.processes)
+            if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
+        ]
+        if errored_processes:
+            error = ""
+            for i, process in errored_processes:
+                # Get error from pipe.
+                error_message = self.pid_to_pipe[process.pid].recv()
+                error += (
+                    f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} "
+                    f"and exception:\n{error_message}\n"
+                )
+
+            raise RuntimeError(error)
+        # If no process exited uncleanly, we check for timeouts, and then ensure
+        # each process exited cleanly.
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError(
+                    f"Process {i} terminated or timed out after {elapsed_time} seconds"
+                )
+
+        # Skip the test return code check
+        if fn in self.skip_return_code_checks:
+            return
+
+        for skip in TEST_SKIPS.values():
+            if first_process.exitcode == skip.exit_code:
+                if IS_SANDCASTLE:
+                    # Don't use unittest.skip to skip the test on sandcastle
+                    # since it creates tasks for skipped tests assuming there
+                    # is some follow-up needed. Instead just "pass" the test
+                    # with an appropriate message.
+                    logger.info(
+                        "Skipping %s on sandcastle for the following reason: %s",
+                        self.id(),
+                        skip.message,
+                    )
+                    return
+                else:
+                    raise unittest.SkipTest(skip.message)
+
+        # In most cases, we expect test to return exit code 0, standing for success.
+        expected_return_code = 0
+        # In some negative tests, we expect test to return non-zero exit code,
+        # such as watchdog throwing SIGABRT.
+        if fn in self.special_return_code_checks:
+            expected_return_code = self.special_return_code_checks[fn]
+
+        self.assertEqual(
+            first_process.exitcode,
+            expected_return_code,
+            msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
+        )
+
+    @property
+    def is_master(self) -> bool:
+        return self.rank == 0
+
+
+# Utility base class for distributed Multi Process Test cases
+# This abstracts the PG creation and deletion, the backends are selected based
+# on device type. The tests functions can be instantiated per device type using
+# common_device_type.instantiate_device_type_tests
+# other backends can add entry in backend() function
+class DistributedTestBase(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        os.environ["WORLD_SIZE"] = str(self.world_size)
+        self._spawn_processes()
+
+    def tearDown(self):
+        try:
+            torch.distributed.destroy_process_group()
+        except AssertionError:
+            pass
+        try:
+            os.remove(self.file_name)
+        except OSError:
+            pass
+
+    def backend(self, device) -> str:
+        if "cuda" in device:
+            return "nccl"
+        elif "hpu" in device:  # intel gaudi
+            return "hccl"
+        elif "xpu" in device:
+            return "xccl"
+        else:
+            return "gloo"
+
+    def create_pg(self, device, world_size=None):
+        if world_size is None:
+            world_size = self.world_size
+        num_visible_devices = torch.get_device_module(device).device_count()
+        store = torch.distributed.FileStore(self.file_name, num_visible_devices)
+        torch.distributed.init_process_group(
+            backend=self.backend(device),
+            world_size=world_size,
+            rank=self.rank,
+            store=store,
+        )
+        if "nccl" in self.backend(device) or "xccl" in self.backend(device):
+            torch.accelerator.set_device_index(self.rank)
+        return torch.distributed.distributed_c10d._get_default_group()
+
+    def rank_to_device(self, device):
+        num_visible_devices = torch.get_device_module(device).device_count()
+        return {i: [i % num_visible_devices] for i in range(self.world_size)}
+
+
+def run_subtests(
+    cls_inst,
+    subtest_config: dict[str, list[Any]],
+    test_fn: Callable,
+    *test_args,
+    **test_kwargs: Any,
+):
+    """
+    Runs a test function given by ``test_fn`` as a subtest according to the
+    configurations specified by ``subtest_config``. This amortizes the
+    costly setup overhead (including process spawn and initializing the
+    process group) over the subtests.
+
+    Args:
+        subtest_config (Dict[str, List[Any]]): A mapping from subtest
+            keyword argument name to a list of its possible values.
+        test_fn (Callable): A callable that runs the actual test.
+        test_args: Positional arguments to pass to ``test_fn``.
+        test_kwargs: Keyword arguments to pass to ``test_fn``.
+    """
+    # Convert the config mapping to a list to have a fixed order
+    subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
+    subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
+    subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
+    for values in itertools.product(*subtest_config_values):
+        # Map keyword to chosen value
+        subtest_kwargs = dict(zip(subtest_config_keys, values, strict=True))
+        with cls_inst.subTest(**subtest_kwargs):
+            torch._dynamo.reset()
+            test_fn(*test_args, **test_kwargs, **subtest_kwargs)
+            torch._dynamo.reset()
+        c10d.barrier()
+
+
+@functools.cache
+def has_efa() -> bool:
+    """
+    If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
+    Libfabric EFA interfaces and EFA software components installed,
+    see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html.
+    """
+
+    try:
+        return (
+            subprocess.run(
+                ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False
+            ).returncode
+            == 0
+        )
+    except FileNotFoundError:
+        pass
+    return False
+
+
+def tp_transports():
+    """
+    If the machine has Libfabric EFA interfaces and EFA software components installed it may cause
+    'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe
+    uses InfiniBand transport, so we exclude it from tensorpipe transports,
+    see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
+    """
+    return ["shm", "uv"] if has_efa() else None
+
+
+def spawn_threads_and_init_comms(
+    func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
+):
+    """
+    Wrapper to use with a test method
+    """
+    if func is None:
+        return partial(
+            spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
+        )
+
+    def _run_test_method_with_multi_threads(world_size, callback):
+        world = _install_threaded_pg()
+        global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        def worker(rank, world_pg, store):
+            c10d.init_process_group(
+                backend="threaded", rank=rank, world_size=world_size, store=store
+            )
+            try:
+                callback()
+            except BaseException as ex:  # noqa: B036
+                # Exceptions are handled in MultiThreadedTestCase
+                MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
+                ProcessLocalGroup.exception_handle(
+                    ex
+                )  # trigger _terminate event and awaken worker threads
+            finally:
+                if world_is_valid():
+                    c10d.destroy_process_group()
+
+        threads = []
+        for rank in range(world_size):
+            t = threading.Thread(target=worker, args=(rank, world, global_store))
+            t.start()
+            threads.append(t)
+
+        return threads
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        # TODO: get test name from kwargs
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        try:
+            threads = _run_test_method_with_multi_threads(
+                world_size, lambda: func(self, *args, **kwargs)
+            )
+            # join and error handling
+            MultiThreadedTestCase._join_threads(threads, func)
+        finally:
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+    return wrapper
+
+
+class MultiThreadedTestCase(TestCase):
+    """
+    Test runner that runs all tests with the in-proc process group using
+    multiple threads with the threaded process group.
+
+    Each test spawns world_size threads and run the test method in each thread.
+
+    Difference from regular MultiProcess test runner:
+    Must explicitly defines SetUp and call self._spawn_threads() to run the tests.
+    Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
+        to set up / tear down each thread when running each test.
+    No global state possible
+        How bad of a limitation is this?
+    """
+
+    exception_queue = queue.Queue()
+
+    MAIN_THREAD_RANK = -1
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_THREAD_RANK:
+                self._join_threads(self.threads, fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def perThreadSetUp(self):
+        # super().setUp()  # TestCase.setUp() calls torch.manual_seed()
+        pass
+
+    def perThreadTearDown(self):
+        pass
+
+    def setUp(self) -> None:
+        """
+        setUp only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadSetUp
+        """
+        super().setUp()
+        self.rank = self.MAIN_THREAD_RANK
+        self.threads = []
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+
+    def tearDown(self):
+        """
+        tearDown only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadTearDown
+        """
+        super().tearDown()
+        self.threads = []
+
+    def _spawn_threads(self):
+        """
+        class method to spawn threads and run test, use this method in the SetUp of your TestCase
+        """
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        test_name = self._current_test_name
+        # for each test case, we need to create thread local world, and a global store
+        world = _install_threaded_pg()
+        self.__class__.global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        if not world_is_valid():
+            raise RuntimeError("Invalid world")
+
+        for rank in range(self.world_size):
+            t = threading.Thread(
+                target=self.__class__._run, args=(test_name, rank, self.world_size)
+            )
+            t.start()
+            self.threads.append(t)
+
+    @classmethod
+    def _run(cls, test_name, rank, world_size, **kwargs):
+        self = cls(test_name)
+        self.rank = rank
+
+        # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make
+        # every thread have the same value. This would be relevant when we use op db tests, where it
+        # needs those states to be set i.e. using instantiate_device_type_tests()
+        # TODO: figure out a better way to do this
+        if hasattr(self, "_tls"):
+            self._tls = threading.local()
+            self._tls.precision = TestCase._precision
+            self._tls.rel_tol = TestCase._rel_tol
+
+        self.run_test_with_threaded_pg(test_name, rank, world_size)
+
+    def run_test_with_threaded_pg(self, test_name, rank, world_size):
+        """
+        Run the current test associated with `test_name` using the threaded process group.
+        """
+        c10d.init_process_group(
+            backend="threaded",
+            rank=rank,
+            world_size=world_size,
+            store=self.__class__.global_store,
+        )
+        self.perThreadSetUp()
+
+        try:
+            getattr(self, test_name)()
+        except BaseException as ex:  # noqa: B036
+            self.exception_queue.put((rank, sys.exc_info()))
+            ProcessLocalGroup.exception_handle(
+                ex
+            )  # trigger _terminate event and awaken worker threads
+        finally:
+            c10d.destroy_process_group()
+            self.perThreadTearDown()
+
+    @classmethod
+    def _join_threads(cls, threads, fn):
+        timeout = TIMEOUT_DEFAULT
+        try:
+            for idx, thread in enumerate(threads):
+                thread.join(max(0, timeout))
+                if thread.is_alive():
+                    MultiThreadedTestCase.exception_queue.put(
+                        (
+                            idx,
+                            (
+                                TimeoutError,
+                                TimeoutError(
+                                    f"Rank failed to join in under {timeout} seconds"
+                                ),
+                                None,
+                            ),
+                        )
+                    )
+            ProcessLocalGroup.reset()
+            failed_ranks = []
+            while not cls.exception_queue.empty():
+                failure = cls.exception_queue.get()
+                failed_ranks.append(failure)
+        finally:
+            _uninstall_threaded_pg()
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+        cls._check_return_codes(failed_ranks, timeout, fn)
+
+    @classmethod
+    def _check_return_codes(cls, failed_ranks, timeout, fn):
+        # Print based on exceptions raised from threads
+        #   SkipTest: print info for each thread
+        #   TimeoutError: raise RuntimeError for any timed out thread
+        #   Normal Exception: print error for each thread that raises exception
+        #   and raise a RuntimeError
+        error_msg = ""
+        skip_code = -1
+        for rank, exc_info in failed_ranks:
+            exc = exc_info[1]
+            if isinstance(exc, unittest.SkipTest):
+                logger.info(
+                    "Thread %s skipping test %s for following reason: %s",
+                    rank,
+                    fn,
+                    str(exc),
+                )
+                if skip_code < 0:
+                    skip_code = TEST_SKIPS["generic"].exit_code
+            elif isinstance(exc, TimeoutError):
+                msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n"
+                logger.error(msg)
+                raise RuntimeError(msg)
+            elif isinstance(exc, Exception):
+                msg = "".join(traceback.format_exception(*exc_info))
+                logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
+                error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
+            elif isinstance(exc, SystemExit):
+                if type(exc.code) is int and skip_code < 0:
+                    skip_code = exc.code
+
+        # check exceptions
+        if len(error_msg) > 0:
+            raise RuntimeError(error_msg)
+        # check skip
+        if skip_code > 0:
+            for skip in TEST_SKIPS.values():
+                if skip_code == skip.exit_code:
+                    if IS_SANDCASTLE:
+                        # "pass" the test with an appropriate message.
+                        logger.info(
+                            "Skipping %s on sandcastle for the following reason: %s",
+                            fn,
+                            skip.message,
+                        )
+                        return
+                    else:
+                        raise unittest.SkipTest(skip.message)
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    @property
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
+        """
+        The reason why we have this util function instead of
+        self.assertEqual is all threads are sharing one CPU RNG
+        so the assertion result is only reliable on rank 0
+        """
+        if self.rank == rank:
+            self.assertEqual(x, y, msg)
+
+    def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0):
+        if self.rank == rank:
+            self.assertNotEqual(x, y)
+
+
+class SaveForwardInputsModule(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.l = nn.Linear(100, 100)
+        self.forward_inputs = forward_inputs
+        self.cast_forward_inputs = cast_forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x)
+
+
+class SaveForwardInputsModel(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.forward_inputs = forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.c2(self.c1(x))
+
+
+@contextmanager
+def _dynamo_dist_per_rank_init(
+    rank, world_size, backend=None, init_pg=True, fake_pg=False
+):
+    # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
+    # Just manually implement the most important part of the dynamo behavior to reset/clear.
+    if not fake_pg:
+        torch.accelerator.set_device_index(rank)
+
+    device_type = (
+        acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
+    )
+    if backend is None:
+        backend = c10d.get_default_backend_for_device(device_type)
+
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "6789"
+    if init_pg:
+        if fake_pg:
+            store = torch.testing._internal.distributed.fake_pg.FakeStore()
+            c10d.init_process_group(
+                backend="fake",
+                world_size=world_size,
+                rank=rank,
+                store=store,
+            )
+        else:
+            c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
+    torch._dynamo.reset()
+    torch._dynamo.utils.counters.clear()
+    try:
+        yield
+    finally:
+        torch._dynamo.reset()
+        torch._dynamo.utils.counters.clear()
+        if init_pg:
+            c10d.destroy_process_group()
+
+
+class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
+    """
+    Test harness for single-process dynamo distributed tests,
+    initializes dist process group.
+
+    Prefer this for simple tests, as it's easier to debug.
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        # _exit_stack is set up in TestCase
+        cls._exit_stack.enter_context(
+            patch.dict(
+                os.environ,
+                {
+                    "MASTER_ADDR": "localhost",
+                    "MASTER_PORT": "12355",
+                },
+            )
+        )
+        cls.rank = 0
+        device = torch.accelerator.current_accelerator().type
+        cls.device = f"{device}:{cls.rank}"
+        cls.device_ids = None if device in cls.device else [cls.rank]
+        c10d.init_process_group(
+            c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
+        )
+
+    @classmethod
+    def tearDownClass(cls):
+        c10d.destroy_process_group()
+        super().tearDownClass()
+
+
+class DynamoDistributedMultiProcTestCase(DistributedTestBase):
+    """
+    Use this for tests that actually run on multiple GPUs.
+
+    Decorate tests with @skip_if_lt_x_gpu(ngpu)
+
+    Note: MultiProcTestCase spawns processes per test and is slow.
+    Prefer MultiThreadedTestCase for most tests. Perhaps use this one
+    sparingly for integration tests.
+    """
+
+    @property
+    def world_size(self) -> int:
+        return torch.accelerator.device_count()
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        trace_log.addHandler(logging.NullHandler())
+
+        # The rest is copypasta from MultiProcessTestCase._run
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+
+class MultiProcContinuousTest(TestCase):
+    # Class variables:
+    MAIN_PROCESS_RANK = -1
+    # number of test processes
+    world_size: int = -2  # unset state
+    # rank of the current process
+    rank: int = -2  # unset state
+    # Rendezvous file
+    rdvz_file: Optional[str] = None
+    # timeout configured per class
+    timeout: timedelta = timedelta(seconds=120)
+    # Poison pill for rest of tests if one of them fails
+    poison_pill: bool = False
+
+    @classmethod
+    def backend_str(cls) -> Optional[str]:
+        """
+        ProcessGroup backend str.
+        To be customized by sub test classes, e.g. "nccl".
+        Otherwise we return None -- lazily decided by tensor.
+        """
+        return None
+
+    # Please override if you intend to test on specific device type
+    @classmethod
+    def device_type(cls) -> str:
+        curr_device = torch.accelerator.current_accelerator()
+        if curr_device is None:
+            return "cpu"
+        return curr_device.type
+
+    @classmethod
+    def opts(cls, high_priority_stream=False):
+        """
+        ProcessGroup init options.
+        To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest
+        Here we return None.
+        """
+        return None
+
+    @classmethod
+    def _init_pg(cls, rank, world_size, rdvz_file):
+        assert rdvz_file is not None
+        # rank should be local_rank for tests running on <= 8 gpus which is how all these tests are designed
+        # and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without
+        # issuing a warning
+        os.environ["LOCAL_RANK"] = str(rank)
+        store = c10d.FileStore(rdvz_file, world_size)
+        # create nccl processgroup with opts
+        c10d.init_process_group(
+            backend=cls.backend_str(),
+            world_size=world_size,
+            rank=rank,
+            store=store,
+            pg_options=cls.opts(),
+            timeout=cls.timeout,
+        )
+        cls.pg = c10d.distributed_c10d._get_default_group()
+
+    @classmethod
+    def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        test_name = test_id.rsplit(".", maxsplit=1)[-1]
+        # Get the test function from the test class
+        self = cls(test_name)
+        self.rank = cls.rank
+        self.world_size = cls.world_size
+        test_fn = getattr(self, test_name)
+
+        # Ensure all the ranks use the same seed.
+        common_utils.set_rng_seed()
+
+        # Run the test function
+        test_fn(**kwargs)
+
+    @classmethod
+    def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
+        raised_exception = False
+        # Sub tests are going to access these values, check first
+        assert 0 <= rank < world_size
+        # set class variables for the test class
+        cls.rank = rank
+        cls.world_size = world_size
+
+        # Initialize the process group
+        cls._init_pg(rank, world_size, rdvz_file)
+
+        # End of bootstrap
+        logger.debug("Setup complete")
+
+        # Loop forever, waiting for a test name to run
+        while True:
+            test_id = task_queue.get()
+            logger.debug(f"Got test {test_id}")  # noqa: G004
+            # None means exit
+            if test_id is None:
+                break
+
+            # Run the test
+            try:
+                cls._run_test_given_id(test_id)
+                completion_queue.put(test_id)
+            except BaseException as ex:  # noqa: B036
+                if isinstance(ex, SystemExit):
+                    # Get exit code from the process
+                    exit_code = getattr(ex, "code", None)
+
+                    # Look up exit code in TEST_SKIPS to see if it is a valid skip
+                    skip_entry = next(
+                        (v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
+                        None,
+                    )
+
+                    # If we found an entry, we want to skip the test and the object back to the main process
+                    if skip_entry:
+                        completion_queue.put(unittest.SkipTest(skip_entry.message))
+                        # Skip exception handling below, move to main thread for processing the skip
+                        continue
+
+                raised_exception = True
+                # Send the exception and stack trace back to the dispatcher
+                exc_info = sys.exc_info()
+                tb_str = "".join(traceback.format_exception(*exc_info))
+                # Create a new exception with the original exception and traceback
+                enhanced_ex = RuntimeError(f"Exception in worker process:\n{tb_str}")
+                enhanced_ex.__cause__ = ex
+                completion_queue.put(enhanced_ex)
+
+        # Termination
+        logger.debug("Terminating ...")
+        # Calling destroy_process_group when workers have exceptions
+        # while others are doing collectives will cause a deadlock since
+        # it waits for enqueued collectives to finish.
+        # Only call this on a clean exit path
+        if not raised_exception:
+            c10d.destroy_process_group()
+
+    @classmethod
+    def _spawn_processes(cls, world_size) -> None:
+        cls.processes = []
+        cls.task_queues = []
+        cls.completion_queues = []
+        # Need a rendezvous file for `init_process_group` purpose.
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            cls.rdvz_file = f.name
+
+        # CUDA multiprocessing requires spawn instead of fork, to make sure
+        # child processes have their own memory space.
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            # The start method has already been set
+            pass
+
+        for rank in range(int(world_size)):
+            task_queue = torch.multiprocessing.Queue()
+            completion_queue = torch.multiprocessing.Queue()
+            process = torch.multiprocessing.Process(
+                target=cls._worker_loop,
+                name="process " + str(rank),
+                daemon=True,  # so that child processes will exit if parent decides to terminate
+                args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
+            )
+            process.start()
+            cls.processes.append(process)
+            cls.task_queues.append(task_queue)
+            cls.completion_queues.append(completion_queue)
+            logger.debug("Started process %s with pid %s", rank, process.pid)  # noqa: UP031
+
+    @classmethod
+    def setUpClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, before any test starts.
+        Set up the process group.
+        """
+        super().setUpClass()
+
+        # Use device count as world size
+        device_type = cls.device_type()
+        # If world_size is not set, use device count
+        if cls.world_size == -2:
+            cls.world_size = torch.get_device_module(device_type).device_count()
+            if cls.world_size == 0:
+                raise unittest.SkipTest(f"No {device_type} devices available")
+
+        logger.info(
+            f"Testing class {cls.__name__} on {cls.world_size} {device_type}"  # noqa: G004
+        )
+
+        cls._spawn_processes(cls.world_size)
+
+    @classmethod
+    def tearDownClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, after all tests finish.
+        Tear down the process group.
+        """
+        logger.debug(f"Joining {cls.world_size} workers")  # noqa: G004
+        # Enqueue "None" to all workers to tell them to exit
+        for task_queue in cls.task_queues:
+            task_queue.put(None)
+
+        # Wait for all workers to exit
+        for process in cls.processes:
+            process.join()
+
+        # Clear up the rendezvous file
+        try:
+            os.remove(cls.rdvz_file)
+        except OSError:
+            pass
+
+        logger.info(f"Class {cls.__name__} finished")  # noqa: G004
+        super().tearDownClass()
+
+    def setUp(self) -> None:
+        """
+        Test fixture. Run before each test.
+        """
+        super().setUp()
+
+        # I am the dispatcher
+        self.rank = self.MAIN_PROCESS_RANK
+
+        # If this test class hits an exception in one test, skip the rest of tests
+        if self.__class__.poison_pill:
+            raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}")
+
+        # Enqueue "current test" to all workers
+        for i, task_queue in enumerate(self.task_queues):
+            logger.debug(f"Sending Rank {i}: {self.id()}")  # noqa: G004
+            task_queue.put(self.id())
+
+    def _worker_run_main_wait(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                logger.debug(f"Waiting for workers to finish {self.id()}")  # noqa: G004
+                # Wait for the workers to finish the test
+                for i, completion_queue in enumerate(self.completion_queues):
+                    rv = completion_queue.get()
+                    if isinstance(rv, unittest.SkipTest):
+                        raise rv
+                    if isinstance(rv, BaseException):
+                        # Hit an exception, re-raise it in the main process.
+                        logger.warning(
+                            f"Detected failure from Rank {i} in: {self.id()}, "  # noqa: G004
+                            f"skipping rest of tests in Test class: {self.__class__.__name__}"  # noqa: G004
+                        )
+                        # Poison rest of tests (because ProcessGroup may be not
+                        # reusable now)
+                        self.__class__.poison_pill = True
+                        raise rv
+
+                    # Success
+                    assert rv == self.id()
+                    logger.debug(
+                        f"Main proc detected rank {i} finished {self.id()}"  # noqa: G004
+                    )
+            else:
+                # Worker just runs the test
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self._worker_run_main_wait(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..474bb689f0ad9bcd7ee171b68de22f7752b37e3c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
@@ -0,0 +1,227 @@
+# mypy: ignore-errors
+
+
+import torch
+
+
+# Functions and classes for describing the dtypes a function supports
+# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
+
+
+# Verifies each given dtype is a torch.dtype
+def _validate_dtypes(*dtypes):
+    for dtype in dtypes:
+        assert isinstance(dtype, torch.dtype)
+    return dtypes
+
+
+# class for tuples corresponding to a PyTorch dispatch macro
+class _dispatch_dtypes(tuple):
+    __slots__ = ()
+
+    def __add__(self, other):
+        assert isinstance(other, tuple)
+        return _dispatch_dtypes(tuple.__add__(self, other))
+
+
+_empty_types = _dispatch_dtypes(())
+
+
+def empty_types():
+    return _empty_types
+
+
+_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
+
+
+def floating_types():
+    return _floating_types
+
+
+_floating_types_and_half = _floating_types + (torch.half,)
+
+
+def floating_types_and_half():
+    return _floating_types_and_half
+
+
+def floating_types_and(*dtypes):
+    return _floating_types + _validate_dtypes(*dtypes)
+
+
+_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
+
+
+def floating_and_complex_types():
+    return _floating_and_complex_types
+
+
+def floating_and_complex_types_and(*dtypes):
+    return _floating_and_complex_types + _validate_dtypes(*dtypes)
+
+
+_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
+
+
+def double_types():
+    return _double_types
+
+
+# NB: Does not contain uint16/uint32/uint64 for BC reasons
+_integral_types = _dispatch_dtypes(
+    (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
+)
+
+
+def integral_types():
+    return _integral_types
+
+
+def integral_types_and(*dtypes):
+    return _integral_types + _validate_dtypes(*dtypes)
+
+
+_all_types = _floating_types + _integral_types
+
+
+def all_types():
+    return _all_types
+
+
+def all_types_and(*dtypes):
+    return _all_types + _validate_dtypes(*dtypes)
+
+
+_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
+
+
+def complex_types():
+    return _complex_types
+
+
+def complex_types_and(*dtypes):
+    return _complex_types + _validate_dtypes(*dtypes)
+
+
+_all_types_and_complex = _all_types + _complex_types
+
+
+def all_types_and_complex():
+    return _all_types_and_complex
+
+
+def all_types_and_complex_and(*dtypes):
+    return _all_types_and_complex + _validate_dtypes(*dtypes)
+
+
+_all_types_and_half = _all_types + (torch.half,)
+
+
+def all_types_and_half():
+    return _all_types_and_half
+
+
+_all_mps_types = (
+    _dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types
+)
+
+
+def all_mps_types():
+    return _all_mps_types
+
+
+def all_mps_types_and(*dtypes):
+    return _all_mps_types + _validate_dtypes(*dtypes)
+
+
+_float8_types = _dispatch_dtypes(
+    (
+        torch.float8_e4m3fn,
+        torch.float8_e4m3fnuz,
+        torch.float8_e5m2,
+        torch.float8_e5m2fnuz,
+    )
+)
+
+
+def float8_types():
+    return _float8_types
+
+
+def float8_types_and(*dtypes):
+    return _float8_types + _validate_dtypes(*dtypes)
+
+
+def all_types_complex_float8_and(*dtypes):
+    return _all_types + _complex_types + _float8_types + _validate_dtypes(*dtypes)
+
+
+def custom_types(*dtypes):
+    """Create a list of arbitrary dtypes"""
+    return _empty_types + _validate_dtypes(*dtypes)
+
+
+# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
+
+
+# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+def get_all_dtypes(
+    include_half=True,
+    include_bfloat16=True,
+    include_bool=True,
+    include_complex=True,
+    include_complex32=False,
+    include_qint=False,
+) -> list[torch.dtype]:
+    dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
+        include_half=include_half, include_bfloat16=include_bfloat16
+    )
+    if include_bool:
+        dtypes.append(torch.bool)
+    if include_complex:
+        dtypes += get_all_complex_dtypes(include_complex32)
+    if include_qint:
+        dtypes += get_all_qint_dtypes()
+    return dtypes
+
+
+def get_all_math_dtypes(device) -> list[torch.dtype]:
+    return (
+        get_all_int_dtypes()
+        + get_all_fp_dtypes(
+            include_half=device.startswith("cuda"), include_bfloat16=False
+        )
+        + get_all_complex_dtypes()
+    )
+
+
+def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]:
+    return (
+        [torch.complex32, torch.complex64, torch.complex128]
+        if include_complex32
+        else [torch.complex64, torch.complex128]
+    )
+
+
+def get_all_int_dtypes() -> list[torch.dtype]:
+    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
+
+
+def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]:
+    dtypes = [torch.float32, torch.float64]
+    if include_half:
+        dtypes.append(torch.float16)
+    if include_bfloat16:
+        dtypes.append(torch.bfloat16)
+    return dtypes
+
+
+def get_all_qint_dtypes() -> list[torch.dtype]:
+    return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
+
+
+float_to_corresponding_complex_type_map = {
+    torch.float16: torch.complex32,
+    torch.float32: torch.complex64,
+    torch.float64: torch.complex128,
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..74b3cdc78f2d93086cc82886ddf36f5c9cc40184
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
@@ -0,0 +1,1623 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["oncall: distributed"]
+
+import contextlib
+import os
+import re
+import sys
+import time
+import unittest
+import warnings
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+from contextlib import nullcontext
+from copy import deepcopy
+from enum import auto, Enum
+from functools import wraps
+from typing import Any, cast, no_type_check, Optional, Union
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._composable import checkpoint
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp import (
+    CPUOffload,
+    fully_shard,
+    FullyShardedDataParallel as FSDP,
+)
+from torch.distributed.fsdp._common_utils import TrainingState
+from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
+    FSDPParamGroup,
+    RegisterPostBackwardFunction,
+)
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+    BackwardPrefetch,
+    MixedPrecision,
+    ShardingStrategy,
+)
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
+from torch.distributed.tensor import distribute_tensor, DTensor, Shard
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    get_cycles_per_ms,
+    set_rng_seed,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_XPU,
+)
+from torch.utils._triton import has_triton
+
+
+DEVICE_COUNT = 4  # default
+
+if TEST_CUDA:
+    DEVICE_TYPE = "cuda"
+    DISTRIBUTED_BACKEND = "nccl"
+    DEVICE_COUNT = torch.cuda.device_count()
+elif TEST_HPU:
+    DEVICE_TYPE = "hpu:0"
+    DISTRIBUTED_BACKEND = "hccl"
+elif TEST_XPU:
+    DEVICE_TYPE = "xpu"
+    DISTRIBUTED_BACKEND = "xccl"
+    DEVICE_COUNT = torch.xpu.device_count()
+else:
+    DEVICE_TYPE = "cpu"
+    DISTRIBUTED_BACKEND = "gloo"
+    DEVICE_COUNT = 1
+
+
+class FSDPInitMode(Enum):
+    # No FSDP wrapping
+    NO_FSDP = auto()
+    # FSDP recursive wrapping
+    RECURSIVE = auto()
+    # TODO: FSDP non-recursive wrapping
+    # NONRECURSIVE = auto()
+
+
+class DEVICEInitMode(Enum):
+    # Move model to DEVICE before passing to the FSDP constructor
+    DEVICE_BEFORE = auto()
+    # Move model to DEVICE after passing to the FSDP constructor
+    DEVICE_AFTER = auto()
+    # Keep on CPU
+    DEVICE_NEVER = auto()
+
+
+class FSDPTestModel(nn.Module, ABC):
+    """This defines the interface expected from all models used commonly for
+    FSDP unit tests."""
+
+    @abstractmethod
+    def get_input(self, device) -> tuple[torch.Tensor, ...]:
+        """Returns an input for the model as as tuple."""
+        ...
+
+    @abstractmethod
+    def get_loss(self, input, output) -> torch.Tensor:
+        """Returns the loss given the input and output."""
+        ...
+
+    @abstractmethod
+    def run_backward(self, loss) -> None:
+        """Runs the backward pass (e.g. including ``loss.backward()``)."""
+        ...
+
+    @staticmethod
+    @abstractmethod
+    def init(*args: Any, **kwargs: Any) -> nn.Module:
+        """Initializes an instance of this model."""
+        ...
+
+
+def _assert_module_states(
+    model: nn.Module,
+    process_group: dist.ProcessGroup,
+    assert_fn: Callable,
+):
+    """
+    All-gathers module states across ranks and calls ``assert_fn`` on each pair
+    of corresponding states from rank 0 and a nonzero rank. For example, if
+    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
+    states are equal across ranks.
+    """
+    # Include names for debugging convenience
+    named_module_states = [
+        (param_name, param.detach().cpu())
+        for param_name, param in model.named_parameters()
+    ]
+    named_module_states += [
+        (buffer_name, buffer.detach().cpu())
+        for buffer_name, buffer in model.named_buffers()
+    ]
+    world_size = dist.get_world_size(process_group)
+    olist = [None for _ in range(world_size)]
+    dist.all_gather_object(olist, named_module_states, group=process_group)
+    rank0_states = olist[0]
+    assert rank0_states is not None  # mypy
+    for state in olist[1:]:
+        assert state is not None  # mypy
+        for (_, p1), (_, p2) in zip(rank0_states, state, strict=True):
+            assert_fn(p1, p2)
+
+
+def get_devtype():
+    return torch.device(DEVICE_TYPE)
+
+
+def _zero_model(
+    model: nn.Module,
+    zero_buffers: bool = False,
+    summon_full=True,
+):
+    """Zeros the parameters and optionally buffers of ``model`` in place."""
+    ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
+    with ctx:
+        for param in model.parameters():
+            with torch.no_grad():
+                param.zero_()
+        if zero_buffers:
+            for buffer in model.buffers():
+                with torch.no_grad():
+                    buffer.zero_()
+
+
+def _get_state_dict(model, cpu_offload=False, half=False):
+    if not cpu_offload:
+        model = model.to(DEVICE_TYPE)
+    if half:
+        model.half()
+
+    return model.state_dict()
+
+
+def subtest_name(test_name_mapping, *args):
+    return "_".join(
+        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
+    )
+
+
+def _broadcast_state_dict(rank, state_dict):
+    # For non-FSDP roots, some parts of the model state on rank 0 may
+    # not be on CPU, so we move everything to CPU to avoid issues like:
+    # https://github.com/pytorch/pytorch/issues/77113.
+    for param_name, param in state_dict.items():
+        if param.device != torch.device("cpu"):
+            state_dict[param_name] = param.cpu()
+
+    olist = [state_dict if rank == 0 else None]
+    dist.broadcast_object_list(olist)
+    state_dict = cast(dict[str, torch.Tensor], olist[0])
+    # Ensure that the state is on DEVICE
+    for param_name in state_dict:
+        state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
+    return state_dict
+
+
+def get_full_params(model: nn.Module, recurse: bool = True):
+    """
+    Returns the full unsharded parameters of ``model``. Any FSDP-managed
+    parameters offloaded to CPU are moved to GPU in the returned list.
+
+    Args:
+        recurse (bool): If ``False``, only unshards the parameters immediate to
+            ``model``; if ``True``, recurses through the module hierarchy
+            rooted at ``model``.
+    """
+    with FSDP.summon_full_params(model, recurse=recurse):
+        return deepcopy(list(model.parameters()))
+
+
+def _move_to_device(model: nn.Module, move_to_device: bool):
+    return model.to(DEVICE_TYPE) if move_to_device else model
+
+
+def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
+    return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
+
+
+class DummyProcessGroup:
+    def __init__(self, rank: int, size: int):
+        self._rank = rank
+        self._size = size
+
+    def rank(self) -> int:
+        return self._rank
+
+    def size(self) -> int:
+        return self._size
+
+    def allreduce(self, *args, **kwargs):
+        dist_wait = mock.Mock()
+
+        def get_future():
+            future: torch.futures.Future = torch.futures.Future()
+            future.set_result(1)
+            return future
+
+        dist_wait.get_future = get_future
+        return dist_wait
+
+
+class TransformerWithSharedParams(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        device_init_mode: DEVICEInitMode,
+        add_bn: bool,
+        deterministic: bool,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        if deterministic:
+            torch.manual_seed(0)
+        d_vocab = 23
+        d_model = 16
+
+        self.embed_tokens = nn.Embedding(d_vocab, d_model)
+        self.transformer = nn.Transformer(
+            d_model=d_model,
+            num_encoder_layers=2,
+            num_decoder_layers=2,
+            dim_feedforward=8,
+            dropout=0.1,
+        )
+        self.output_proj = nn.Linear(d_model, d_vocab)
+
+        # share the embedding and output projection weights
+        self.output_proj.weight = self.embed_tokens.weight
+        self.register_buffer(
+            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
+        )
+        self.register_buffer(
+            "long_buffer",
+            torch.zeros_like(self.vocab_bias, dtype=torch.long),  # type: ignore[arg-type]
+        )  # type: ignore[arg-type]
+
+        self.bs = 2
+        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
+        if device_init_mode == DEVICEInitMode.DEVICE_BEFORE:
+            self = self.to(DEVICE_TYPE)
+        if deterministic:
+            self.eval()
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
+        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
+        return (src, tgt)
+
+    def forward(self, src_ids, tgt_ids):
+        src = self.embed_tokens(src_ids)
+        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
+        tgt = self.embed_tokens(tgt_ids)
+        tgt = self.bn(tgt)
+        x = self.transformer(src, tgt)
+        return self.output_proj(x)
+
+    def get_loss(self, input, output):
+        _, tgt = input
+        return nn.functional.cross_entropy(
+            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
+        )
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        add_bn: bool = True,
+    ) -> Union[nn.Module, FSDP]:
+        """
+        Initializes a :class:`TransformerWithSharedParams` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps with
+                top-level FSDP. By default, the top-level FSDP uses the
+                ``ModuleWrapPolicy`` for encoder and decoder layers, but a
+                different auto wrap policy may be specified via
+                ``fsdp_kwargs``.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            add_bn (bool): Whether to include batch norm in the model.
+        """
+
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            if isinstance(group, tuple):
+                pg = group[0]
+            else:
+                pg = group
+            return TransformerWithSharedParams(
+                pg, device_init_mode, add_bn, deterministic
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Default to the `ModuleWrapPolicy`
+            if "auto_wrap_policy" not in fsdp_kwargs:
+                auto_wrap_policy = ModuleWrapPolicy(
+                    {
+                        TransformerEncoderLayer,
+                        TransformerDecoderLayer,
+                    }
+                )
+            else:
+                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
+
+            if (
+                "sharding_strategy" in fsdp_kwargs
+                and fsdp_kwargs["sharding_strategy"]
+                in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
+                and not isinstance(group, tuple)
+            ):
+                fsdp_pg = None
+            else:
+                fsdp_pg = group
+
+            if isinstance(group, tuple):
+                tformer_pg = group[0]
+            else:
+                tformer_pg = group
+
+            m = TransformerWithSharedParams(
+                tformer_pg, device_init_mode, add_bn, deterministic
+            )
+            fsdp_model = FSDP(
+                m,
+                fsdp_pg,
+                auto_wrap_policy=auto_wrap_policy,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+    def get_ignored_modules(self):
+        return [self.transformer]
+
+
+class NestedWrappedModule(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(_move_to_device(nn.Linear(16, 4), move_to_device)),
+            _move_to_device(nn.Linear(4, 8), move_to_device),
+        )
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        return (torch.rand(4, 8, device=device),)
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = output.sum()
+        return loss
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ) -> nn.Module:
+        """
+        Initializes a :class:`NestedWrappedModule` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP but not the top-level module. The model may
+                later be wrapped with a top-level FSDP external to this method
+                if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return NestedWrappedModule(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = NestedWrappedModule(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
+        wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
+        policy.
+        """
+        model = super(
+            AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
+        ).init(
+            group=group,
+            fsdp_init_mode=FSDPInitMode.NO_FSDP,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+        )
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            fsdp_kwargs = fsdp_kwargs or {}
+            fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+
+
+class NonUniformReqGradNWM(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super(NestedWrappedModule, self).__init__()
+        # This `__init__` only differs from `NestedWrappedModule.__init__` in that
+        # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential`
+        # container. This arrangement results in all elements of the last two parameters
+        # residing on a single rank. Freezing all parameters except those two allows us
+        # to verify that `ShardedGradScaler` accommodates situations where some ranks
+        # have no (non-zero sized) parameter shards.
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(
+                nn.Sequential(
+                    _move_to_device(nn.Linear(16, 4), move_to_device),
+                    _move_to_device(nn.Linear(4, 8), move_to_device),
+                ),
+            ),
+        )
+
+    @staticmethod
+    def _set_nonuniform_req_grad(model, req_grad_mask) -> None:
+        for n, p in model.named_parameters():
+            if not re.match(req_grad_mask, n):
+                p.requires_grad_(False)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential`
+        container to enable the desired non-uniform ``requires_grad``
+        ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP``
+        init modes, freezes all parameters except the last two to validate
+        ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in
+        FSDP ``use_orig_params=True`` mode.
+        """
+        # The parameters that should remain unfrozen are in `module.2.1`. The regex
+        # pattern below matches the relevant parameter names both with and without
+        # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present.
+        req_grad_pattern = re.compile(r"module\.2.*\.1.*")
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            ddp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+            NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern)
+            return ddp_model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            if fsdp_kwargs is None:
+                fsdp_kwargs = {}
+            fsdp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class ModuleWithDelay(FSDPTestModel):
+    """This class wraps a :class:`FSDPTestModel` to optionally add a delay
+    after computing the loss and/or before the gradient reduction."""
+
+    def __init__(
+        self,
+        module: nn.Module,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+    ):
+        super().__init__()
+        self.delay_after_loss_ms = delay_after_loss_ms
+        self.delay_before_reduction_ms = delay_before_reduction_ms
+        self.module = module
+
+    def get_input(self, device):
+        return self.module.get_input(device)  # type: ignore[operator]
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = self.module.get_loss(input, output)  # type: ignore[operator]
+        if self.delay_after_loss_ms > 0:
+            if TEST_HPU or TEST_XPU:
+                time.sleep(self.delay_after_loss_ms / 1000)
+            elif TEST_CUDA:
+                torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
+
+        return loss
+
+    def run_backward(self, loss):
+        orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
+
+        def _delayed_reduce_scatter(*args, **kwargs):
+            if self.delay_before_reduction_ms > 0:
+                if TEST_CUDA:
+                    torch.cuda._sleep(
+                        int(self.delay_before_reduction_ms * get_cycles_per_ms())
+                    )
+                elif TEST_HPU or TEST_XPU:
+                    time.sleep(self.delay_before_reduction_ms / 1000)
+            return orig_reduce_scatter(*args, **kwargs)
+
+        with mock.patch(
+            "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
+        ):
+            self.module.run_backward(loss)  # type: ignore[operator]
+
+    @staticmethod
+    def init(
+        module_class: type[FSDPTestModel],
+        *model_args: Any,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+        **model_kwargs: Any,
+    ):
+        """
+        Args:
+            module_class (Type[FSDPTestModel]): Wrapped module class to which
+                to add delays.
+            model_args: Positional arguments forwarded to the ``module_class``
+                ``init()``.
+            delay_after_loss_ms (int): Delay after computing the loss/before
+                the optimizer step (in ms).
+            delay_before_reduction_ms (int): Delay before reduce-scattering
+                gradients (in ms).
+            model_kwargs: Keyword arguments forwarded to the ``module_class``
+                ``init()``.
+        """
+        return ModuleWithDelay(
+            module_class.init(*model_args, **model_kwargs),
+            delay_after_loss_ms,
+            delay_before_reduction_ms,
+        )
+
+
+class NestedWrappedModuleWithDelay(ModuleWithDelay):
+    @staticmethod
+    def init(  # type: ignore[override]
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_after_loss_ms: int = 0,
+        delay_before_reduction_ms: int = 0,
+    ):
+        return ModuleWithDelay.init(
+            NestedWrappedModule,
+            group=group,
+            fsdp_init_mode=fsdp_init_mode,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+            delay_after_loss_ms=delay_after_loss_ms,
+            delay_before_reduction_ms=delay_before_reduction_ms,
+        )
+
+
+class DummyDDP(nn.Module):
+    def __init__(self, module):
+        super().__init__()
+        self.module = module
+
+    def forward(self, *args, **kwargs):
+        return self.module(*args, **kwargs)
+
+
+class MixtureOfExperts(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        delay_before_free_ms: int,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__(
+            group=group,
+            wrap_fsdp=wrap_fsdp,
+            device_init_mode=device_init_mode,
+            deterministic=deterministic,
+        )
+        self.group = group
+        self.delay_before_free_ms = delay_before_free_ms
+        self.wrap_fsdp = wrap_fsdp
+        self.move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+        if deterministic:
+            # Give each rank different expert parameters
+            torch.manual_seed(42 + self.rank)
+        d_expert = 23
+        d_shared = 12
+        d_input = 8
+        expert = _move_to_device(nn.Linear(d_expert, d_shared), self.move_to_device)
+
+        self.num_expert_params = sum(p.numel() for p in expert.parameters())
+        for p in expert.parameters():
+            p.expert = True  # type: ignore[attr-defined]
+
+        if deterministic:
+            # Keep all other parameters the same across ranks
+            torch.manual_seed(0)
+
+        shared = _move_to_device(nn.Linear(d_shared, d_expert), self.move_to_device)
+
+        if wrap_fsdp:
+            # we create a process group of size 1 for the expert params
+            expert_group = torch.distributed.new_group(
+                [group.rank()]
+            )  # world size 1 means no shard
+            expert = FSDP(expert, expert_group, **fsdp_kwargs)  # type: ignore[assignment]
+            shared = FSDP(shared, group, **fsdp_kwargs)  # type: ignore[assignment]
+
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(d_input, d_shared), self.move_to_device),
+            shared,
+            expert,
+            _move_to_device(nn.Linear(d_shared, d_input), self.move_to_device),
+        )
+
+    def forward(self, x):
+        if self.delay_before_free_ms > 0:
+            expert = self.module[2]
+            if isinstance(expert, FSDP):
+                orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
+
+                def _delayed_reshard(*args, **kwargs):
+                    if TEST_CUDA:
+                        torch.cuda._sleep(
+                            int(self.delay_before_free_ms * get_cycles_per_ms())
+                        )
+                    elif TEST_HPU or TEST_XPU:
+                        time.sleep(self.delay_before_free_ms / 1000)
+
+                    return orig_reshard(*args, **kwargs)
+
+                # This patch covers any `import torch..._reshard` uses.
+                with mock.patch(
+                    "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
+                ):
+                    return self.module(x)
+
+        return self.module(x)
+
+    def run_backward(self, loss):
+        loss.backward()
+        # Manually reduce gradients if not wrapped in FullyShardedDataParallel
+        if not self.wrap_fsdp:
+            with torch.no_grad():
+                for p in self.parameters():
+                    if hasattr(p, "expert"):
+                        continue  # these params don't need grad reduction
+                    if p.grad is not None:
+                        p.grad.div_(self.world_size)
+                        torch.distributed.all_reduce(p.grad, group=self.group)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_before_free_ms: int = 0,
+    ):
+        """
+        Initializes a :class:`MixtureOfExperts` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP, including the expert and shared layers, but
+                not the top-level module. The model may later be wrapped with a
+                top-level FSDP external to this method if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            delay_before_free_ms (int): Delay before resharding expert
+                parameters in the forward pass (in ms).
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return MixtureOfExperts(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = MixtureOfExperts(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class MLP(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        device: Optional[torch.device] = None,
+        *,
+        bias: bool = True,
+        with_buffer: bool = False,
+        dim_multiplier: int = 4,
+    ):
+        super().__init__()
+        self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias)
+        self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias)
+        if with_buffer:
+            self.register_buffer("buffer", torch.randn((dim,), device=device))
+        else:
+            self.buffer = None
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        z = self.in_proj(x)
+        z = F.relu(z)
+        z = self.out_proj(z)
+        z = F.relu(z)
+        if self.buffer is not None:
+            z = z + self.buffer
+        return z
+
+    def reset_parameters(self):
+        if self.buffer is not None:
+            torch.nn.init.normal_(self.buffer)
+
+
+class MLPStack(nn.Sequential):
+    def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
+        modules: list[nn.Module] = [
+            # Use multiplier of 3 to exercise uneven case
+            MLP(mlp_dim, dim_multiplier=3),
+            MLP(mlp_dim),
+            MLP(mlp_dim, dim_multiplier=3),
+        ]
+        if with_seq_parallel:
+            modules.append(nn.LayerNorm(mlp_dim, bias=False))
+        super().__init__(*modules)
+        self.with_seq_parallel = with_seq_parallel
+
+    def parallelize(
+        self,
+        tp_mesh: DeviceMesh,
+        dp_mesh: DeviceMesh,
+        use_activation_checkpointing: bool,
+        **fsdp_kwargs,
+    ) -> "MLPStack":
+        parallelize_plan = {
+            # Pass `use_local_output=False` to keep as DTensor to preserve
+            # uneven activation dims
+            "0.in_proj": ColwiseParallel(use_local_output=False),
+            "0.out_proj": RowwiseParallel(use_local_output=False),
+            "1.in_proj": ColwiseParallel(use_local_output=False),
+            "1.out_proj": RowwiseParallel(use_local_output=False),
+            "2.in_proj": ColwiseParallel(use_local_output=False),
+            "2.out_proj": RowwiseParallel(output_layouts=Shard(1))
+            if self.with_seq_parallel
+            else RowwiseParallel(),
+        }
+        if self.with_seq_parallel:
+            parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
+        parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
+        for module in self:
+            if isinstance(module, nn.LayerNorm):
+                continue
+            if use_activation_checkpointing:
+                checkpoint(module)
+            fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
+        fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
+        return self
+
+
+class DoubleLinear(nn.Module):
+    """
+    This can be used for returning multiple outputs from a module
+    (``use_second_linear=True``) or for having an unused module (``False``).
+    """
+
+    def __init__(self, dim: int, use_second_linear: bool = True):
+        super().__init__()
+        self.lin1 = nn.Linear(dim, dim)
+        self.lin2 = nn.Linear(dim, dim)
+        self.relu = nn.ReLU()
+        self.use_second_linear = use_second_linear
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
+        if self.use_second_linear:
+            return self.relu(self.lin1(x)), self.relu(self.lin2(x))
+        return self.relu(self.lin1(x))
+
+
+# NOTE: For these patch methods, if we want safety under multi-threading (e.g.
+# when using multi-threaded process group), then we want:
+# (1) a barrier immediately after reading the original value to ensure that all
+# threads see the same original value
+# (2) a barrier immediately before restoring the original value to ensure that
+# all threads use the patched value inside the context
+@contextlib.contextmanager
+def patch_all_gather(new_all_gather_into_tensor: Callable):
+    orig_all_gather = dist.all_gather_into_tensor
+    dist.barrier()
+    dist.all_gather_into_tensor = new_all_gather_into_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_gather_into_tensor = orig_all_gather
+
+
+@contextlib.contextmanager
+def patch_foreach_all_gather(new_foreach_all_gather: Callable):
+    orig_foreach_all_gather = (
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather
+    )
+    dist.barrier()
+    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
+        new_foreach_all_gather
+    )
+    try:
+        yield
+    finally:
+        dist.barrier()
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
+            orig_foreach_all_gather
+        )
+
+
+@contextlib.contextmanager
+def patch_foreach_reduce(new_foreach_reduce: Callable):
+    orig_foreach_foreach_reduce = (
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce
+    )
+    dist.barrier()
+    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
+        new_foreach_reduce
+    )
+    try:
+        yield
+    finally:
+        dist.barrier()
+        torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
+            orig_foreach_foreach_reduce
+        )
+
+
+@contextlib.contextmanager
+def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
+    orig_reduce_scatter = dist.reduce_scatter_tensor
+    dist.barrier()
+    dist.reduce_scatter_tensor = new_reduce_scatter_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.reduce_scatter_tensor = orig_reduce_scatter
+
+
+@contextlib.contextmanager
+def patch_all_reduce(new_all_reduce: Callable):
+    orig_all_reduce = dist.all_reduce
+    dist.barrier()
+    dist.all_reduce = new_all_reduce
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_reduce = orig_all_reduce
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_unshard(new_unshard: Callable):
+    orig_unshard = FSDPParamGroup.unshard
+    dist.barrier()
+    FSDPParamGroup.unshard = new_unshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.unshard = orig_unshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_reshard(new_reshard: Callable):
+    orig_reshard = FSDPParamGroup.reshard
+    dist.barrier()
+    FSDPParamGroup.reshard = new_reshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.reshard = orig_reshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_post_backward(new_post_backward: Callable):
+    orig_post_backward = FSDPParamGroup.post_backward
+    dist.barrier()
+    FSDPParamGroup.post_backward = new_post_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.post_backward = orig_post_backward
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_register_post_backward_hook_backward(new_backward: Callable):
+    orig_backward = RegisterPostBackwardFunction.backward
+    dist.barrier()
+    RegisterPostBackwardFunction.backward = new_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        RegisterPostBackwardFunction.backward = orig_backward
+
+
+def reduce_scatter_with_assert(
+    cls,
+    orig_reduce_scatter: Callable,
+    assert_fn: Callable,  # `assert_fn(output: Tensor)`
+    *args: Any,
+    **kwargs: Any,
+):
+    if len(args) > 0:
+        output = args[0]
+    elif "output" in kwargs:
+        output = kwargs["output"]
+    else:
+        raise AssertionError(
+            f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}"
+        )
+    assert_fn(output)
+    return orig_reduce_scatter(*args, **kwargs)
+
+
+def check_sharded_parity(
+    cls,  # unit test class
+    replicated_module: nn.Module,
+    sharded_module: nn.Module,
+    prefixes_to_ignore: tuple[str, ...] = (),
+):
+    for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip(
+        replicated_module.named_parameters(),
+        sharded_module.named_parameters(),
+        strict=True,
+    ):
+        clean_sharded_name = sharded_name
+        for prefix in prefixes_to_ignore:
+            clean_sharded_name = clean_sharded_name.replace(prefix, "")
+        cls.assertEqual(replicated_name, clean_sharded_name)
+        cls.assertIsInstance(sharded_param, DTensor)
+        assert isinstance(sharded_param, DTensor)  # mypy
+        mesh, placements = sharded_param.device_mesh, sharded_param.placements
+        if tuple(placements) == (Shard(0), Shard(0)):
+            raise AssertionError(
+                "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), "
+                "so we cannot check for equality using it"
+            )
+        sharded_ref_param = distribute_tensor(replicated_param, mesh, placements)
+        cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local())
+        if replicated_param.grad is None:
+            cls.assertIsNone(sharded_param.grad)
+            continue
+        cls.assertIsNotNone(sharded_param.grad)
+        sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements)
+        cls.assertIsInstance(sharded_param.grad, DTensor)
+        assert isinstance(sharded_param.grad, DTensor)  # mypy
+        cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
+
+
+@unittest.skipIf(TEST_XPU, "not-support-multithread")
+class FSDPTestMultiThread(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    def perThreadSetUp(self):
+        torch._dynamo.reset()
+
+    def perThreadTearDown(self):
+        torch._dynamo.reset()
+
+
+class FSDPTest(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
+        # which can cause unit test flakiness:
+        # https://github.com/pytorch/pytorch/issues/90848
+        os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
+        self._spawn_processes()
+
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    @property
+    def process_group(self):
+        return dist.distributed_c10d._get_default_group()
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    def _check_cpu_offload(self, fsdp_model, cpu_offload):
+        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
+
+    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
+        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
+
+    def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
+        self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):  # type: ignore[override]
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        fake_pg = kwargs.get("fake_pg", False)
+
+        print(f"dist init r={self.rank}, world={self.world_size}")
+        if torch.accelerator.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        # Specify gloo backend to make 'init_process_group()' succeed,
+        # Actual tests will be skipped if there is no enough GPUs.
+        try:
+            if fake_pg:
+                store = torch.testing._internal.distributed.fake_pg.FakeStore()
+                dist.init_process_group(
+                    backend="fake",
+                    world_size=self.world_size,
+                    rank=rank,
+                    store=store,
+                )
+            else:
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=DISTRIBUTED_BACKEND,
+                    world_size=int(self.world_size),
+                    rank=self.rank,
+                )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        device_ids = None
+        device_id = self.rank % DEVICE_COUNT
+        if TEST_CUDA or TEST_XPU:
+            torch.accelerator.set_device_index(device_id)
+        device_ids = [device_id]
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        dist.barrier(device_ids=device_ids)
+
+        torch._dynamo.reset()
+        set_rng_seed()
+        self.run_test(test_name, pipe)
+        torch._dynamo.reset()
+
+        dist.barrier(device_ids=device_ids)
+
+        dist.destroy_process_group()
+
+    def _train_for_several_steps(
+        self,
+        model: nn.Module,
+        num_steps: int,
+        autocast: bool,
+        lr: float = 0.01,
+        fsdp_cpu_offload: Optional[CPUOffload] = None,
+        save_model: bool = False,
+        mixed_precision: Optional[MixedPrecision] = None,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+    ):
+        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
+
+        model_device = next(model.parameters()).device
+        if sharded_grad_scaler_kwargs is None:
+            sharded_grad_scaler_kwargs = {}
+        sharded_grad_scaler = ShardedGradScaler(
+            enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs
+        )
+        # use SGD with momentum instead of Adam, since Adam is scale invariant
+        # and this makes it bad for tests
+        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
+        for _ in range(num_steps):
+            optim.zero_grad()
+            with torch.amp.autocast(DEVICE_TYPE, enabled=autocast):
+                # Inputs always cuda regardless of cpu offloading, or model.device
+                input = model.module.get_input(torch.device(DEVICE_TYPE))  # type: ignore[operator, union-attr]
+                if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
+                    if isinstance(input, torch.Tensor):
+                        input = input.half()
+                    else:
+                        input = tuple(x.half() for x in input)
+                output = model(*input)
+                # Post-forward, if CPU offloading model param should be on CPU.
+                if (
+                    cpu_offload_params
+                    and isinstance(model, FSDP)
+                    # If not resharding after forward, the parameters are still
+                    # exposed as unsharded views into the GPU flat parameter
+                    and model.sharding_strategy
+                    not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+                ):
+                    for p in model.parameters():
+                        # Params should always be on CPU
+                        self.assertEqual(p.device, torch.device("cpu"))
+
+                loss = model.module.get_loss(input, output).to(model_device)  # type: ignore[operator, union-attr]
+            loss = sharded_grad_scaler.scale(loss)
+
+            if not mixed_precision and not use_pure_fp16:
+                assert loss.dtype == torch.float32, (
+                    "loss data type should be float32, as the original \
+                    parameter data type is float32."
+                )
+            else:
+                if use_pure_fp16:
+                    self.assertEqual(loss.dtype, torch.float16)
+                # FSDP loss is fp16, DDP AMP loss is fp32
+                elif isinstance(model, FSDP):
+                    assert mixed_precision is not None  # mypy
+                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
+                else:
+                    self.assertEqual(loss.dtype, torch.float32)
+            model.module.run_backward(loss)  # type: ignore[operator, union-attr]
+            # Post-backward, if CPU offloading model params should be on CPU.
+            if cpu_offload_params and isinstance(model, FSDP):
+                for p in model.parameters():
+                    # Params should always be on CPU
+                    self.assertEqual(p.device, torch.device("cpu"))
+            # Unscale the gradients and step
+            sharded_grad_scaler.step(optim)
+            # Update the scale factor
+            sharded_grad_scaler.update()
+            # if save_model, simulate save + load.
+            if save_model:
+                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
+                # Zero params, if save/load state_dict did not work properly, this
+                # would break the parity test with DDP.
+                _zero_model(model)
+                model.load_state_dict(state_dict)
+
+        if isinstance(model, FSDP):
+            model._assert_state(TrainingState.IDLE)
+        return loss.detach()  # type: ignore[possibly-undefined]
+
+    def _test_fsdp_parity(
+        self,
+        model_class: type[FSDPTestModel],
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        ref_init_fn: Optional[Callable] = None,
+        num_iters: int = 2,
+        save_model: bool = True,
+        cpu_offload: CPUOffload = CPUOffload(),
+        backward_prefetch: Optional[BackwardPrefetch] = None,
+        sharding_strategy: Optional[ShardingStrategy] = None,
+        mixed_precision: Optional[MixedPrecision] = None,
+        forward_prefetch: bool = False,
+        use_orig_params: bool = False,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        init_kwargs: Optional[dict[str, Any]] = None,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+        **fsdp_kwargs,
+    ):
+        """
+        Tests FSDP training against a reference, which defaults to DDP but
+        may be customized with ``ref_init_fn``.
+
+        Args:
+            model_class (Type[FSDPTestModel]): A model class that inherits from
+                ``FSDPTestModel``, which defines the expected interface.
+            fsdp_init_mode (FSDPInitMode): The mode to initialize the
+                FSDP-wrapped model. This should not be ``NO_FSDP``.
+            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
+                non-wrapped model to construct the reference model, where this
+                wrapper should provide data parallel semantics. If ``None``,
+                then the callable defaults to the DDP constructor.
+        """
+        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, (
+            "Expects an FSDP init mode that wraps with FSDP"
+        )
+        if init_kwargs is None:
+            init_kwargs = {}
+        lr = 1e-2
+        rank = self.process_group.rank()
+        # Establish reference behavior with DDP
+        model = model_class.init(
+            self.process_group,
+            FSDPInitMode.NO_FSDP,
+            DEVICEInitMode.DEVICE_BEFORE,
+            deterministic=True,
+            **init_kwargs,
+        )
+        if ref_init_fn is None:
+            if TEST_HPU:
+                ref_model = DDP(
+                    model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE
+                )
+            else:
+                ref_model = DDP(model, device_ids=[rank], output_device=rank)
+        else:
+            ref_model = ref_init_fn(model)
+        if use_pure_fp16:
+            ref_model = ref_model.half()
+        ref_loss = self._train_for_several_steps(
+            ref_model,
+            num_iters,
+            autocast=mixed_precision is not None,
+            lr=lr,
+            fsdp_cpu_offload=cpu_offload,
+            mixed_precision=mixed_precision,
+            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+            use_pure_fp16=use_pure_fp16,
+            sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+        )
+        ddp_params = list(ref_model.parameters())
+        # Check against FSDP behavior
+        fsdp_kwargs.update(
+            {
+                "cpu_offload": cpu_offload,
+                "backward_prefetch": backward_prefetch,
+                "sharding_strategy": sharding_strategy,
+                "mixed_precision": mixed_precision,
+                "forward_prefetch": forward_prefetch,
+                "use_orig_params": use_orig_params,
+            }
+        )
+        try:
+            fsdp_model = model_class.init(
+                self.process_group,
+                fsdp_init_mode,
+                device_init_mode,
+                fsdp_kwargs,
+                deterministic=True,
+                **init_kwargs,
+            )
+        except Exception as e:
+            raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
+        if not isinstance(fsdp_model, FSDP):
+            # Enforce that we wrap with top-level FSDP since we are comparing
+            # assuming a data parallel reference and some test models may not
+            # do so in their `init()` method
+            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
+        if use_pure_fp16:
+            # Change the model parameter dtype after FSDP initialization
+            fsdp_model = fsdp_model.half()
+        if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+            fsdp_model = fsdp_model.to(DEVICE_TYPE)
+        offload_params = cpu_offload is not None and cpu_offload.offload_params
+        # Offloading parameters with `DEVICE_AFTER` should raise an error during
+        # lazy initialization due to the parameter devices not being CPU;
+        # otherwise, all parameter devices should be CPU
+        expects_device_error = (
+            offload_params and device_init_mode == DEVICEInitMode.DEVICE_AFTER
+        )
+        expects_cpu_device = (
+            offload_params and device_init_mode != DEVICEInitMode.DEVICE_AFTER
+        )
+        if expects_cpu_device:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+        context = (
+            self.assertRaisesRegex(
+                RuntimeError,
+                "An FSDP-managed module with parameter CPU offloading enabled "
+                f"has parameters on {DEVICE_TYPE}",
+            )
+            if expects_device_error
+            else nullcontext()
+        )
+        with context:
+            fsdp_loss = self._train_for_several_steps(
+                fsdp_model,
+                num_iters,
+                autocast=False,
+                lr=lr,
+                fsdp_cpu_offload=cpu_offload,
+                save_model=save_model,
+                mixed_precision=mixed_precision,
+                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+                use_pure_fp16=use_pure_fp16,
+                sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+            )
+        # No need to check for parameter and loss parity if expecting an error
+        if expects_device_error:
+            return
+        # Check parameter devices are CPU if offloading to CPU before calling
+        # `get_full_params()`, which will cast the parameters to FP32
+        if offload_params:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+            fsdp_loss = fsdp_loss.to(DEVICE_TYPE)
+        fsdp_unsharded_params = get_full_params(fsdp_model)
+        # Do not check dtype since the reference DDP loss may not be the same
+        # dtype as the FSDP loss in the case of mixed precision
+        torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
+        # Do not check for parameter parity if using mixed precision since (1)
+        # the DDP parameters are in FP16 (from `half()`) while the FSDP
+        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
+        # the optimizer in FP16 while FSDP runs it in FP32
+        # TODO: Disable checking the parameters for pure FP16 due to floating
+        # point inaccuracy. Note that this means that the backward pass is not
+        # checked: https://github.com/pytorch/pytorch/issues/90784
+        if mixed_precision is None and not use_pure_fp16:
+            self.assertEqual(
+                ddp_params,
+                fsdp_unsharded_params,
+                exact_device=True,
+                msg="FSDP did not match DDP",
+            )
+
+
+def compiled_fsdp_test(compile_compute_on_module: Optional[type] = None):
+    def fully_shard_with_compiled_compute(*args, **kwargs):
+        torch.distributed.fsdp.fully_shard(*args, **kwargs)  # type: ignore[operator]
+        if compile_compute_on_module is None or isinstance(
+            args[0], compile_compute_on_module
+        ):
+            args[0].compile()
+
+    class FullyShardMode(Enum):
+        EAGER = auto()
+        COMPILED_COMPUTE = auto()
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            original_fully_shard: Any = torch.distributed.fsdp.fully_shard
+            for mode in FullyShardMode:
+                if mode != FullyShardMode.EAGER and not has_triton():
+                    warnings.warn(
+                        "Inductor on GPU needs Triton and recent GPU arch", stacklevel=2
+                    )
+                    continue
+                # barrier to ensure thread reading the same value
+                original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
+                original_compile_threads = torch._inductor.config.compile_threads
+                torch.distributed.barrier()
+
+                if mode == FullyShardMode.EAGER:
+                    fully_shard_patch = original_fully_shard
+                elif mode == FullyShardMode.COMPILED_COMPUTE:
+                    torch._dynamo.config.skip_fsdp_hooks = True
+                    torch._inductor.config.compile_threads = 1
+                    fully_shard_patch = fully_shard_with_compiled_compute  # type: ignore[assignment]
+                else:
+                    raise NotImplementedError(
+                        f"Need to implement FullyShardMode={mode}"
+                    )
+
+                # fully_shard is imported as a global
+                # through `from ... import fully_shard`
+                func.__globals__[original_fully_shard.__name__] = fully_shard_patch
+                func(*args, **kwargs)
+                # other threads use patched func before this thread restores
+                torch.distributed.barrier()
+                func.__globals__[original_fully_shard.__name__] = original_fully_shard
+                torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
+                torch._inductor.config.compile_threads = original_compile_threads
+
+        return wrapper
+
+    return decorator
+
+
+class SkipModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        return self.lin(x)
+
+
+class NestedLinear(nn.Module):
+    def __init__(self, fsdp_wrap):
+        super().__init__()
+        if fsdp_wrap:
+            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).to(DEVICE_TYPE))
+        else:
+            self.nested_linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+
+    def forward(self, x):
+        return self.nested_linear(x)
+
+
+class SkipModel(nn.Module):
+    def __init__(self, double_nest):
+        super().__init__()
+        self.linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+        self.linear_skip = SkipModule().to(DEVICE_TYPE)
+        self.nested_linear = wrap(
+            NestedLinear(fsdp_wrap=double_nest), device_id=DEVICE_TYPE
+        )
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.linear_skip(x)
+        x = self.nested_linear(x)
+        return x
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..70ab98137bd712de4c5b0e998e26bd585ff4433c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
@@ -0,0 +1,113 @@
+# mypy: ignore-errors
+
+import contextlib
+import functools
+import inspect
+
+import torch
+
+
+def bf32_is_not_fp32():
+    if not torch.backends.mkldnn.is_available():
+        return False
+    if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+        return False
+    return True
+
+
+def tf32_is_not_fp32():
+    if not torch.backends.mkldnn.is_available():
+        return False
+    if not torch._C._cpu._is_amx_fp16_supported():
+        return False
+    return True
+
+
+@contextlib.contextmanager
+def reduced_f32_off():
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "ieee"
+        torch.backends.mkldnn.conv.fp32_precision = "ieee"
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+
+
+@contextlib.contextmanager
+def bf32_on(self, bf32_precision=1e-2):
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    old_precision = self.precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "bf16"
+        torch.backends.mkldnn.conv.fp32_precision = "bf16"
+        self.precision = bf32_precision
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+        self.precision = old_precision
+
+
+@contextlib.contextmanager
+def tf32_on(self, tf32_precision=1e-5):
+    old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
+    old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
+    old_precision = self.precision
+    try:
+        torch.backends.mkldnn.matmul.fp32_precision = "tf32"
+        torch.backends.mkldnn.conv.fp32_precision = "tf32"
+        self.precision = tf32_precision
+        yield
+    finally:
+        torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
+        torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
+        self.precision = old_precision
+
+
+# This is a wrapper that wraps a test to run this test three times, one with
+# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32
+# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/
+# tf32) as specified by the argument.
+def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5):
+    def with_reduced_f32_disabled(self, function_call):
+        with reduced_f32_off():
+            function_call()
+
+    def with_bf32_enabled(self, function_call):
+        with bf32_on(self, bf32_precision):
+            function_call()
+
+    def with_tf32_enabled(self, function_call):
+        with tf32_on(self, tf32_precision):
+            function_call()
+
+    def wrapper(f):
+        params = inspect.signature(f).parameters
+        arg_names = tuple(params.keys())
+
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            kwargs.update(zip(arg_names, args, strict=False))
+            cond = True
+            if "device" in kwargs:
+                cond = cond and (torch.device(kwargs["device"]).type == "cpu")
+            if "dtype" in kwargs:
+                cond = cond and (kwargs["dtype"] == torch.float)
+            bf32_cond = cond and bf32_is_not_fp32()
+            tf32_cond = cond and tf32_is_not_fp32()
+            if bf32_cond or tf32_cond:
+                with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs))
+                if bf32_cond:
+                    with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
+                if tf32_cond:
+                    with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
+            else:
+                f(**kwargs)
+
+        return wrapped
+
+    return wrapper
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..83fca0b973856ad05dcdd417f1f46f85bcd8591f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
@@ -0,0 +1,4380 @@
+# mypy: ignore-errors
+
+import torch
+import unittest
+from copy import deepcopy
+from enum import Enum
+from functools import wraps, partial
+from itertools import chain, product
+import itertools
+import math
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pack_padded_sequence
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import TEST_CUDNN
+from torch.testing._internal.common_dtype import (
+    floating_types, floating_and_complex_types_and, get_all_fp_dtypes)
+from torch.testing._internal.common_device_type import (
+    _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol,
+    precisionOverride, skipMeta, skipMPS)
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_nn import (
+    cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference,
+    hingeembeddingloss_reference, huberloss_reference, kldivloss_reference,
+    marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference,
+    nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction)
+from torch.testing._internal.common_utils import (
+    freeze_rng_state, skipIfMPS, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
+    skipIfTorchDynamo)
+from types import ModuleType
+import operator
+
+# List of all namespaces containing modules to test.
+MODULE_NAMESPACES: list[ModuleType] = [
+    torch.nn.modules,
+    torch.ao.nn.qat.modules,
+    torch.ao.nn.quantizable.modules,
+    torch.ao.nn.quantized.modules,
+    torch.ao.nn.quantized.modules,
+]
+
+# Modules that shouldn't be tested for one reason or another.
+MODULES_TO_SKIP: set[type] = {
+    torch.nn.Module,  # abstract base class
+    torch.nn.Container,  # deprecated
+    torch.nn.NLLLoss2d,  # deprecated
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+}
+
+# List of all module classes to test.
+MODULE_CLASSES: list[type] = [*chain.from_iterable([
+    [getattr(namespace, module_name) for module_name in namespace.__all__]  # type: ignore[attr-defined]
+    for namespace in MODULE_NAMESPACES])]
+MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
+
+# Dict of module class -> common name. Useful for making test names more intuitive.
+# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
+MODULE_CLASS_NAMES: dict[type, str] = {}
+for namespace in MODULE_NAMESPACES:
+    for module_name in namespace.__all__:  # type: ignore[attr-defined]
+        module_cls = getattr(namespace, module_name)
+        namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '')
+
+        # Deal with any aliases by preferring earlier names.
+        if module_cls not in MODULE_CLASS_NAMES:
+            MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
+
+
+# Specifies the modes (i.e. train, eval) to test over.
+TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
+
+
+class modules(_TestParametrizer):
+    """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
+
+    def __init__(self, module_info_iterable, allowed_dtypes=None,
+                 train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True):
+        self.module_info_list = list(module_info_iterable)
+        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
+        self.train_eval_mode = train_eval_mode
+        self.skip_if_dynamo = skip_if_dynamo
+
+    def _get_training_flags(self, module_info):
+        training_flags = []
+        if (self.train_eval_mode == TrainEvalMode.train_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(True)
+
+        if (self.train_eval_mode == TrainEvalMode.eval_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(False)
+
+        # If train and eval modes don't differ for the module, don't bother using more than one.
+        if not module_info.train_and_eval_differ:
+            training_flags = training_flags[:1]
+
+        return training_flags
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError('The @modules decorator is only intended to be used in a device-specific '
+                               'context; use it with instantiate_device_type_tests() instead of '
+                               'instantiate_parametrized_tests()')
+
+        for module_info in self.module_info_list:
+            dtypes = set(module_info.supported_dtypes(device_cls.device_type))
+            if self.allowed_dtypes is not None:
+                dtypes = dtypes.intersection(self.allowed_dtypes)
+
+            training_flags = self._get_training_flags(module_info)
+            for (training, dtype) in product(training_flags, dtypes):
+                # Construct the test name; device / dtype parts are handled outside.
+                # See [Note: device and dtype suffix placement]
+                test_name = module_info.formatted_name
+                if len(training_flags) > 1:
+                    test_name += f"_{'train_mode' if training else 'eval_mode'}"
+
+                # Construct parameter kwargs to pass to the test.
+                param_kwargs = {'module_info': module_info}
+                _update_param_kwargs(param_kwargs, 'dtype', dtype)
+                _update_param_kwargs(param_kwargs, 'training', training)
+
+                try:
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR:
+                        test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper)
+
+                    decorator_fn = partial(module_info.get_decorators, generic_cls.__name__,
+                                           test.__name__, device_cls.device_type, dtype)
+
+                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+                except Exception as ex:
+                    # Provides an error message for debugging before rethrowing the exception
+                    print(f"Failed to instantiate {test_name} for module {module_info.name}!")
+                    raise ex
+
+
+def get_module_common_name(module_cls):
+    if module_cls in MODULE_CLASS_NAMES:
+        # Example: "nn.Linear"
+        return MODULE_CLASS_NAMES[module_cls]
+    else:
+        return module_cls.__name__
+
+
+class FunctionInput:
+    """ Contains args and kwargs to pass as input to a function. """
+    __slots__ = ['args', 'kwargs']
+
+    def __init__(self, *args, **kwargs):
+        self.args = args
+        self.kwargs = kwargs
+
+
+class ModuleInput:
+    """ Contains args / kwargs for module instantiation + forward pass. """
+    __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn']
+
+    def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None):
+        self.constructor_input = constructor_input  # Inputs to pass during construction
+        self.forward_input = forward_input  # Inputs to pass to forward()
+        self.desc = desc  # Description for this set of inputs
+        self.reference_fn = reference_fn  # Reference with signature: reference_fn(module, parameters, *args, **kwargs)
+
+        if reference_fn is not None:
+
+            @wraps(reference_fn)
+            def copy_reference_fn(m, *args, **kwargs):
+                # Copy inputs to avoid undesired side effects from calling the reference.
+                args, kwargs = deepcopy(args), deepcopy(kwargs)
+
+                # Note that module parameters are passed in for convenience.
+                return reference_fn(m, list(m.parameters()), *args, **kwargs)
+
+            self.reference_fn = copy_reference_fn
+
+class ModuleErrorEnum(Enum):
+    """ Enumerates when error is raised when testing modules. """
+    CONSTRUCTION_ERROR = 0
+    FORWARD_ERROR = 1
+
+class ErrorModuleInput:
+    """
+    A ModuleInput that will cause the operation to throw an error plus information
+    about the resulting error.
+    """
+
+    __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(self,
+                 module_error_input,
+                 *,
+                 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+                 error_type=RuntimeError,
+                 error_regex):
+        self.module_error_input = module_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class ModuleInfo:
+    """ Module information to be used in testing. """
+
+    def __init__(self,
+                 module_cls,  # Class object for the module under test
+                 *,
+                 module_inputs_func,  # Function to generate module inputs
+                 skips=(),  # Indicates which tests to skip
+                 decorators=None,  # Additional decorators to apply to generated tests
+                 dtypes=floating_types(),  # dtypes this function is expected to work with
+                 dtypesIfMPS=(torch.float16, torch.float32,),  # dtypes this function is expected to work with on MPS
+                 dtypesIfHpu=(torch.bfloat16, torch.float32,),
+                 supports_gradgrad=True,  # whether the op supports second order gradients
+                 gradcheck_nondet_tol=0.0,  # tolerance for nondeterminism while performing gradcheck
+                 module_memformat_affects_out=False,  # whether converting module to channels last will generate
+                                                      # channels last output
+                 train_and_eval_differ=False,  # whether the module has differing behavior between train and eval
+                 module_error_inputs_func=None,  # Function to generate module inputs that error
+                 gradcheck_fast_mode=None,  # Whether to use the fast implementation for gradcheck/gradgradcheck.
+                                            # When set to None, defers to the default value provided by the wrapper
+                                            # function around gradcheck (testing._internal.common_utils.gradcheck)
+                 ):
+        self.module_cls = module_cls
+        self.module_inputs_func = module_inputs_func
+        self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
+        self.dtypes = dtypes
+        self.dtypesIfMPS = dtypesIfMPS
+        self.dtypesIfHpu = dtypesIfHpu
+        self.supports_gradgrad = supports_gradgrad
+        self.gradcheck_nondet_tol = gradcheck_nondet_tol
+        self.module_memformat_affects_out = module_memformat_affects_out
+        self.train_and_eval_differ = train_and_eval_differ
+        self.module_error_inputs_func = module_error_inputs_func
+        self.gradcheck_fast_mode = gradcheck_fast_mode
+        self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(test_class, test_name, device, dtype, param_kwargs):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    def supported_dtypes(self, device_type):
+        if device_type == 'mps':
+            return self.dtypesIfMPS
+        elif device_type == 'hpu':
+            return self.dtypesIfHpu
+        else:
+            return self.dtypes
+
+    @property
+    def name(self):
+        return get_module_common_name(self.module_cls)
+
+    @property
+    def formatted_name(self):
+        return self.name.replace('.', '_')
+
+# Start of module inputs functions.
+
+def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(10, 8),
+                    forward_input=FunctionInput(input=make_input((4, 10))),
+                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+        ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t())),
+        ModuleInput(constructor_input=FunctionInput(3, 5),
+                    forward_input=FunctionInput(make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1])
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def bilinear_reference_fn(m, p, x1, x2, bias=True):
+        result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2)
+        if bias:
+            if x1.shape[0] == 1:
+                result = result.view(-1) + p[1]
+            else:
+                result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0])
+        return result
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    reference_fn=bilinear_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input(2), make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))),
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_batchmean', {'reduction': 'batchmean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('log_target', {'log_target': True})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return kldivloss_reference(i, t, **constructor_kwargs)
+
+        input = make_input((10, 10)).log()
+        target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log()
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input, target),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        scalar_input = make_input(()).log()
+        # FIXME(rec): scalar_target is unused, perhaps should be argument to FunctionInput?
+        scalar_target = (  # noqa: F841
+            make_input(()) if kwargs.get('log_target', False) else make_input(()).log()
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(scalar_input, scalar_input),
+                        desc='scalar_' + desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad):
+        return make_tensor(shape, device=device, dtype=dtype,
+                           requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('ignore_index', {'ignore_index': 2}),
+        ('weights', {'weight': make_weight(4).abs()}),
+        ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}),
+        ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1})
+    ]
+
+    # TODO: Uncomment when negative weights is supported.
+    # negative_weight = make_weight(10)
+    # negative_weight[0] = -1
+    # cases.append(('weights_negative', {'weight': negative_weight}))
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nllloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 4)),
+                                                    torch.empty(15, device=device).uniform_().mul(4).floor().long()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nlllossNd_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5)),
+                            torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"nd_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5, 2, 2)),
+                            torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"higher_dim_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5)),
+                            torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"3d_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('homoscedastic', {'homoscedastic': True}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        homoscedastic = constructor_kwargs.pop('homoscedastic', False)
+        var_input = make_input(1, 3).abs() if homoscedastic else make_input(4, 1).abs()
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(4, 3),
+                                                    make_target(4, 3),
+                                                    var_input),
+                        desc=desc,
+                        reference_fn=no_batch_dim_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('full', {'full': True}),
+        ('no_log_input', {'log_input': False}),
+        ('full_no_log_input', {'full': True, 'log_input': False}),
+    ]
+
+    def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8):
+        if log_input:
+            result = i.exp() - t.mul(i)
+        else:
+            result = i - t.mul((i + eps).log())
+
+        if full:
+            result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return poissonnllloss_reference_fn(i, t, **constructor_kwargs)
+
+        log_input = constructor_kwargs.get('log_input', True)
+        input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001)
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input,
+                                                    make_target((2, 3, 4, 5)).floor_().abs_()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    def mse_loss_reference_fn(m, p, i, t, reduction='mean'):
+        if reduction == 'none':
+            return (i - t).pow(2)
+        elif reduction == 'mean':
+            return (i - t).pow(2).sum() / i.numel()
+        else:
+            return (i - t).pow(2).sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 4, 5)),
+                                                    make_target((2, 3, 4, 5))),
+                        desc=desc,
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(())),
+                        desc=f'{desc}_scalar',
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def no_batch_dim_reference_fn(m, p, *args, **kwargs):
+    """Reference function for modules supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+
+    Currently it only supports modules which return a single Tensor as output.
+    You can bind the following kwargs.
+    Kwargs:
+        batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` .
+                        and output will be squeezed at dim `0` else dim `1` for both.
+        kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze.
+                               Useful if there are few arguments whose batch dimension are different
+                               from the ones selected by `batch_first`.
+        is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly.
+    """
+    def get_and_pop(key, default):
+        v = kwargs.get(key, default)
+        if key in kwargs:
+            kwargs.pop(key)
+        return v
+
+    batch_dim = 0 if get_and_pop('batch_first', True) else 1
+    kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None)
+    is_criterion = get_and_pop('is_criterion', False)
+
+    if kwargs_to_batchify is not None:
+        assert isinstance(kwargs_to_batchify, dict)
+        for k, v in kwargs.items():
+            if k in kwargs_to_batchify and v is not None:
+                bdim = kwargs_to_batchify[k]
+                kwargs[k] = v.unsqueeze(bdim)
+
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim)
+
+    if is_criterion:
+        reduction = get_reduction(m)
+        if reduction == 'none':
+            return output.squeeze(0)
+    return output
+
+
+def no_batch_dim_reference_mha(m, p, *args, **kwargs):
+    """Reference function for MultiheadAttention supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    batch_dim = 0 if kwargs.get('batch_first', True) else 1
+    if 'batch_first' in kwargs:
+        kwargs.pop('batch_first')
+    if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None:
+        kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0)
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(0))
+
+
+def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
+    """Reference function for RNN and GRU supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = h.unsqueeze(1)
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(1))
+
+
+def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
+    """Reference function for LSTM supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
+
+
+def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
+    """Reference function for LSTMCell supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    inp, (h, c) = args
+    single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0)))
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(0), output[1].squeeze(0))
+
+
+def generate_regression_criterion_inputs(make_input):
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(reduction=reduction),
+            forward_input=FunctionInput(make_input((4, )), make_input(4,)),
+            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True),
+            desc=f'no_batch_dim_{reduction}'
+        ) for reduction in ['none', 'mean', 'sum']]
+
+
+def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(kernel_size=2),
+                    forward_input=FunctionInput(make_input((3, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((2, 3, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2,), (2,)),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, 1),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride_pad')]
+
+
+def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((3, 6, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride_pad')]
+
+
+
+def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((3, 4, 4, 4))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride_pad_gpu_input_nooverlap'),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='divisor_stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride_pad_gpu_input_nooverlap')]
+
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='one_output')]
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single_1x1output'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 2, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((None, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput((3, 2, 2)),
+                    forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))),
+                    desc='last_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))),
+                    desc='single_nonatomic'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))),
+                    desc='tuple_nonatomic')]
+
+
+def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(10,),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine'),
+        ModuleInput(constructor_input=FunctionInput(5,),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, None),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input_not_affine'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 9))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='2d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='3d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if module_info.module_cls == torch.nn.BatchNorm1d:
+        input_shape = (2, 10)
+    elif module_info.module_cls == torch.nn.BatchNorm2d:
+        input_shape = (2, 10, 5, 5)
+    else:
+        input_shape = (2, 10, 4, 4, 4)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, eps=-1.0),
+                forward_input=FunctionInput(make_input(input_shape)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="eps must be positive"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, eps=0.0),
+                forward_input=FunctionInput(make_input(input_shape)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="eps must be positive"
+        ),
+    ]
+
+
+def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    N = kwargs['N']
+    lazy = kwargs.get('lazy', False)
+    transposed = kwargs.get('transposed', False)
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
+    kernel_size, C_in, C_out = 3, 4, 5
+    input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N))
+    input_batch_shape = (2,) + input_no_batch_shape
+    return [
+        ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
+                                       FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
+                    forward_input=FunctionInput(make_input(
+                        input_batch_shape if with_batch else input_no_batch_shape)),
+                    desc=('' if with_batch else 'no_batch_dim'),
+                    reference_fn=(None if with_batch else no_batch_dim_reference_fn))
+        for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
+    ]
+
+
+def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.7})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)),
+                                                    make_target((15,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+                    desc='4d_input')]
+
+
+def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    desc='dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((4,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='with_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.0),
+                    forward_input=FunctionInput(make_input((10, 10))),
+                    desc='with_zero_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='with_negval_scalar')]
+
+
+def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d_multiparam')]
+
+
+def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar')]
+
+
+def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))]
+
+
+def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
+                    desc='multiparam'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
+                    desc='multiparam_scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 10))),
+                    desc='multidim'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((3, 4, 10))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)),
+                    desc='beta'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='lambda'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='lambda_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+
+def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='threshold_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 10.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='large_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='threshold_value_scalar'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4)),
+                                                make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
+                                                                         for a, b in zip(i, t, strict=True))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(()), make_input(())),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
+                    desc='scalar')] + generate_regression_criterion_inputs(make_input)
+
+
+def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return smoothl1loss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_input(())),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+    ]
+
+    def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = -(t * i.log() + (1 - t) * (1 - i).log())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs))
+        )
+
+    scalar_weight = make_weight(())
+    module_inputs.append(
+        ModuleInput(constructor_input=FunctionInput(weight=scalar_weight),
+                    forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2),
+                                                make_target(()).gt(0).to(dtype)),
+                    desc='scalar_weight',
+                    reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight))
+    )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+        ('scalar_weights', {'weight': make_weight(())})
+    ]
+
+    def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        # TODO: add pos_weight to the definition here and corresponding SampleInputs
+        max_val = (-i).clamp(min=0)
+        result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    reductions: list[str] = ['mean', 'sum', 'none']
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('weights', {'weight': make_weight((3,))}),
+        ('ignore_index', {'ignore_index': 1}),
+        ('label_smoothing', {'label_smoothing': 0.15}),
+        ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15})
+    ]
+
+    module_inputs = []
+    for reduction, (desc, constructor_kwargs) in product(reductions, cases):
+        def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs):
+            return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5)),
+                                                    make_target((2, 5, 5), low=0, high=3)),
+                        desc=f"4d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5)),
+                                                    make_target((2, 5), low=0, high=3)),
+                        desc=f"3d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3)),
+                                                    make_target((2), low=0, high=3)),
+                        desc=f"2d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                    make_target((2, 5, 5, 2, 2), low=0, high=3)),
+                        desc=f"higher_dim_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+
+        if constructor_kwargs.get('ignore_index', None) is None:
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4, 2)),
+                                                        make_input((5, 3, 4, 2)).softmax(dim=1)),
+                            desc=f"4d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4)),
+                                                        make_input((5, 3, 4)).softmax(dim=1)),
+                            desc=f"3d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3)),
+                                                        make_input((5, 3)).softmax(dim=1)),
+                            desc=f"2d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                        make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)),
+                            desc=f"higher_dim_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((3,)),
+                                                        make_target((), low=0, high=3)),
+                            desc=f"no_batch_dim_{desc}_{reduction}",
+                            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
+            )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('blank', {'blank': 14})
+    ]
+    target_dtypes = [torch.int, torch.long]
+
+    module_inputs = []
+    for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases):
+        def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs):
+            return ctcloss_reference(i, t, il, tl, **constructor_kwargs)
+
+        blank = constructor_kwargs.get('blank', 0)
+        low = 0 if blank == 14 else 1
+        high = 14 if blank == 14 else 15
+
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_1d_target_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_1d_target_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 5))),
+            desc='1d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 12, 1e-3),
+            forward_input=FunctionInput(make_input((4, 12))),
+            desc='1d_affine_GN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 6, 1e-3),
+            forward_input=FunctionInput(make_input((150, 6))),
+            desc='1d_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 5, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 10, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 10))),
+            desc='1d_no_affine_LN'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 2, 3))),
+            desc='2d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_LN'),
+    ]
+
+
+def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    """
+    Error inputs for GroupNorm that test error messages include actual values.
+    """
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(3, 10),  # num_groups=3, num_channels=10
+                forward_input=FunctionInput(),  # Not needed for construction error
+            ),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex=r"num_channels \(10\) must be divisible by num_groups \(3\)"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(5, 13),  # num_groups=5, num_channels=13
+                forward_input=FunctionInput(),  # Not needed for construction error
+            ),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex=r"num_channels \(13\) must be divisible by num_groups \(5\)"
+        ),
+    ]
+
+
+def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input((4, 3, 2, 4))),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+            desc='4d_input')
+    ]
+
+
+def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((3, 2, 5))),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return hingeembeddingloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(()).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return huberloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    lazy = kwargs.get('lazy', False)
+    N = kwargs['N']
+    num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True
+    input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)}
+    input_no_batch_shape = input_no_batch_shape_dict[N]
+    input_batch_shape = (4,) + input_no_batch_shape
+
+    return [
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape))),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape)),
+            desc='tracking_stats'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='tracking_stats_no_batch_dim'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim')
+    ]
+
+def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine_no_bias'),
+    ]
+
+def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def rms_norm_reference_fn(m, p, i):
+        eps = m.eps
+        if eps is None:
+            eps = torch.finfo(i.dtype).eps
+        ndim = i.ndim
+        normalized_shape = m.normalized_shape
+        weight = m.weight
+        dims = [ndim - i - 1 for i in range(len(normalized_shape))]
+        upcasted_i = i.float()
+        result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
+        if weight is not None:
+            result *= weight
+        return result.type_as(i)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+    ]
+
+
+def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3,),
+            forward_input=FunctionInput(make_input((1, 5, 7))),
+            desc='1d'),
+        ModuleInput(
+            constructor_input=FunctionInput(2,),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7))),
+            desc='2d_uneven_pad'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 1., 0.5, 2.),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))),
+            desc='3d_custom_params'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7))),
+            desc='norm'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((1, 3, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((3, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+    ]
+
+
+
+def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, return_indices=True),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='4d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='return_indices'),
+    ]
+
+def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, (2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride_padding'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 7, 6))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 5, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 7, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))),
+            desc='asymsize'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 5, 5, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            desc='channels_last_mem_format'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+            desc='channels_last_3d_mem_format'
+        )
+    ]
+
+
+def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return marginrankingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((50,)), make_input((50,)),
+                                                    make_target((50,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multilabelmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10), low=0, high=10)),
+                        desc=f'1d_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('p', {'p': 2}),
+        ('margin', {'margin': 0.5}),
+        ('weights', {'weight': make_weight(10)})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multimarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weight', {'weight': make_weight(10)}),
+    ]
+
+    def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
+        if weight is not None:
+            result *= weight
+        result = (-result).sum(i.dim() - 1) / i.size(-1)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.mean()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=2)),
+                        desc=desc,
+                        reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return softmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 5)),
+                                                    make_target((5, 5)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
+    samples = []
+    for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
+            None, device, dtype, requires_grad, training):
+        # Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
+        l_args, l_kwargs = (layer_module_input.constructor_input.args,
+                            layer_module_input.constructor_input.kwargs)
+        l_kwargs['device'] = device
+        l_kwargs['dtype'] = dtype
+        encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
+        num_layers = 2
+        # Note: TransformerEncoderLayer takes a "src_mask" while
+        # TransformerEncoder takes a "mask"; rename kwarg appropriately.
+        forward_input = layer_module_input.forward_input
+        if 'src_mask' in forward_input.kwargs:
+            forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
+            del forward_input.kwargs['src_mask']
+        samples.append(ModuleInput(
+            constructor_input=FunctionInput(encoder_layer, num_layers),
+            forward_input=forward_input,
+            desc=layer_module_input.desc
+        ))
+    return samples
+
+def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+    # Samples below where we pass reference_fn are for validating the fast path,
+    # since the fast path requires no_grad mode, we run the fast path in .eval()
+    # and no_grad() in the reference_fn and verify that against the results in train mode.
+    def fast_path_reference_fn(module, parameters, *args, **kwargs):
+        assert module.training
+        module.train(False)
+        with torch.no_grad():
+            output = module(*args, **kwargs)
+        module.train(True)
+        return output
+
+    if training:
+        for norm_first, bias in itertools.product((True, False), (True, False)):
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(
+                        4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
+                    ),
+                    forward_input=FunctionInput(
+                        make_input((2, 3, 4)),
+                    ),
+                    # fastpath doesn't run when bias=False
+                    reference_fn=fast_path_reference_fn if bias else None,
+                    desc=f'fastpath_{bias}_norm_first_{norm_first}'
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        memory_mask = tgt_mask
+        memory_key_padding_mask = tgt_key_padding_mask
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
+        if tgt_key_padding_mask is not None:
+            memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}'
+            ))
+
+    return samples
+
+
+def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for mask, key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        src_mask , tgt_mask = (mask,) * 2
+        src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src = src.transpose(0, 1)
+            tgt = tgt.transpose(0, 1)
+        if key_padding_mask is not None:
+            src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+            ))
+    return samples
+
+
+def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(2, 3).random_(4))
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
+            desc='discontiguous'
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    bool_vals = (True, False)
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3)))
+    products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks)
+    for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products:
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=no_batch_dim_reference_mha,
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=partial(no_batch_dim_reference_mha, batch_first=False),
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        )
+    ]
+
+    is_rnn = kwargs.get('is_rnn', False)
+    if is_rnn:
+        # RNN also supports `nonlinearity` argument.
+        # `tanh` is the default, so we check with `relu`
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'),
+                forward_input=FunctionInput(make_input(5), make_input(10)),
+                reference_fn=no_batch_dim_reference_fn,
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = (
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+    )
+
+    return samples
+
+def make_packed_sequence(inp, batch_sizes):
+    required_grad = inp.requires_grad
+    inp.requires_grad_(False)  # user won't have access to inp so won't be able to get its grads
+    seq = pack_padded_sequence(inp, batch_sizes)
+    seq.data.requires_grad_(required_grad)
+    return seq
+
+
+def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    is_rnn = kwargs['is_rnn']
+    nonlinearity = ('relu', 'tanh')
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+
+    samples = []
+    if is_rnn:
+        prod_gen = product(nonlinearity, bias, batch_first, bidirectional)
+    else:
+        prod_gen = product(bias, batch_first, bidirectional)
+
+    for args in prod_gen:
+        if is_rnn:
+            nl, b, b_f, bidir = args
+        else:
+            b, b_f, bidir = args
+
+        cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        if is_rnn:
+            cons_args['nonlinearity'] = nl
+            cons_args_hidden['nonlinearity'] = nl
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((3, 2))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        if with_packed_sequence:
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+    proj_sizes = (0, 2)
+
+    samples = []
+    prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
+
+    for args in prod_gen:
+        b, b_f, bidir, proj_size = args
+        hidden_size = 3
+        cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((2, 2))),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+        h_out = proj_size if proj_size > 0 else hidden_size
+        hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), hx),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+
+    return samples
+
+
+
+def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 2),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2), 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4), 5),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7),
+            forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding1d_circular_ref(inp, pad):
+        r""" input:
+                [[[0., 1., 2.],
+                  [3., 4., 5.]]]
+                pad: (1, 2)
+                output:
+                    [[[2., 0., 1., 2., 0., 1.],
+                      [5., 3., 4., 5., 3., 4.]]]
+            """
+        return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 1)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding2d_circular_ref(inp, pad):
+        r"""input:
+                [[[[0., 1., 2],
+                   [3., 4., 5.]]]]
+                pad: (1, 2, 2, 1)
+        output:
+            [[[[2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2)
+        return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 2, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 3, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 3, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    def padding3d_circular_ref(inp, pad):
+        r"""input:
+                [[[[[ 0.,  1.,  2.],
+                    [ 3.,  4.,  5.]],
+                   [[ 6.,  7.,  8.],
+                    [ 9., 10., 11.]]]]]
+            pad: (1, 2, 2, 1, 1, 2)
+            output: [[[[[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2)
+        inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3)
+        return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+    ]
+
+
+# All these operators share similar issues on cuDNN and MIOpen
+rnn_gru_lstm_module_info_decorators = (
+    # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
+    # We could not generate a fallback
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_grad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
+    # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_gradgrad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # CUDNN GRU doesn't accept non-contiguous hx
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
+    )
+)
+
+# Start of module error inputs functions.
+
+def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hidden to be 1D or 2D, got 4D instead"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'relu'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'tanh'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead"
+        ),
+    ]
+    return samples
+
+
+def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="hidden_size must be greater than zero"
+        ),
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 10, 0)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="num_layers must be greater than zero"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 2D or 3D input \(got 4D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 3D or 4D input \(got 2D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 4D or 5D input \(got 2D input\)",
+
+        ),
+    ]
+
+
+_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0)
+
+
+# Database of ModuleInfo entries in alphabetical order.
+module_db: list[ModuleInfo] = [
+    ModuleInfo(torch.nn.AdaptiveAvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),
+                   # Fails on backward check if output size is 1x1
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                   ),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool1d,
+               ),
+    ModuleInfo(torch.nn.AvgPool2d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool2d,
+               skips=(
+                   # The difference between channels last backward and
+                   # channels first backward of AvgPool2d on CUDA is too large
+                   # See https://github.com/pytorch/pytorch/issues/107201
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='cuda',),
+               ),),
+    ModuleInfo(torch.nn.AvgPool3d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # No channels_last support for AvgPool1d as it does not take 4D inputs
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # backward not supported on MPS backend
+                   DecorateInfo(skipMPS, 'TestModule', 'test_non_contiguous_tensors'),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm1d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ))
+               ),
+    ModuleInfo(torch.nn.BatchNorm2d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # See https://github.com/pytorch/pytorch/issues/134580
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm3d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
+               module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.CELU,
+               module_inputs_func=module_inputs_torch_nn_CELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Conv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Conv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS, device_type="mps"),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Fails on backward check because ViewAsRealBackward apply contiguous for grad
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
+                                dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64, torch.complex128]),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # ConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.CosineEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ELU,
+               module_inputs_func=module_inputs_torch_nn_ELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.L1Loss,
+               module_inputs_func=module_inputs_torch_nn_L1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.SmoothL1Loss,
+               module_inputs_func=module_inputs_torch_nn_SmoothL1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   # NS: Still fails on MacOS15.1
+                   DecorateInfo(skipIfMPS, 'TestModule', 'test_non_contiguous_tensors',
+                                dtypes=[torch.float16], device_type='mps'),),
+               ),
+    ModuleInfo(torch.nn.LazyConv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Linear,
+               module_inputs_func=module_inputs_torch_nn_Linear,
+               skips=(
+                   # No channels_last support for Linear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Bilinear,
+               module_inputs_func=module_inputs_torch_nn_Bilinear,
+               decorators=[
+                   DecorateInfo(
+                       toleranceOverride({
+                           torch.float32: tol(atol=1e-4, rtol=1e-4),
+                           torch.float64: tol(atol=1e-4, rtol=1e-4)}),
+                       'TestModule', 'test_forward', device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Bilinear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LPPool1d,
+               module_inputs_func=module_inputs_torch_nn_LPPool1d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.LPPool2d,
+               module_inputs_func=module_inputs_torch_nn_LPPool2d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training') and not _macos15_or_newer,
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LPPool3d,
+               module_inputs_func=module_inputs_torch_nn_LPPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(skipIfMPS, device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.MaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool1d,
+               ),
+    ModuleInfo(torch.nn.MaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool2d,
+               ),
+    ModuleInfo(torch.nn.MaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               ),
+    ModuleInfo(torch.nn.KLDivLoss,
+               module_inputs_func=module_inputs_torch_nn_KLDivLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # https://github.com/pytorch/pytorch/issues/115588
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MSELoss,
+               module_inputs_func=module_inputs_torch_nn_MSELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MarginRankingLoss,
+               module_inputs_func=module_inputs_torch_nn_MarginRankingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # derivative for aten::multilabel_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MultiMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multi_margin_loss' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.SoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_SoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelSoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.NLLLoss,
+               module_inputs_func=module_inputs_torch_nn_NLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GaussianNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.PoissonNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.HingeEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.HuberLoss,
+               module_inputs_func=module_inputs_torch_nn_HuberLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: seemingly incorrect output dtype
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.BCELoss,
+               module_inputs_func=module_inputs_torch_nn_BCELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.BCEWithLogitsLoss,
+               module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # see #119108: tolerance issue
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CrossEntropyLoss,
+               module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss,
+               dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False),
+               decorators=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
+                                "test_forward", dtypes=[torch.float16], device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
+                                device_type='cuda'),),
+               ),
+    ModuleInfo(torch.nn.CTCLoss,
+               module_inputs_func=module_inputs_torch_nn_CTCLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # The operator aten::_ctc_loss is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps',),
+                   # derivative for aten::_ctc_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # https://github.com/pytorch/pytorch/issues/115585
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),)
+               ),
+    ModuleInfo(torch.nn.GELU,
+               module_inputs_func=module_inputs_torch_nn_GELU,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GLU,
+               module_inputs_func=module_inputs_torch_nn_GLU,
+               ),
+    ModuleInfo(torch.nn.GroupNorm,
+               module_inputs_func=module_inputs_torch_nn_GroupNorm,
+               module_error_inputs_func=module_error_inputs_torch_nn_GroupNorm,
+               dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True),
+               skips=(
+                   # Tracking at https://github.com/pytorch/pytorch/issues/98089
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_memory_format', device_type='cpu'),
+                   # No channels_last support for GroupNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'),
+                   DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad",
+                                active_if=TEST_WITH_ROCM, device_type='cuda'),)
+               ),
+    ModuleInfo(torch.nn.Hardshrink,
+               module_inputs_func=module_inputs_torch_nn_Hardshrink,
+               ),
+    ModuleInfo(torch.nn.Hardswish,
+               module_inputs_func=module_inputs_torch_nn_Hardswish,
+               supports_gradgrad=False),
+    ModuleInfo(torch.nn.Hardtanh,
+               module_inputs_func=module_inputs_torch_nn_Hardtanh,
+               ),
+    ModuleInfo(torch.nn.InstanceNorm1d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm1d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm2d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm2d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm3d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
+               train_and_eval_differ=True,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_memory_format'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous_tensors'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_forward'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_save_load'),
+                   # No channels_last support for InstanceNorm3d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.LocalResponseNorm,
+               module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
+               ),
+    ModuleInfo(torch.nn.LayerNorm,
+               module_inputs_func=module_inputs_torch_nn_LayerNorm,
+               skips=(
+                   # No channels_last support for LayerNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.RMSNorm,
+               module_inputs_func=module_inputs_torch_nn_RMSNorm,
+               ),
+    # TransformerEncoder takes the same inputs as TransformerEncoderLayer
+    ModuleInfo(torch.nn.TransformerEncoder,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoder,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # Doesn't support device / dtype kwargs directly because it is just a
+                   # container of TransformerEncoderLayers.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),)
+               ),
+    ModuleInfo(torch.nn.TransformerEncoderLayer,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='cpu', active_if=IS_WINDOWS),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
+                                'TestModule', 'test_forward',
+                                device_type='mps'),
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.TransformerDecoderLayer,
+               module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerDecoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Transformer,
+               module_inputs_func=module_inputs_torch_nn_Transformer,
+               # Inputs are too large to run with slow gradcheck
+               # https://github.com/pytorch/pytorch/issues/117140
+               gradcheck_fast_mode=True,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Transformer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiheadAttention,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
+               skips=(
+                   # No channels_last support for MultiheadAttention currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Embedding,
+               module_inputs_func=module_inputs_torch_nn_Embedding,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='mps')],
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ReLU,
+               module_inputs_func=module_inputs_torch_nn_ReLU,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LeakyReLU,
+               module_inputs_func=module_inputs_torch_nn_LeakyReLU,
+               ),
+    ModuleInfo(torch.nn.ReLU6,
+               module_inputs_func=module_inputs_torch_nn_ReLU6,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.PReLU,
+               module_inputs_func=module_inputs_torch_nn_PReLU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNNCell,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.GRUCell,
+               module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.LSTMCell,
+               module_inputs_func=module_inputs_torch_nn_LSTMCell,
+               module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell,
+               ),
+    ModuleInfo(torch.nn.Sigmoid,
+               module_inputs_func=module_inputs_torch_nn_Sigmoid,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LogSigmoid,
+               module_inputs_func=module_inputs_torch_nn_LogSigmoid,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.SiLU,
+               module_inputs_func=module_inputs_torch_nn_SiLU,
+               ),
+    ModuleInfo(torch.nn.Softmax,
+               module_inputs_func=module_inputs_torch_nn_Softmax,
+               ),
+    ModuleInfo(torch.nn.Softmax2d,
+               module_inputs_func=module_inputs_torch_nn_Softmax2d,
+               skips=(
+                   # no channels last support for Softmax2d currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LogSoftmax,
+               module_inputs_func=module_inputs_torch_nn_LogSoftmax,
+               skips=(
+                   # no channels last support for LogSoftmax currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: inf nan error
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Softmin,
+               module_inputs_func=module_inputs_torch_nn_Softmin,
+               skips=(
+                   # no channels last support for Softmin currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Softplus,
+               module_inputs_func=module_inputs_torch_nn_Softplus,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softshrink,
+               module_inputs_func=module_inputs_torch_nn_Softshrink,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softsign,
+               module_inputs_func=module_inputs_torch_nn_Softsign,
+               ),
+    ModuleInfo(torch.nn.Tanh,
+               module_inputs_func=module_inputs_torch_nn_Tanh,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Tanhshrink,
+               module_inputs_func=module_inputs_torch_nn_Tanhshrink,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Threshold,
+               module_inputs_func=module_inputs_torch_nn_Threshold,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Mish,
+               module_inputs_func=module_inputs_torch_nn_Mish,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNN,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators
+               ),
+    ModuleInfo(torch.nn.GRU,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.LSTM,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_LSTM,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               skips=(
+                   # LSTM with projections is not currently supported with MPS
+                   DecorateInfo(skipMPS),),
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.ReflectionPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
+               ),
+    ModuleInfo(torch.nn.ReflectionPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReflectionPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad1d,
+               ),
+    ModuleInfo(torch.nn.ReplicationPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.SELU,
+               module_inputs_func=module_inputs_torch_nn_SELU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad1d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
+               ),
+    ModuleInfo(torch.nn.ZeroPad2d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad3d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CircularPad1d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad1d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
+               ),
+    ModuleInfo(torch.nn.CircularPad2d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad2d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
+               ),
+    ModuleInfo(torch.nn.CircularPad3d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad3d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad1d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
+               ),
+    ModuleInfo(torch.nn.ConstantPad2d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad3d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               )
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedd0c92b6a4da6d7a0e1d30efa3551c05e11208
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
@@ -0,0 +1,840 @@
+import unittest
+from collections.abc import Sequence
+from typing import Optional
+
+import torch
+
+from .common_utils import MACOS_VERSION
+from .opinfo.core import DecorateInfo, OpInfo
+
+
+if torch.backends.mps.is_available():
+
+    def mps_ops_modifier(
+        ops: Sequence[OpInfo],
+        device_type: str = "mps",
+        xfail_exclusion: Optional[list[str]] = None,
+        sparse: bool = False,
+    ) -> Sequence[OpInfo]:
+        if xfail_exclusion is None:
+            xfail_exclusion = []
+
+        # Supported complex OPS
+        SUPPORTED_COMPLEX_OPS = {
+            "__radd__",
+            "__rmul__",
+            "__rsub__",
+            "__getitem__",
+            "_unsafe_masked_index",
+            "_unsafe_masked_index_put_accumulate",
+            "abs",
+            "add",
+            "alias_copy",
+            "argwhere",
+            "atleast_1d",
+            "atleast_2d",
+            "atleast_3d",
+            "as_strided",
+            "as_strided_copy",
+            "as_strided_scatter",
+            "asin",
+            "asinh",
+            "acos",
+            "atan",
+            "broadcast_tensors",
+            "broadcast_to",
+            "chalf",
+            "cfloat",
+            "chunk",
+            "clone",
+            "conj",
+            "conj_physical",
+            "contiguous",
+            "cos",
+            "cosh",
+            "diag",
+            "diag_embed",
+            "diagflat",
+            "diagonal",
+            "diagonal_copy",
+            "diagonal_scatter",
+            "divno_rounding_mode",
+            "dsplit",
+            "empty",
+            "empty_permuted",
+            "empty_strided",
+            "exp",
+            "expm1",
+            "exp2",
+            "expand",
+            "expand_as",
+            "expand_copy",
+            "flatten",
+            "fill",
+            "full",
+            "full_like",
+            "H",
+            "hsplit",
+            "imag",
+            "index_add",
+            "index_copy",
+            "index_select",
+            "index_put",
+            "isfinite",
+            "isinf",
+            "isreal",
+            "item",
+            "kron",
+            "linalg.diagonal",
+            "linalg.householder_product",
+            "linalg.svd",
+            "log10",
+            "log1p",
+            "log2",
+            "log",
+            "logaddexp",
+            "logaddexp2",
+            "mH",
+            "mT",
+            "masked_fill",
+            "masked_scatter",
+            "masked_select",
+            "meshgridlist_of_tensors",
+            "meshgridvariadic_tensors",
+            "movedim",
+            "mul",
+            "narrow",
+            "narrow_copy",
+            "neg",
+            "new_full",
+            "new_ones",
+            "new_zeros",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv_transpose1d",
+            "nn.functional.conv_transpose2d",
+            "nn.functional.conv_transpose3d",
+            "nn.functional.feature_alpha_dropoutwithout_train",
+            "nn.functional.padcircular",
+            "nn.functional.softsign",
+            "nn.functional.tanhshrink",
+            "nn.functional.unfold",
+            "nonzero",
+            "ones",
+            "ones_like",
+            "outer",
+            "permute",
+            "permute_copy",
+            "positive",
+            "randn",
+            "ravel",
+            "real",
+            "repeat_interleave",
+            "reshape_as",
+            "reshape",
+            "resolve_conj",
+            "resolve_neg",
+            "rsqrt",
+            "rsub",
+            "scalar_tensor",
+            "select",
+            "sgn",
+            "sigmoid",
+            "sin",
+            "sinc",
+            "sinh",
+            "slice",
+            "special.spherical_bessel_j0",
+            "special.entr",
+            "special.xlog1py",
+            "special.zeta",
+            "split",
+            "split_with_sizes",
+            "split_with_sizes_copy",
+            "splitlist_args",
+            "sqrt",
+            "squeeze",
+            "squeeze_copy",
+            "squeezemultiple",
+            "sub",
+            "svd",
+            "t",
+            "t_copy",
+            "tanh",
+            "tan",
+            "tensor_split",
+            "transpose",
+            "transpose_copy",
+            "tril",
+            "triu",
+            "true_divide",
+            "T",
+            "unbind",
+            "unbind_copy",
+            "unflatten",
+            "unfold",
+            "unfold_copy",
+            "unsafe_chunk",
+            "unsafe_split",
+            "unsqueeze",
+            "unsqueeze_copy",
+            "view_as",
+            "view_as_real",
+            "view",
+            "view_copy",
+            "vsplit",
+            "zero_",
+            "zeros",
+            "zeros_like",
+            "__rdiv__",
+            "__rmatmul__",
+            "_chunk_cat",
+            "acosh",
+            "all",
+            "allclose",
+            "angle",
+            "any",
+            "addcdiv",
+            "addcmul",
+            "addmmdecomposed",
+            "addmv",
+            "atanh",
+            "bfloat16",
+            "bmm",
+            "bool",
+            "cartesian_prod",
+            "cat",
+            "char",
+            "column_stack",
+            "combinations",
+            "corrcoef",
+            "constant_pad_nd",
+            "cov",
+            "count_nonzero",
+            "diff",
+            "div",
+            "dot",
+            "dstack",
+            "einsum",
+            "eq",
+            "equal",
+            "eye",
+            "fft.fft",
+            "fft.fft2",
+            "fft.fftn",
+            "fft.fftshift",
+            "fft.ifft",
+            "fft.ifft2",
+            "fft.ifftn",
+            "fft.ifftshift",
+            "fft.irfftn",
+            "fft.irfft2",
+            "fft.irfft",
+            "fft.hfftn",
+            "fft.hfft2",
+            "fft.hfft",
+            "flip",
+            "fliplr",
+            "flipud",
+            "float",
+            "gradient",
+            "half",
+            "hstack",
+            "inner",
+            "int",
+            "isclose",
+            "isnan",
+            "ldexp",
+            "lerp",
+            "linalg.multi_dot",
+            "linalg.pinv",
+            "linspace",
+            "linspacetensor_overload",
+            "logical_and",
+            "logical_not",
+            "logical_or",
+            "logical_xor",
+            "logsumexp",
+            "long",
+            "masked.mean",
+            "masked.prod",
+            "masked.std",
+            "masked.sum",
+            "masked.var",
+            "masked.logsumexp",
+            "matmul",
+            "mean",
+            "mm",
+            "mv",
+            "ne",
+            "nn.functional.padconstant",
+            "nn.functional.padreflect",
+            "nn.functional.padreplicate",
+            "nn.functional.pixel_shuffle",
+            "nn.functional.pixel_unshuffle",
+            "nn.functional.rms_norm",
+            "pinverse",
+            "prod",
+            "reciprocal",
+            "roll",
+            "rot90",
+            "short",
+            "square",
+            "stack",
+            "stft",
+            "sum",
+            "sum_to_size",
+            "tensordot",
+            "trace",
+            "trapz",
+            "trapezoid",
+            "vstack",
+            "where",
+            "byte",
+        }
+
+        MACOS_BEFORE_14_4_XFAILLIST = {
+            # These ops work fine in 14.4 but fail in 14.2 or 13.x
+            "fft.hfft2": [torch.complex64],
+        }
+
+        # Those ops are not expected to work
+        UNIMPLEMENTED_XFAILLIST: dict[str, Optional[list]] = {
+            # Failures due to lack of op implementation on MPS backend
+            "logspace": None,
+            "logspacetensor_overload": None,
+            "linalg.eig": None,
+            "linalg.eigvals": None,
+            "put": None,
+            "cauchy_": None,
+            "cauchy": None,
+            "cholesky_inverse": None,
+            "cholesky_solve": None,
+            "frexp": None,
+            "gcd": None,
+            "geqrf": None,
+            "nn.functional.grid_sample": None,  # Unsupported Border padding mode
+            "hash_tensor": None,
+            "heaviside": None,
+            "index_reduceprod": None,
+            "index_reducemean": None,
+            "index_reduceamax": None,
+            "index_reduceamin": None,
+            # "kthvalue": None,
+            "lcm": None,
+            "linalg.cond": None,
+            "linalg.eigh": None,
+            "linalg.eigvalsh": None,
+            "linalg.ldl_factor": None,
+            "linalg.ldl_factor_ex": None,
+            "linalg.ldl_solve": None,
+            "linalg.lstsq": None,
+            "linalg.lstsqgrad_oriented": None,
+            "linalg.matrix_norm": [torch.float32],
+            "linalg.norm": [torch.float32],
+            "linalg.normsubgradients_at_zero": [torch.float32],
+            "linalg.qr": None,
+            "linalg.svdvals": None,
+            "linalg.vecdot": None,
+            "masked.median": None,
+            "matrix_exp": None,
+            "mode": None,
+            "normnuc": None,
+            "nn.functional.fractional_max_pool2d": None,
+            "nn.functional.fractional_max_pool3d": None,
+            "nn.functional.adaptive_avg_pool3d": None,
+            "nn.functional.adaptive_max_pool3d": None,
+            "nn.functional.interpolatearea": None,
+            "nn.functional.interpolatebicubic": [torch.uint8],
+            "nn.functional.ctc_loss": None,
+            "nn.functional.multi_margin_loss": None,
+            "nn.functional.multilabel_margin_loss": None,
+            "nn.functional.pdist": None,
+            "nn.functional.rrelu": None,
+            "nn.functional.norm": None,
+            "ormqr": None,
+            "pca_lowrank": None,
+            "qr": None,
+            "scatter_reduceamax": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "scatter_reduceamin": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "segment_reduce": None,
+            "_segment.reduce": None,
+            "segment.reduce": None,
+            "segment_reduce_offsets": None,
+            "_segment_reduce_offsets": None,
+            "_segment_reduce_lengths": None,
+            "_segment_reducelengths": None,
+            "_segment_reduceoffsets": None,
+            "sparse.mm": None,
+            "sparse.sampled_addmm": None,
+            "sparse.mmreduce": None,
+            "special.airy_ai": None,
+            "special.erfcx": None,
+            "special.laguerre_polynomial_l": None,
+            "special.legendre_polynomial_p": None,
+            "special.log_ndtr": None,
+            "special.ndtri": None,
+            "svd_lowrank": None,
+            "symeig": None,
+            "take": None,
+            "to": None,
+            "vdot": None,
+            "segment_reduce_": None,
+            "_upsample_bilinear2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "_upsample_bicubic2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "geometric": None,
+            "geometric_": None,
+            "log_normal_": None,
+            "log_normal": None,
+            "cdouble": None,
+            "double": None,
+            "nn.functional.softminwith_dtype": None,
+            "log_softmaxwith_dtype": None,
+            "softmaxwith_dtype": None,
+            "float_power": None,
+            "linalg.matrix_rankhermitian": None,
+            "linalg.pinvhermitian": None,
+            "nonzero_static": None,
+            # MPS: input sizes must be divisible by output sizes
+            "nn.functional.adaptive_avg_pool1d": None,
+            "nn.functional.adaptive_avg_pool2d": None,
+            # Convolution for integral types is not supported on MPS
+            "nn.functional.conv1d": [torch.int64],
+            "nn.functional.conv2d": [torch.int64],
+            "nn.functional.conv3d": [torch.int64],
+            "nn.functional.conv_transpose1d": [torch.int64],
+            "nn.functional.conv_transpose2d": [torch.int64, torch.bfloat16],
+            "nn.functional.conv_transpose3d": [
+                torch.int64,
+                torch.bfloat16,
+                torch.float16,
+            ],
+            # Unsupported dtypes
+            "histc": [torch.float16, torch.bfloat16],
+            # GEMM on MPS is not supported for integral types
+            "nn.functional.linear": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            # returned output on CPU is float64
+            "bincount": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+        }
+        UNIMPLEMENTED_XFAILLIST_SPARSE: dict[str, Optional[list]] = {
+            "logspace": None,
+            "logspacetensor_overload": None,
+            "linalg.eig": None,
+            "linalg.eigvals": None,
+            "put": None,
+        }
+
+        if MACOS_VERSION < 15.0:
+            UNIMPLEMENTED_XFAILLIST.update(
+                {
+                    "quantile": None,
+                    "nanquantile": None,
+                }
+            )
+        if sparse:
+            UNIMPLEMENTED_XFAILLIST.update(UNIMPLEMENTED_XFAILLIST_SPARSE)
+
+        UNDEFINED_XFAILLIST: dict[str, Optional[list]] = {
+            # Top 60 operators
+            # topk fails with duplicate indices
+            "topk": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            # Failures due to random output that they generate using
+            # Philox engine causing mismatch with CPU results
+            "multinomial": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],  # random results
+            "uniform": [torch.float16, torch.float32, torch.bfloat16],
+            "rand_like": [torch.float16, torch.float32, torch.bfloat16],
+            "randint": None,
+            "randint_like": None,
+            "randn": None,
+            "randn_like": None,
+            "bernoulli": [torch.float16, torch.float32, torch.bfloat16],
+            "exponential": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.feature_alpha_dropoutwith_train": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "normal": [torch.float16, torch.float32, torch.bfloat16],
+            "normalin_place": [torch.float16, torch.float32, torch.bfloat16],
+            "normalnumber_mean": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.alpha_dropout": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "nn.functional.dropout": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout2d": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout3d": [torch.float16, torch.float32, torch.bfloat16],
+            # See https://github.com/pytorch/pytorch/issues/111479
+            "nn.functional.multi_head_attention_forward": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+            # zero to negative integer powers are undefined
+            "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
+            "resize_": [torch.float16, torch.float32, torch.bfloat16],
+            "resize_as_": [torch.float16, torch.float32, torch.bfloat16],
+            # CPU Errors:
+            "addr": [
+                torch.bool,
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],  # "addmv_impl_cpu" not implemented for 'Half'
+            "as_stridedpartial_views": None,  # cpu result off, showing random values
+            # random results
+            # mps vs cpu:
+            # Mismatched elements: 40 / 96 (41.7%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            # cuda(2.0.0.dev20230301+cu117) vs cpu:
+            # Mismatched elements: 56 / 96 (58.3%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            "nn.functional.scaled_dot_product_attention": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        ON_MPS_XFAILLIST: dict[str, Optional[list]] = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")`
+            "arange": [torch.uint8],
+            # before macOS 13.2 it falls back to cpu and pass the forward pass
+            "grid_sampler_2d": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],  # Unsupported Border padding mode
+            # Failure due to precision issue for fp16
+            # on both cpu and mps there are test cases that might produce inf result
+            # 'nn.functional.pairwise_distance': [torch.float16],
+            # test blow pass on macOS 12 as it falls back to cpu
+            # Argsort case using duplicate indices (undefined behaviour):
+            #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], device='cpu')
+            #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
+            # Elements from index 30 and 5133 are both equal.
+            # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
+            "argsort": [
+                torch.float16,
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.bfloat16,
+            ],
+            # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
+            # The values of the sorted tensor match the CPU,
+            # but in case of the returned indices this results in undefined behaviour.
+            "sort": [
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        EMPTY_OPS_SKIPLIST = {
+            # Fill tensors with uninitialized data, causing mismatch with CPU.
+            # They occasionally match, thus skipping them.
+            # See https://github.com/pytorch/pytorch/issues/100175
+            "new_empty": None,
+            "new_empty_strided": None,
+            "empty_strided": None,
+            # CPU: empty is returning all 0's and there is a mismatch with MPS
+            # allocation (MacOS 13). According to
+            # https://pytorch.org/docs/2.0/generated/torch.empty.html
+            "empty": None,
+            "empty_like": None,
+            "empty_permuted": None,
+        }
+
+        SKIPLIST = {
+            # Unsupported
+            # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
+            "nn.functional.conv3d": None,
+            # The CPU impl of grid_sampler_3d does not use opmath_t, so it has a
+            # large amount of error compared with the MPS impl for half
+            # precision types. So we have to skip these for now.
+            "grid_sampler_3d": [torch.float16, torch.bfloat16],
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            if device_type is not None:
+                d.device_type = device_type
+
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            addDecorator(
+                op,
+                DecorateInfo(
+                    unittest.expectedFailure,
+                    dtypes=[
+                        torch.double,
+                        torch.cdouble,
+                    ],
+                ),
+            )
+            if sparse:
+                # Skipped due to test_sparse_zero_dims test in test_sparse.py which allocates empty tensor
+                # which leads to unexpected success with it
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.skip(
+                            "Skipped due to MPS not supporting complex128 tensors"
+                        ),
+                        dtypes=[
+                            torch.complex128,
+                        ],
+                    ),
+                )
+            if key in EMPTY_OPS_SKIPLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.skip("Skipping empty ops."),
+                        dtypes=EMPTY_OPS_SKIPLIST[key],
+                    ),
+                )
+            if key in SKIPLIST:
+                addDecorator(
+                    op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key])
+                )
+            for xfaillist in [
+                UNIMPLEMENTED_XFAILLIST,
+                UNDEFINED_XFAILLIST,
+                ON_MPS_XFAILLIST,
+            ]:
+                if key in xfaillist and key not in xfail_exclusion:
+                    addDecorator(
+                        op,
+                        DecorateInfo(unittest.expectedFailure, dtypes=xfaillist[key]),
+                    )
+
+            if (
+                key in MACOS_BEFORE_14_4_XFAILLIST
+                and key not in xfail_exclusion
+                and (MACOS_VERSION < 14.4)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=MACOS_BEFORE_14_4_XFAILLIST[key],
+                    ),
+                )
+
+            # If ops is not supported for complex types, expect it to fail
+            if key not in SUPPORTED_COMPLEX_OPS:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=[torch.complex32, torch.complex64],
+                    ),
+                )
+
+        return ops
+
+    def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        XFAILLIST_GRAD = {
+            # Unimplemented ops
+            "_segment_reduce": [torch.float16, torch.float32],
+            "_chunk_cat": [torch.float16, torch.float32],
+            "_upsample_bilinear2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "_upsample_bicubic2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "sparse.mmreduce": [torch.float32],  # csr not supported
+            "linalg.householder_product": None,
+            "unique_consecutive": [torch.float16, torch.float32],
+            "scalar_tensor": [torch.float16, torch.float32],
+            "cdist": [torch.float32],
+            "masked.scatter": [torch.float16, torch.float32],
+            "grid_sampler_3d": None,
+            "index_fill": [torch.float16, torch.float32],  # missing `aten::_unique`.
+            "igamma": None,  # currently not supported for any device
+            "igammac": None,  # currently not supported for any device
+            "linalg.solve": [torch.float16, torch.float32],  # missing `aten::lu_solve`.
+            "linalg.solve_ex": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "linalg.tensorsolve": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "aminmax": [torch.float32, torch.float16],
+            "special.i1": [torch.float16],  # "i1_backward" not implemented for 'Half'
+            "special.i1e": [torch.float16],  # "i1e_backward" not implemented for 'Half'
+            # Correctness issues
+            "atanh": [torch.float32],
+            # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
+            # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
+            # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
+            # Running `msort` with stable `sort` passes.
+            "msort": [torch.float16],
+            # Random output
+            "exponential": [torch.float16, torch.float32],
+            # CPU errors
+            # derivative for zeta is not implemented
+            "special.zeta": None,
+            # derivative for aten::nextafter is not implemented on CPU
+            "nextafter": None,
+            # derivative for aten::floor_divide is not implemented on CPU
+            "floor_divide": [torch.float16, torch.float32],
+            # derivative for aten::narrow_copy is not implemented on CPU
+            "narrow_copy": [torch.float16, torch.float32],
+            # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+            "histogramdd": [torch.float16, torch.float32],
+            # derivative for aten::histogram is not implemented
+            "histogram": [torch.float16, torch.float32],
+            # 'bool' object is not iterable
+            "allclose": [torch.float16, torch.float32],
+            "equal": [torch.float16, torch.float32],
+            # 'float' object is not iterable
+            "item": [torch.float16, torch.float32],
+            # cpu error: grad requires non-empty inputs
+            "randn": [torch.float16, torch.float32],
+            "signal.windows.bartlett": [torch.float32],
+            "signal.windows.blackman": [torch.float32],
+            "signal.windows.cosine": [torch.float32],
+            "signal.windows.exponential": [torch.float32],
+            "signal.windows.gaussian": [torch.float32],
+            "signal.windows.general_cosine": [torch.float32],
+            "signal.windows.general_hamming": [torch.float32],
+            "signal.windows.hamming": [torch.float32],
+            "signal.windows.hann": [torch.float32],
+            "signal.windows.kaiser": [torch.float32],
+            "signal.windows.nuttall": [torch.float32],
+            "eye": [torch.float16, torch.float32],
+            # topk fails with duplicate indices
+            "topk": [torch.float16],
+            # Could not run 'aten::uniform_' with arguments from the 'SparseCPU' backend
+            "to_sparse": None,
+            # Exception: the derivative for '_unique2' is not implemented.
+            "unique": None,
+        }
+
+        SKIPLIST_GRAD = {
+            "nn.functional.pairwise_distance": [torch.float16],
+            # failed assertion `destination datatype must be fp32'
+            "nn.functional.conv1d": [torch.float16],
+            "nn.functional.conv2d": [torch.float16],
+            "nn.functional.conv3d": [torch.float16],
+            "nn.functional.conv_transpose1d": [torch.float16],
+            "nn.functional.conv_transpose2d": [torch.float16],
+            "nn.functional.conv_transpose3d": [torch.float16],
+        }
+
+        ON_MPS_XFAILLIST = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by sample input at index 3 on MPS
+            "nn.functional.conv3d": [torch.float32],
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST_GRAD:
+                addDecorator(
+                    op,
+                    DecorateInfo(unittest.expectedFailure, dtypes=XFAILLIST_GRAD[key]),
+                )
+
+            if key in SKIPLIST_GRAD:
+                addDecorator(op, DecorateInfo(unittest.skip, dtypes=SKIPLIST_GRAD[key]))
+
+            if key in ON_MPS_XFAILLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=ON_MPS_XFAILLIST[key]
+                    ),
+                )
+
+        return ops
+
+    def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        # Error input samples do not take a dtype argument.
+        XFAILLIST = {
+            # Exceptions are not raised
+            "__rmod__",
+            "__rsub__",
+            "__rpow__",
+            "clamp_max",
+            "clamp_min",
+            "masked_scatter",
+            # unsupported float64 dtype
+            "multinomial",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv3d",
+            "gather",
+            "scatter",
+            "scatter_add",
+            # MPS does not support tensor dimensions > 16
+            "amax",
+            "amin",
+            "aminmax",
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST:
+                addDecorator(op, DecorateInfo(unittest.expectedFailure))
+
+        return ops
+else:
+
+    def mps_ops_modifier(
+        ops: Sequence[OpInfo],
+        device_type: str = "mps",
+        xfail_exclusion: Optional[list[str]] = None,
+        sparse: bool = False,
+    ) -> Sequence[OpInfo]:
+        return ops
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a276144e53bd3145590775ecb13573bda3eb12f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
@@ -0,0 +1,3998 @@
+# mypy: ignore-errors
+
+from abc import abstractmethod
+import tempfile
+import unittest
+
+from copy import deepcopy
+from functools import reduce, partial
+from itertools import product
+from operator import mul
+
+
+import torch
+import torch.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import _reduction as _Reduction
+from torch.testing._internal import common_utils
+from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
+    gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
+from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
+from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
+from torch.autograd import Variable
+from torch.types import _TensorOrTensors
+import torch.backends.cudnn
+
+from typing import Union, Any
+from collections.abc import Callable
+from collections.abc import Sequence
+
+TemporaryFile = tempfile.TemporaryFile
+PRECISION = 1e-5
+
+
+def get_reduction(m):
+    result = getattr(m, 'reduction', None)
+    if result is None:
+        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
+    assert result is not None
+    return result
+
+
+def get_weight(m):
+    result = getattr(m, 'weight', None)
+    if result is not None:
+        return result
+    return getattr(m, 'weights', None)
+
+# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
+#
+# The way to check API parity is to add parity tests for the NN module / functional of interest.
+# Here are the detailed steps:
+#
+# For NN module:
+# 1. Make sure you already have a test dict with the module configuration you want to test.
+# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
+#    the Python module constructor arguments. For example, if in the test dict we pass
+#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
+#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
+# 3. If in the process of performing the above step you referenced any variables
+#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
+#    to the test dict to make sure that those variables are populated with the right Python values.
+#    For example, if the Python constructor call is
+#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
+#    the corresponding C++ constructor argument is
+#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
+#    and the `cpp_var_map` entry must be
+#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
+#    used in the C++ constructor argument with the Python tensor value `random_samples`.
+#
+# For NN functional:
+# 1. Make sure you already have a test dict with the functional configuration you want to test.
+# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
+#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
+#    functional optional arguments. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
+#    then the `cpp_options_args` entry should be
+#    "F::InterpolateFuncOptions().size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)".
+# 3. Otherwise, if the test dict's `constructor` entry looks like
+#    `wrap_functional(lambda i: F.some_functional_name(...))`,
+#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
+#    functional function call. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+# 4. If in the process of performing the above two steps you referenced any variables
+#    in the `cpp_options_args` or `cpp_function_call` entry, you must
+#    add `cpp_var_map` entry to the test dict to make sure that those variables
+#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+#    Notice that there are two variables `i` and `t` that need to have their values provided,
+#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
+#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
+#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
+#
+# There are also a few optional flags in the test dict to control the C++ parity test behavior:
+#
+# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
+# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
+
+
+module_tests = [
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
+        input_size=(4, 10),
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
+        with_tf32=True,
+        tf32_precision=0.005,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8, False),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
+        input_size=(4, 10),
+        desc='no_bias',
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
+        with_tf32=True,
+        tf32_precision=0.005,
+        # ROCM: skipping tf32 test on gfx94 archs due to tolerance issue.
+        test_cuda=not (TEST_WITH_ROCM and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName),
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        input_size=(1, 2, 2),
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        constructor_args=(0.1, 0.9),
+        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+        input_size=(4, 4, 5),
+        desc='with_up_down',
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Flatten',
+        input_size=(2, 3, 4, 5),
+        reference_fn=lambda i, *_: torch.flatten(i, 1),
+        default_dtype=torch.double,
+    ),
+    # TODO: reference function
+    dict(
+        module_name='CrossMapLRN2d',
+        constructor_args=(5, 5e-3, 1e-3, 2),
+        cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
+        input_size=(2, 3, 6, 6),
+        check_gradgrad=False,
+        # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
+        check_batched_grad=False,
+        default_dtype=torch.double,
+    ),
+]
+
+
+# Generates rand tensor with non-equal values. This ensures that duplicate
+# values won't be causing test failure for modules like MaxPooling.
+# size should be small, otherwise randperm fails / long overflows.
+def _rand_tensor_non_equal(*size):
+    total = reduce(mul, size, 1)
+    return torch.randperm(total).view(*size).double()
+
+
+def wrap_functional(fn, **kwargs):
+    class FunctionalModule(nn.Module):
+        def forward(self, *args):
+            return fn(*args, **kwargs)
+    return FunctionalModule
+
+
+def poissonnllloss_no_reduce_test():
+    t = torch.randn(10, 10)
+    return dict(
+        fullname='PoissonNLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::poisson_nll_loss('
+                          'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: i.exp() - t.mul(i),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    return dict(
+        fullname='BCELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        precision=7e-4,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    return dict(
+        fullname='BCELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_weights_no_reduce_test():
+    t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), '
+                          'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        precision=3e-4,
+        default_dtype=torch.double,
+    )
+
+
+def bceloss_weights_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    weights = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy(
+            i, t.to(i.options()),
+            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_legacy_enum_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_legacy_enum',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_with_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_with_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_test():
+    t = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_with_log_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_with_log_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_log_target_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_log_target_test():
+    t = torch.rand((), dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_complex_test():
+    t = torch.randn(2, 3, 4, dtype=torch.cdouble)
+    return dict(
+        fullname='L1Loss_no_reduce_complex',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False)
+
+
+def l1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_test():
+    input_size = (2, 3, 4, 5)
+    target = torch.randn(*input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_scalar_test():
+    input_size = ()
+    target = torch.randn(input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': 2}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_neg_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': -1}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
+        input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss2d_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLossNd_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_zero_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_zero_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def huberloss_delta_test():
+    t = torch.randn(2, 3, 4)
+    return dict(
+        fullname='HuberLoss_delta',
+        constructor=wrap_functional(
+            lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
+        cpp_function_call='''F::huber_loss(
+            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_0d_no_reduce_test():
+    t = torch.zeros(()).long()
+    return dict(
+        fullname='MultiLabelMarginLoss_0d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False)
+
+
+def multilabelmarginloss_1d_no_reduce_test():
+    t = Variable(torch.rand(10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_index_neg_test():
+    t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
+    return dict(
+        fullname='MultiLabelMarginLoss_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_no_reduce_test():
+    t = Variable(torch.rand(5, 10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_margin_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def softmarginloss_no_reduce_test():
+    t = torch.randn(5, 5, dtype=torch.double)
+    return dict(
+        fullname='SoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::soft_margin_loss(
+            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 5),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_weights_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    weights = torch.rand(10)
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
+                                                    weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()),
+            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_no_reduce_test():
+    t = torch.rand(1).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_input_0d_target_no_reduce_test():
+    t = torch.rand(()).mul(8).floor().long()
+    return dict(
+        fullname='multimarginloss_1d_input_0d_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_p_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_p_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_margin_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_weights_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='MultiMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
+                                          reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  weight=weights, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def single_batch_reference_fn(input, parameters, module):
+    """Reference function for modules supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    single_batch_input = unsqueeze_inp(input)
+    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
+    with freeze_rng_state():
+        return module(*single_batch_input).squeeze(0)
+
+
+def get_new_module_tests():
+    common_utils.set_rng_seed()
+    new_module_tests = [
+        poissonnllloss_no_reduce_test(),
+        bceloss_no_reduce_test(),
+        bceloss_weights_no_reduce_test(),
+        bce_with_logistic_legacy_enum_test(),
+        bce_with_logistic_no_reduce_test(),
+        bceloss_no_reduce_scalar_test(),
+        bceloss_weights_no_reduce_scalar_test(),
+        bce_with_logistic_no_reduce_scalar_test(),
+        kldivloss_with_target_no_reduce_test(),
+        kldivloss_no_reduce_test(),
+        kldivloss_no_reduce_scalar_test(),
+        kldivloss_with_log_target_no_reduce_test(),
+        kldivloss_no_reduce_log_target_test(),
+        kldivloss_no_reduce_scalar_log_target_test(),
+        l1loss_no_reduce_test(),
+        l1loss_no_reduce_complex_test(),
+        l1loss_no_reduce_scalar_test(),
+        mseloss_no_reduce_test(),
+        mseloss_no_reduce_scalar_test(),
+        nllloss_no_reduce_test(),
+        nllloss_no_reduce_ignore_index_test(),
+        nllloss_no_reduce_weights_test(),
+        nllloss_no_reduce_weights_ignore_index_test(),
+        nllloss_no_reduce_weights_ignore_index_neg_test(),
+        nllloss2d_no_reduce_test(),
+        nllloss2d_no_reduce_weights_test(),
+        nllloss2d_no_reduce_ignore_index_test(),
+        nlllossNd_no_reduce_test(),
+        nlllossNd_no_reduce_weights_test(),
+        nlllossNd_no_reduce_ignore_index_test(),
+        smoothl1loss_no_reduce_test(),
+        smoothl1loss_no_reduce_scalar_test(),
+        smoothl1loss_beta_test(),
+        smoothl1loss_zero_beta_test(),
+        huberloss_delta_test(),
+        multilabelmarginloss_0d_no_reduce_test(),
+        multilabelmarginloss_1d_no_reduce_test(),
+        multilabelmarginloss_index_neg_test(),
+        multilabelmarginloss_no_reduce_test(),
+        hingeembeddingloss_no_reduce_test(),
+        hingeembeddingloss_margin_no_reduce_test(),
+        softmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_weights_no_reduce_test(),
+        multimarginloss_no_reduce_test(),
+        multimarginloss_1d_no_reduce_test(),
+        multimarginloss_1d_input_0d_target_no_reduce_test(),
+        multimarginloss_p_no_reduce_test(),
+        multimarginloss_margin_no_reduce_test(),
+        multimarginloss_weights_no_reduce_test(),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='stride',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad1',
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad2',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad1size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad2size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(0, 4, 10),
+            cudnn=True,
+            desc='zero_batch',
+            with_tf32=True,
+            tf32_precision=0.005,
+        ),
+        dict(
+            fullname='Conv1d_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
+            input_size=(2, 4, 10),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_groups',
+            constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
+            input_size=(2, 4, 6),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_valid',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same2',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d',
+            constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
+            cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
+            cudnn=True,
+            input_size=(1, 3, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='no_bias',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d_groups',
+            constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
+                                    .stride(3).padding(1).output_padding(1).groups(2)''',
+            cudnn=True,
+            input_size=(2, 4, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(2, 3, 7, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='strided',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
+            input_size=(2, 3, 8, 8),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(2, 3, 6, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(0, 3, 7, 5),
+            cudnn=True,
+            desc='zero_batch',
+            check_with_long_tensor=True,
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv2d_groups',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_groups_thnn',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_valid',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same_dilated',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({3, 2}).padding(1).output_padding({1, 1})''',
+            cudnn=True,
+            input_size=(1, 3, 7, 6),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3})
+                                    .padding(1)
+                                    .output_padding({1, 1})
+                                    .groups(1)
+                                    .bias(false)
+                                    .dilation({2, 2})''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose2d_groups',
+            constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
+            cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
+            input_size=(1, 2, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_with_multiplier',
+            constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_strided',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_padded',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_dilated',
+            constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
+            input_size=(2, 4, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='1x1x1_no_bias',
+            check_with_long_tensor=False,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2, 1),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride_padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, (2, 3, 4)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
+            input_size=(0, 3, 3, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            desc='zero_batch',
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv3d_groups',
+            constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated_strided',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_valid',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
+                                    .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(2, 3, 2, 2, 2),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(3, 2, 2, 2),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
+            skip_half=True,
+            desc='complex'
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='mean',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'sum'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='sum',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'max'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='max',
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_mean_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sum_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_max_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sparse',
+            constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            fullname='Embedding_sparse',
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            module_name='PixelShuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
+            input_size=(1, 9, 4, 4),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PixelUnshuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
+            input_size=(1, 1, 12, 12),
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_nearest_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_nearest_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_linear_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_linear_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({2, 2}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 128, 1, 1),
+            fullname='interpolate_nearest_2d_launch_configs',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4),
+            fullname='interpolate_nearest_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_nearest_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bilinear_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bilinear_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bicubic_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bicubic_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_nearest_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_scale_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
+                                        scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            fullname='interpolate_trilinear_scale_3d',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
+                                        mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d_align_corners',
+            pickle=False,
+            default_dtype=torch.double
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_trilinear_scale_3d_align_corners',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 128),
+            fullname='softmax_lastdim_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=0),
+            cpp_options_args='F::SoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim0',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=3),
+            cpp_options_args='F::SoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim3',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(),
+            fullname='softmax_functional_scalar',
+            test_cuda=False,
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=-1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='log_softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='log_softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='log_softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim0',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=3),
+            cpp_options_args='F::LogSoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim3',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(),
+            fullname='log_softmax_scalar',
+            pickle=False,
+        ),
+        dict(
+            fullname='Unfold',
+            constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_input',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(16, 4),
+            check_gradgrad=False,
+            ref=single_batch_reference_fn,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Unfold_int_input',
+            constructor=lambda: nn.Unfold(2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(16, 4),
+            ref=single_batch_reference_fn,
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='RReLU',
+            constructor_args=(0.1, 0.9),
+            cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+            input_size=(),
+            desc='with_up_down_scalar',
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
+            desc='broadcast_lhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
+            desc='broadcast_rhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            constructor_args=(1.5, 1e-05, True),
+            cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            desc='with_non_default_args',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(8), torch.randn(8)),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 16, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(16)
+                                    .dropout(0.0)''',
+            input_size=(2, 3, 4),
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.1,
+            # TODO(#50743): figure out the error
+            # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
+            # at non-singleton dimension 2
+            check_batched_grad=False,
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_size=(2, 3, 4),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.08 if SM90OrLater else 0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Transformer',
+            constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
+            cpp_constructor_args='''torch::nn::TransformerOptions()
+                                    .d_model(4)
+                                    .nhead(2)
+                                    .num_encoder_layers(2)
+                                    .num_decoder_layers(2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kReLU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
+            check_gradgrad=False,
+            desc='multilayer_coder',
+            with_tf32=True,
+            tf32_precision=0.05 if SM90OrLater else 0.03,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Linear',
+            constructor_args=(3, 5),
+            cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
+            input_fn=lambda: torch.rand(3),
+            reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
+            desc="no_batch_dim",
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Flatten',
+            cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
+            constructor_args=(-3, -1),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Unflatten',
+            cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
+            constructor_args=(-2, torch.Size([2, 2])),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='LayerNorm',
+            constructor_args=([56, 56, 56], 1e-5, False),
+            cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
+            input_size=(4, 56, 56, 56),
+            cudnn=True,
+            check_eval=True,
+            gradcheck_fast_mode=True,
+            check_half=True,
+            desc='3d_no_affine_large_feature',
+        ),
+    ]
+
+    # add conv padding mode tests:
+    for padding_mode, cpp_padding_mode in zip(
+            ['reflect', 'circular', 'replicate', 'zeros'],
+            ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros'], strict=True):
+        # conv signature:
+        #     in_channels, out_channels, kernel_size, stride=1,
+        #     padding=0, dilation=1, groups=1,
+        #     bias=True, padding_mode='zeros'
+        for d in (1, 2, 3):
+            if d == 3 and padding_mode == 'reflect':
+                # FIXME: remove after implementing reflection pad 3d
+                #        https://github.com/pytorch/pytorch/issues/27655
+                continue
+            padding = tuple(range(1, d + 1))
+            cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
+            input_size = (2, 2) + (4,) * d
+            output_size = (2, 3) + tuple(p + 1 for p in padding)  # simplified from `(4 + 2 * p - 3) // 2 + 1`
+            new_module_tests.append(
+                dict(
+                    module_name=f'Conv{d}d',
+                    constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
+                    cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
+                                            .stride(2)
+                                            .padding({cpp_padding})
+                                            .dilation(1)
+                                            .groups(1)
+                                            .bias(true)
+                                            .padding_mode({cpp_padding_mode})''',
+                    input_size=input_size,
+                    output_size=output_size,
+                    cudnn=True,
+                    desc=f'{padding_mode}_stride2_pad2',
+                    with_tf32=True,
+                    tf32_precision=0.05,
+                    default_dtype=torch.double,
+                ),
+            )
+
+    # Check that non linear activations work with no batch dimensions
+    non_linear_activations_no_batch = [
+        'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
+        'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
+        'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
+        'Tanhshrink', 'Threshold'
+    ]
+    non_linear_activations_extra_info: dict[str, dict] = {
+        'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
+        'Threshold': {'constructor_args': (2., 1.)},
+        'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        # For RRelu, test that compare CPU and GPU results fail because RNG
+        # is different between CPU and GPU
+        'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
+        'ELU': {'default_dtype': torch.double},
+        'GELU': {'default_dtype': torch.double},
+        'GLU': {'default_dtype': torch.double},
+        'Hardshrink': {'default_dtype': torch.double},
+        'Hardtanh': {'default_dtype': torch.double},
+        'LeakyReLU': {'default_dtype': torch.double},
+        'LogSigmoid': {'default_dtype': torch.double},
+        'Mish': {'default_dtype': torch.double},
+        'PReLU': {'default_dtype': torch.double},
+        'ReLU6': {'default_dtype': torch.double},
+        'ReLU': {'default_dtype': torch.double},
+        'SELU': {'default_dtype': torch.double},
+        'SiLU': {'default_dtype': torch.double},
+        'Sigmoid': {'default_dtype': torch.double},
+        'Softplus': {'default_dtype': torch.double},
+        'Softshrink': {'default_dtype': torch.double},
+        'Softsign': {'default_dtype': torch.double},
+        'Tanh': {'default_dtype': torch.double},
+        'Tanhshrink': {'default_dtype': torch.double},
+    }
+    for non_linear_activation in non_linear_activations_no_batch:
+        activation_test_info = dict(
+            module_name=non_linear_activation,
+            input_size=(4,),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            test_cpp_api_parity=False,
+        )
+        extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
+        activation_test_info.update(extra_info)
+        new_module_tests.append(activation_test_info)
+
+
+    return new_module_tests
+
+
+def kldivloss_reference(input, target, reduction='mean', log_target=False):
+    if log_target:
+        result = torch.exp(target) * (target - input)
+    else:
+        result = target * (target.log() - input)
+    if reduction == 'mean':
+        return result.mean()
+    elif reduction == 'sum':
+        return result.sum()
+    elif reduction == 'batchmean' and result.dim() != 0:
+        return result.sum() / result.size(0)
+    return result
+
+
+def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
+                        reduction='mean'):
+    assert input.dim() >= 3
+    N = input.size(0)
+    C = input.size(1)
+    out_size = (N,) + input.size()[2:]
+    output = torch.zeros(out_size).type_as(input)
+
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    total_weight = 0
+    for tup in product(*[range(size) for size in out_size]):
+        t_nx = target[tup]
+        norm = 0. if ignore_index == t_nx else weight[t_nx].item()
+        input_index = list(tup)
+        input_index.insert(1, t_nx)
+        output[tup] = -input[tuple(input_index)] * norm
+        total_weight += norm
+
+    if reduction == 'mean':
+        return output.sum() / total_weight
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
+                                             label_smoothing=0.0):
+    assert input.dim() >= 2
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    if label_smoothing > 0.0:
+        assert label_smoothing <= 1.0
+        target = (target * (1 - label_smoothing) + label_smoothing / C)
+
+    output = -(input * target * weight).sum(dim=1)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
+                                                reduction='mean', label_smoothing=0.0):
+    log_softmax_input = torch.log_softmax(input, 1)
+    nllloss = F.nll_loss(
+        log_softmax_input,
+        target,
+        weight,
+        ignore_index=ignore_index,
+        reduction=reduction)
+
+    if label_smoothing == 0.0:
+        return nllloss
+
+    assert 0.0 < label_smoothing <= 1.0
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is not None:
+        input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    smooth_loss = -torch.sum(input, 1)
+
+    ignore_mask = target == ignore_index
+    smooth_loss.masked_fill_(ignore_mask, 0.0)
+
+    if reduction == 'mean':
+        if weight is not None:
+            # TODO: This code can path can be removed if #61309 is resolved
+            # loss is normalized by the weights to be consistent with nll_loss_nd
+            ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
+        else:
+            ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
+    elif reduction == 'sum':
+        ret = torch.sum(smooth_loss)
+    else:
+        ret = smooth_loss
+
+    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
+
+
+def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
+                                 label_smoothing=0.0):
+    if input.shape == target.shape:
+        return cross_entropy_loss_prob_target_reference(
+            input,
+            target,
+            weight=weight,
+            reduction=reduction,
+            label_smoothing=label_smoothing)
+    else:
+        return cross_entropy_loss_indices_target_reference(
+            input, target, weight=weight, reduction=reduction,
+            ignore_index=ignore_index, label_smoothing=label_smoothing
+        )
+
+
+def nllloss_reference(input, target, weight=None, ignore_index=-100,
+                      reduction='mean'):
+
+    def nll_loss_helper(input, target, weight, ignore_index):
+        if target == ignore_index:
+            return (0, 0)
+        norm = 1 if weight is None else weight[target]
+        result = -input[target] * norm
+        return (result, norm)
+
+    losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
+                          for i, t in zip(input, target, strict=True)]
+    losses, weights = zip(*losses_and_weights, strict=True)
+    losses_tensor = input.new_tensor(losses)
+    if reduction == 'mean':
+        return sum(losses_tensor) / sum(weights)
+    elif reduction == 'sum':
+        return sum(losses_tensor)
+    else:
+        return losses_tensor
+
+
+def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
+    abs_diff = (input - target).abs()
+    ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
+    lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
+    # when beta <= 0 we should just use l1_loss
+    if beta == 0:
+        output = abs_diff
+    else:
+        output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def huberloss_reference(input, target, reduction='mean', delta=1.0):
+    abs_diff = (input - target).abs()
+    ge_delta_mask = (abs_diff >= delta)
+    lt_delta_mask = (abs_diff < delta)
+    output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multilabelmarginloss_reference(input, target):
+    targets = []
+    for target_index in target:
+        if target_index < 0:
+            break
+        targets.append(target_index)
+
+    sum = 0
+    for target_index in targets:
+        for i in range(len(input)):
+            if i not in targets:
+                sum += max(0, 1 - input[target_index] + input[i])
+
+    return sum
+
+
+def multilabelmarginloss_reference(input, target, reduction='mean'):
+    # make everything 2-dimensional
+    input_dim = input.dim()
+    if input.dim() < 2:
+        assert target.dim() < 2
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+        target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n).zero_()
+    for i in range(n):
+        output[i] = _multilabelmarginloss_reference(input[i], target[i])
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif input_dim < 2:
+        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
+        # back to correct dimensionality
+        return output.squeeze() / dim
+    else:
+        return output / dim
+
+
+def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
+    margin_clamp = (margin - input).clamp(min=0).type_as(input)
+    output = torch.where(target == 1, input, margin_clamp)
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def softmarginloss_reference(input, target, reduction='mean'):
+    output = (1 + (-input * target).exp()).log()
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multimarginloss_reference(input, target_idx, p, margin, weight):
+    if weight is None:
+        weight = input.new(len(input)).fill_(1)
+
+    output = 0
+    for i in range(len(input)):
+        if i != target_idx:
+            output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
+    return output
+
+
+def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
+    if input.dim() < 2:
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+
+    target_dim = target.dim()
+    if target.dim() == 0:
+        target = target.unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n)
+    for x in range(n):
+        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif target_dim == 0:
+        return output.squeeze(0) / dim
+    return output / dim
+
+
+def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    def _cos(a, b):
+        cos = a.new(a.size(0))
+        for i in range(a.size(0)):
+            cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
+        return cos
+
+    output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
+                                reduction='mean'):
+    d_p = torch.pairwise_distance(anchor, positive, p, eps)
+    d_n = torch.pairwise_distance(anchor, negative, p, eps)
+    if swap:
+        d_s = torch.pairwise_distance(positive, negative, p, eps)
+        d_n = torch.min(d_n, d_s)
+
+    output = torch.clamp(margin + d_p - d_n, min=0.0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    output = (-target * (input1 - input2) + margin).clamp(min=0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
+def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
+    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
+    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
+    dt = log_probs.dtype
+    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
+    targets = targets.long()
+    cum_target_lengths = target_lengths.cumsum(0)
+    losses = []
+    for i in range(log_probs.size(1)):
+        input_length = input_lengths[i].item()
+        target_length = target_lengths[i].item()
+        cum_target_length = cum_target_lengths[i].item()
+        targets_prime = targets.new_full((2 * target_length + 1,), blank)
+        if targets.dim() == 2:
+            targets_prime[1::2] = targets[i, :target_length]
+        else:
+            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
+        probs = log_probs[:input_length, i].exp()
+        alpha = log_probs.new_zeros((target_length * 2 + 1,))
+        alpha[0] = probs[0, blank]
+        alpha[1] = probs[0, targets_prime[1]]
+        mask_third = (targets_prime[:-2] != targets_prime[2:])
+        for t in range(1, input_length):
+            alpha_next = alpha.clone()
+            alpha_next[1:] += alpha[:-1]
+            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
+            alpha = probs[t, targets_prime] * alpha_next
+        losses.append(-alpha[-2:].sum().log()[None])
+    output = torch.cat(losses, 0)
+    if reduction == 'mean':
+        output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
+    elif reduction == 'sum':
+        output = output.sum()
+    output = output.to(dt)
+    return output
+
+
+loss_reference_fns: dict['str', Callable] = {
+    'KLDivLoss': kldivloss_reference,
+    'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
+    'NLLLoss': nllloss_reference,
+    'NLLLossNd': nlllossNd_reference,
+    'SmoothL1Loss': smoothl1loss_reference,
+    'HuberLoss': huberloss_reference,
+    'MultiLabelMarginLoss': multilabelmarginloss_reference,
+    'HingeEmbeddingLoss': hingeembeddingloss_reference,
+    'SoftMarginLoss': softmarginloss_reference,
+    'MultiMarginLoss': multimarginloss_reference,
+    'CosineEmbeddingLoss': cosineembeddingloss_reference,
+    'TripletMarginLoss': tripletmarginloss_reference,
+    'MarginRankingLoss': marginrankingloss_reference,
+    'CTCLoss': ctcloss_reference,
+    'CrossEntropyLoss': cross_entropy_loss_reference
+}
+
+
+criterion_tests = []
+
+
+def single_batch_reference_criterion_fn(*args):
+    """Reference function for criterion supporting no batch dimensions.
+
+    The criterion is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    criterion = args[-1]
+
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    def flatten(xs):
+        result = []
+        if isinstance(xs, (list, tuple)):
+            for x in xs:
+                result.extend(flatten(x))
+        else:
+            result.append(xs)
+        return result
+
+    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
+
+    output = criterion(*single_batch_input_args)
+    reduction = get_reduction(criterion)
+
+    if reduction == 'none':
+        return output.squeeze(0)
+    # reduction is 'sum' or 'mean' which results in a scalar
+    return output
+
+
+# Check that regression criterion work with no batch dimensions
+regression_criterion_no_batch = [
+    'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
+]
+reductions = ['none', 'mean', 'sum']
+for name, reduction in product(regression_criterion_no_batch, reductions):
+    regression_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_size=(3, ),
+        target_size=(3, ),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+for reduction in reductions:
+    regression_test_info = dict(
+        fullname=f"KLDivLoss_no_batch_dim_{reduction}",
+        constructor=lambda: nn.KLDivLoss(reduction=reduction),
+        input_fn=lambda: torch.rand((3,)).log(),
+        target_fn=lambda: torch.rand((3,)),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+# Check that classification criterion work with no batch dimensions
+# List of tuples of (name, input_fn, target_fn)
+classification_criterion_no_batch = [
+    (
+        'BCELoss',
+        lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
+    ),
+    ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
+    ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
+    ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
+    (
+        'CosineEmbeddingLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.tensor(1, dtype=torch.double)
+    ),
+    # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
+    ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
+    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
+    (
+        'TripletMarginLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double)
+    ),
+    ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
+]
+classification_criterion_no_batch_extra_info: dict[str, dict] = {
+    'MultiLabelMarginLoss': {'check_gradgrad': False},
+}
+# TODO : Fix these discrepancies
+classification_cpp_parity = {
+    'BCELoss': False,
+    'BCEWithLogitsLoss': False,
+    'HingeEmbeddingLoss': False,
+    'NLLLoss': False,
+    'SoftMarginLoss': False,
+}
+reductions = ['none', 'mean', 'sum']
+for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
+                                                      reductions):
+    classification_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_fn=lambda f=input_fn: f(),
+        target_fn=lambda f=target_fn: f(),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=True,
+        has_parity=classification_cpp_parity.get(name, True)
+    )
+    extra_info = classification_criterion_no_batch_extra_info.get(name, {})
+    classification_test_info.update(extra_info)
+    criterion_tests.append(classification_test_info)
+
+
+class NNTestCase(TestCase):
+
+    # _forward is defined in classes inheriting from NNTestCase
+    @abstractmethod
+    def _forward(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _zero_grad_parameters(self, module: nn.Module) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _backward(self, module: nn.Module,
+                  input: _TensorOrTensors, output: torch.Tensor,
+                  grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
+                  create_graph: bool = False):
+        raise NotImplementedError
+
+    def _jacobian(self, input, num_out):
+        if isinstance(input, tuple):
+            return tuple(self._jacobian(elem, num_out) for elem in input)
+        elif isinstance(input, list):
+            return [self._jacobian(elem, num_out) for elem in input]
+        else:
+            return torch.zeros(input.nelement(), num_out)
+
+    def _flatten_tensors(self, x):
+        if isinstance(x, torch.Tensor):
+            if x.is_sparse:
+                return x.to_dense().view(-1)
+            else:
+                return x.view(-1)
+        else:
+            return tuple(self._flatten_tensors(a) for a in x)
+
+    def _zero_grad_input(self, input):
+        if isinstance(input, torch.Tensor):
+            if input.requires_grad and input.grad is not None:
+                input.grad.zero_()
+                input.grad.detach_()
+        else:
+            for i in input:
+                self._zero_grad_input(i)
+
+    def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        output = self._forward(module, input)
+        output_size = output.nelement()
+
+        if jacobian_input:
+            jacobian_inp = self._jacobian(input, output_size)
+            flat_jacobian_input = list(_iter_tensors(jacobian_inp))
+
+        if jacobian_parameters:
+            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
+            jacobian_param = torch.zeros(num_param, output_size)
+
+        for i in range(output_size):
+            param, d_param = self._get_parameters(module)
+            # make non grad zeros
+            d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param, strict=True)]
+
+            d_out = torch.zeros_like(output)
+            flat_d_out = d_out.view(-1)
+            flat_d_out[i] = 1
+
+            if jacobian_parameters:
+                self._zero_grad_parameters(module)
+            # Tensors will accumulate gradient from multiple steps
+            if jacobian_input:
+                self._zero_grad_input(input)
+            d_input = self._backward(module, input, output, d_out)
+
+            if jacobian_input:
+                for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input), strict=True):
+                    jacobian_x[:, i] = d_x.contiguous().view(-1)
+            if jacobian_parameters:
+                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += jacobian_inp,
+        if jacobian_parameters:
+            res += jacobian_param,
+
+        return res
+
+    def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        def fw(*input):
+            return self._forward(module, input).detach()
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += _get_numerical_jacobian(fw, input, eps=1e-6),
+        if jacobian_parameters:
+            param, _ = self._get_parameters(module)
+            to_cat = []
+            for p in param:
+                jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
+                # get_numerical_jacobian returns a list of tuples but we require a tensor
+                to_cat.append(jacobian[0][0])
+            res += (torch.cat(to_cat, 0),)
+        return res
+
+    def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
+        jacobian_parameters = bool(self._get_parameters(module)[0])
+        analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        analytical_t = list(_iter_tensors(analytical))
+        numerical_t = list(_iter_tensors(numerical))
+
+        differences = []
+        for a, n in zip(analytical_t, numerical_t, strict=True):
+            if a.numel() != 0:
+                differences.append(a.add(n, alpha=-1).abs().max())
+            # TODO: compare structure (ensure analytic jacobian has correct shape)
+        if len(differences) > 0:
+            self.assertLessEqual(max(differences), PRECISION)  # type: ignore[type-var]
+
+
+class TestBase:
+
+    _required_arg_names = {'constructor_args', 'input', 'extra_args'}
+
+    def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
+        self.desc = desc
+        self.fullname = fullname
+        self.constructor = constructor
+        self.reference_fn = reference_fn
+        for name in self._required_arg_names:
+            if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
+                if name in {'constructor_args', 'extra_args'}:
+                    kwargs[name] = ()
+                else:
+                    raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
+        self._extra_kwargs = kwargs
+        self._arg_cache = {}
+
+    def get_name(self):
+        if self.fullname is not None:
+            return 'test_' + self.fullname
+
+        test_name = 'test_' + self.constructor.__name__
+        if self.desc:
+            test_name += '_' + self.desc
+        return test_name
+
+    def _unpack(self, value):
+        if isinstance(value, torch.Tensor):
+            return value
+        elif is_iterable(value):
+            return type(value)(self._unpack(v) for v in value)
+        else:
+            return value
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', True)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', True)
+
+    def _get_arg(self, name, unpack):
+        assert name in self._required_arg_names
+
+        if name not in self._arg_cache:
+            fn_name = name + '_fn'
+            size_name = name + '_size'
+
+            if name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[name]
+            elif fn_name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[fn_name]()
+            else:
+                assert size_name in self._extra_kwargs, \
+                    f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"
+
+                def map_tensor_sizes(sizes):
+                    if isinstance(sizes, list):
+                        return [map_tensor_sizes(s) for s in sizes]
+                    elif isinstance(sizes, torch.Tensor):
+                        return sizes.double()
+                    else:
+                        return torch.randn(sizes)
+
+                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
+
+        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
+
+    def _get_input(self, unpack=True):
+        return self._get_arg('input', unpack)
+
+    def __call__(self, test_case):
+        raise NotImplementedError
+
+
+class ModuleTest(TestBase):
+
+    @abstractmethod
+    def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
+        raise NotImplementedError
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.jacobian_input = kwargs.get('jacobian_input', True)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.should_test_pickle = kwargs.get('pickle', True)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.FIXME_no_cuda_gradgrad_comparison = \
+            kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
+        self.precision = kwargs.get('precision', 2e-4)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.default_dtype = kwargs.get('default_dtype')
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            if self.reference_fn is not None:
+                out = test_case._forward(module, input)
+                ref_input = deepcopy(input)
+                ref_module = deepcopy(module)
+                expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
+                test_case.assertEqual(out, expected_out, exact_dtype=False)
+            if self.check_forward_only:
+                return
+            self.test_noncontig(test_case, module, input)
+
+            if self.should_test_pickle:
+                # TODO: do this with in-memory files as soon as torch.save will support it
+                with tempfile.TemporaryFile() as f:
+                    test_case._forward(module, input)
+                    torch.save(module, f)
+                    f.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    module_copy = torch.load(f, weights_only=False)
+                    test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
+
+            self._do_test(test_case, module, input)
+
+    def noncontiguize(self, obj):
+        if isinstance(obj, list):
+            return [self.noncontiguize(o) for o in obj]
+        elif isinstance(obj, tuple):
+            return tuple(self.noncontiguize(o) for o in obj)
+        tensor = obj
+        ndim = tensor.dim()
+        # Always making only the last dimension noncontiguous is easy to hide
+        # bugs because .view(-1) will still work. So try to find a dim with size
+        # > 1 and make that non-contiguous, i.e., stack + select on the
+        # dimension directly after that.
+        dim = ndim
+        for d in range(ndim):
+            if tensor.size(d) > 1:
+                dim = d + 1
+                break
+        noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
+        assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
+        noncontig.requires_grad = tensor.requires_grad
+        return noncontig
+
+    def test_noncontig(self, test_case, module, input):
+        # check no scalars, can't make non-contig
+        if isinstance(input, torch.Tensor) and input.dim() == 0:
+            return
+        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
+            return
+
+        test_case._zero_grad_parameters(module)
+        test_case._zero_grad_input(input)
+        with freeze_rng_state():
+            output = test_case._forward(module, input)
+            if getattr(module, "return_indices", False):
+                output = output[0]
+            grad_output = output.new(output.shape).normal_()
+            output = output.clone()
+            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
+            d_param = deepcopy(test_case._get_parameters(module)[1])
+
+        nc_input = self.noncontiguize(input)
+        nc_grad_output = self.noncontiguize(grad_output)
+        for contig_i, contig_g in product((True, False), repeat=2):
+            i = input if contig_i else nc_input
+            # Some ops, e.g., nn.Flatten, return gradient that shares
+            # storage with the grad_output. Hence we copy here.
+            go = deepcopy(grad_output if contig_g else nc_grad_output)
+            test_case._zero_grad_parameters(module)
+            test_case._zero_grad_input(i)
+            with freeze_rng_state():
+                out = test_case._forward(module, i)
+                if getattr(module, "return_indices", False):
+                    out = out[0]
+                grad = test_case._backward(module, i, out, go)
+
+                test_case.assertEqual(out, output)
+                test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
+                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
+
+    def test_cuda(self, test_case):
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+
+            type_map = {torch.double: torch.float}
+            cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
+
+            is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)
+
+            gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
+
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args).float().cuda()
+            cpu_param = test_case._get_parameters(cpu_module)
+            gpu_param = test_case._get_parameters(gpu_module)
+            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0], strict=True):
+                gpu_p.data.copy_(cpu_p)
+
+            test_case._zero_grad_input(cpu_input_tuple)
+            test_case._zero_grad_input(gpu_input_tuple)
+            test_case._zero_grad_parameters(cpu_module)
+            test_case._zero_grad_parameters(gpu_module)
+            cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
+            gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
+            if getattr(cpu_module, "return_indices", False):
+                cpu_output = cpu_output[0]
+                gpu_output = gpu_output[0]
+            test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
+
+            # Run backwards on CPU and GPU and compare results
+            for _ in range(5):
+                cpu_gradOutput = cpu_output.clone().normal_()
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
+                cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
+                gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1], strict=True):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
+
+            # Run double-backwards on CPU and GPU and compare results
+            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
+                cpu_output = cpu_module(*cpu_input_tuple)
+                gpu_output = gpu_module(*gpu_input_tuple)
+                if getattr(cpu_module, "return_indices", False):
+                    cpu_output = cpu_output[0]
+                    gpu_output = gpu_output[0]
+
+                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
+                gpu_gradOutput.requires_grad = True
+
+                cpu_gradInputs = torch.autograd.grad(
+                    cpu_output,
+                    cpu_input_tuple + tuple(cpu_module.parameters()),
+                    cpu_gradOutput,
+                    create_graph=True)
+                gpu_gradInputs = torch.autograd.grad(
+                    gpu_output,
+                    gpu_input_tuple + tuple(gpu_module.parameters()),
+                    gpu_gradOutput,
+                    create_graph=True)
+
+                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs, strict=True):
+                    test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
+
+                # We mix output into the second backwards computation so that
+                # torch.autograd.grad doesn't complain that some inputs
+                # are unreachable (which can happen if you differentiate
+                # only on the gradient.
+                if is_any_input_complex:
+                    outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
+                else:
+                    outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
+
+                cpu_gg = torch.autograd.grad(
+                    outputs_cpu,
+                    cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
+                    retain_graph=True)
+                gpu_gg = torch.autograd.grad(
+                    outputs_gpu,
+                    gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
+                    retain_graph=True)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg, strict=True):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
+
+            self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
+
+
+class InputVariableMixin:
+    def _get_input(self):
+        input = TestBase._get_input(self, False)  # type: ignore[arg-type]
+
+        def map_variables(i):
+            if isinstance(i, torch.Tensor):
+                if i.is_floating_point() or i.is_complex():
+                    i.requires_grad = True
+                return i
+            else:
+                return type(i)(map_variables(elem) for elem in i)
+
+        return map_variables(input)
+
+
+class NewModuleTest(InputVariableMixin, ModuleTest):  # type: ignore[misc]
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.cudnn = kwargs.get('cudnn', False)
+        self.check_inplace = kwargs.get('check_inplace', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.skip_double = kwargs.get('skip_double', False)
+        self.skip_half = kwargs.get('skip_half', False)
+        self.with_tf32 = kwargs.get('with_tf32', False)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode')
+        self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
+        self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
+
+    def _check_gradients(self, test_case, module, input_tuple):
+        params = tuple(x for x in module.parameters())
+        num_inputs = len(input_tuple)
+
+        def fn_to_gradcheck(*inputs_and_params, **kwargs):
+            assert not kwargs
+            return test_case._forward(module, inputs_and_params[:num_inputs])
+
+        # gradcheck doesn't support operators that take in dense inputs but
+        # return sparse parameters. This only happens in the case of nn.Embedding
+        # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
+        # is a slightly different version of gradcheck that can handle this.
+        if self.has_sparse_gradients:
+            assert num_inputs == 1
+            test_input_jacobian = torch.is_floating_point(input_tuple[0])
+            test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
+        else:
+            test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
+                                           check_batched_grad=self.check_batched_grad,
+                                           fast_mode=self.gradcheck_fast_mode,
+                                           check_forward_ad=self.supports_forward_ad))
+
+        if self.check_gradgrad:
+            test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
+                                               check_batched_grad=self.check_batched_grad,
+                                               fast_mode=self.gradcheck_fast_mode,
+                                               check_fwd_over_rev=self.supports_fwgrad_bwgrad))
+
+    def _do_test(self, test_case, module, input):
+        num_threads = torch.get_num_threads()
+        torch.set_num_threads(1)
+        input_tuple = input if isinstance(input, tuple) else (input,)
+
+        self._check_gradients(test_case, module, input_tuple)
+
+        # check if module can be printed
+        module.__repr__()
+
+        if self.check_inplace:
+            # check if the inplace variant of the module gives the same result
+            # as the out-of-place
+
+            # check_inplace doesn't support multiple input tensors, since we don't have any modules
+            # that modify the inputs in-place and that accept more than one input
+            assert len(input_tuple) == 1
+            input = input_tuple[0]
+
+            module_ip = self.constructor(*self.constructor_args, inplace=True)
+
+            input_version = input._version
+            with freeze_rng_state():
+                output = module(input)
+            test_case.assertEqual(input._version, input_version)
+
+            input_ip = deepcopy(input)
+            input_ip_clone = input_ip.clone()
+            with freeze_rng_state():
+                output_ip = module_ip(input_ip_clone)
+            test_case.assertNotEqual(input_ip_clone._version, input_version)
+            test_case.assertEqual(output, output_ip)
+            grad = output.data.clone().normal_()
+            if input.grad is not None:
+                with torch.no_grad():
+                    input.grad.zero_()
+            if input_ip.grad is not None:
+                with torch.no_grad():
+                    input_ip.grad.zero_()
+            output.backward(grad)
+            output_ip.backward(grad)
+            test_case.assertEqual(input.grad, input_ip.grad)
+
+        def assert_module_parameters_are(tensor_type, device_id=None):
+            for p in module.parameters():
+                test_case.assertIsInstance(p, tensor_type)
+                if device_id is not None:
+                    test_case.assertEqual(p.get_device(), device_id)
+
+        if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
+            # check that cuda() moves module parameters to correct GPU device,
+            # and that float() casts parameters correctly
+            input_tuple = tuple(t.cuda() for t in input_tuple)
+            module.float().cuda()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+            if torch.cuda.device_count() > 1:
+                input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                module.cuda(1)
+                with torch.cuda.device(1):
+                    module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+        else:
+            # check that float()/double() casters work correctly
+            def to_type(tensor, real, complex):
+                if tensor.is_complex():
+                    return tensor.to(complex)
+                elif tensor.is_floating_point():
+                    return tensor.to(real)
+                else:
+                    return tensor
+
+            def to_half(x):
+                # TODO: torch.complex32 when properly supported
+                return to_type(x, torch.float16, None)
+
+            def to_single(x):
+                return to_type(x, torch.float32, torch.complex64)
+
+            def to_double(x):
+                return to_type(x, torch.float64, torch.complex128)
+
+            # to float
+            input_tuple = tuple(to_single(t) for t in input_tuple)
+            module.float()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.FloatTensor)
+
+            # and back to double
+            input_tuple = tuple(to_double(t) for t in input_tuple)
+            module.double()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.DoubleTensor)
+
+            if TEST_CUDA and self.should_test_cuda:
+                # check that cuda() moves module parameters to correct GPU device,
+                # and that float() casts parameters correctly
+
+                # to GPU0
+                input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
+                module.float().cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # to CPU
+                input_tuple = tuple(t.cpu() for t in input_tuple)
+                module.cpu()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.FloatTensor)
+
+                # back to GPU0
+                input_tuple = tuple(t.cuda() for t in input_tuple)
+                module.cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # test that forwards of module runs correctly without cuDNN
+                if self.cudnn:
+                    with torch.backends.cudnn.flags(enabled=False):
+                        module(*input_tuple)
+                        assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                if torch.cuda.device_count() >= 2:
+                    # test cross-GPU transfer works
+                    # to GPU1
+                    input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                    module.cuda(1)
+                    with torch.cuda.device(1):
+                        module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+
+                if not self.skip_double:
+                    # test double()
+                    input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
+                    module.double().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.DoubleTensor, 0)  # type: ignore[attr-defined]
+
+                # test half()
+                if not self.skip_half:
+                    input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
+                    module.half().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.HalfTensor, 0)  # type: ignore[attr-defined]
+        torch.set_num_threads(num_threads)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+
+class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
+    # TODO: check that criterions don't ignore grad_output
+
+    _required_arg_names = TestBase._required_arg_names.union({'target'})
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.check_half = kwargs.get('check_half', True)
+        self.check_bfloat16 = kwargs.get('check_bfloat16', False)
+        self.check_complex = kwargs.get('check_complex', False)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.with_tf32 = kwargs.get('with_tf32', True)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.default_dtype = kwargs.get('default_dtype')
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            # Check that these methods don't raise errors
+            module.__repr__()
+            str(module)
+
+            target = self._get_target()
+
+            if self.reference_fn is not None:
+                out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
+                ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
+                expected_out = self.reference_fn(*ref_args)
+                test_case.assertEqual(out, expected_out)
+
+            if self.check_forward_only:
+                return
+
+            params = tuple(x for x in module.parameters())
+            if not isinstance(input, tuple):
+                inputs = (input,) + params + (target,)
+
+                def apply_fn(input, target, *params):
+                    return module(input, target)
+            else:
+                inputs = input + params + (target,)
+
+                def apply_fn(input1, input2, target, *params):  # type: ignore[misc]
+                    return module(input1, input2, target)
+
+            gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+            if self.check_gradgrad:
+                gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+    def test_cuda(self, test_case, dtype, extra_args=None):
+        def convert_dtype(obj, dtype, requires_grad=False):
+            if isinstance(obj, torch.Tensor):
+                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
+            elif isinstance(obj, tuple):
+                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
+            else:
+                return obj
+
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+            cpu_target = self._get_target()
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args)
+
+            # Convert input, target and module parameters to dtype
+            cpu_input = convert_dtype(cpu_input, dtype, True)
+            if cpu_target.is_floating_point() or cpu_target.is_complex():
+                cpu_target = convert_dtype(cpu_target, dtype)
+            cpu_module.type(dtype)
+            gpu_module.type(dtype)
+
+            # GPU setup
+            gpu_input = to_gpu(cpu_input)
+            gpu_target = to_gpu(cpu_target)
+            gpu_module.cuda()
+
+            # torch.HalfTensor doesn't support most operations, converting back to default
+            if dtype in {torch.half, torch.bfloat16}:
+                cpu_input = self._get_input()
+                cpu_target = self._get_target()
+                # Loss modules with weights require consistent input/module weight types
+                cpu_module = self.constructor(*self.constructor_args)
+
+            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
+            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_output, gpu_output,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+            cpu_gradInput = test_case._backward_criterion(
+                cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
+            gpu_gradInput = test_case._backward_criterion(
+                gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_gradInput, gpu_gradInput,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', False)
+
+
+def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
+    # fp32 compute
+    input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
+    if scale_factor is not None:
+        input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
+    out1 = op(input1)
+    grad_input1 = torch.randn_like(out1, device=device)
+    out1.backward(grad_input1)
+
+    # bfloat16 compute
+    op_bfp16 = op.bfloat16()
+    input2 = input1.detach().bfloat16().requires_grad_()
+    grad_input2 = grad_input1.bfloat16()
+    out2 = op_bfp16(input2)
+    out2.backward(grad_input2)
+
+    test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
+    test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
+
+def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
+    if not inference:
+        inp.requires_grad_(True)
+    out = module(inp)
+    if not inference:
+        gO = torch.rand_like(out)
+        out.backward(gO)
+    if check_size:
+        test_case.assertEqual(out.size(), inp.size())
+    if not inference:
+        for p in module.parameters():
+            if p.requires_grad:
+                test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
+        test_case.assertEqual(inp.grad, torch.zeros_like(inp))
+
+
+def _create_basic_net():
+    class Layer(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
+
+    class Net(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.l1 = Layer()
+            self.dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
+
+    l = Layer()
+    n = Net()
+    s = nn.Sequential(n, n)
+
+    return l, n, s
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b41e24b96caf24558c6947b6350c7b9c9ac8b7a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
@@ -0,0 +1,2303 @@
+# mypy: ignore-errors
+
+import functools
+import itertools
+import sys
+import unittest
+from copy import deepcopy
+from enum import Enum
+from typing import Any, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Parameter
+from torch.optim import (
+    Adadelta,
+    Adafactor,
+    Adagrad,
+    Adam,
+    Adamax,
+    AdamW,
+    ASGD,
+    LBFGS,
+    Muon,
+    NAdam,
+    Optimizer,
+    RAdam,
+    RMSprop,
+    Rprop,
+    SGD,
+    SparseAdam,
+)
+from torch.optim.lr_scheduler import (
+    ConstantLR,
+    ExponentialLR,
+    LinearLR,
+    PolynomialLR,
+    ReduceLROnPlateau,
+    StepLR,
+)
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_utils import (
+    _TestParametrizer,
+    skipIfMPS,
+    skipIfTorchDynamo,
+    TEST_WITH_TORCHDYNAMO,
+)
+from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
+
+
+CUDA_CONFIG_GPUS = ["cuda", "xpu"]
+
+
+class OptimizerInput:
+    """Contains args / kwargs to be passed to an optimizer constructor."""
+
+    __slots__ = ["params", "kwargs", "desc"]
+
+    def __init__(
+        self,
+        params: Union[
+            list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
+        ],
+        kwargs: dict[str, Any],
+        desc: str = "",
+    ):
+        # params can be a list of Tensors OR param_groups OR None
+        self.params = params
+        self.kwargs = kwargs
+        self.desc = desc
+
+    def __repr__(self):
+        return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
+
+
+class OptimizerErrorEnum(Enum):
+    """Enumerates when an error is raised when testing optimizers."""
+
+    CONSTRUCTION_ERROR = 0
+    STEP_ERROR = 1
+
+
+class ErrorOptimizerInput:
+    """
+    An OptimizerInput that will cause the optimizer to throw an error when constructed.
+    Includes the type and string of the resulting error.
+    """
+
+    __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(
+        self,
+        optimizer_error_input,
+        *,
+        error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        error_type=RuntimeError,
+        error_regex="",
+    ):
+        self.optimizer_error_input = optimizer_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class OptimizerInfo:
+    """Optimizer information to be used in testing."""
+
+    def __init__(
+        self,
+        optim_cls: Optimizer,  # Class object for the Optimizer under test
+        *,
+        # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
+        # to the test using the OptimizerInfo. OptimizerInput.params is likely None.
+        # Can optionally take in device to filter out certain unsupported configs
+        optim_inputs_func,
+        # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
+        # LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
+        # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
+        # LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
+        # A few optimizers like SGD and Adam will test more LRSchedulers.
+        scheduler_inputs=(
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
+        # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
+        supported_impls: tuple[str, ...] = ("foreach", "differentiable"),
+        # A subset of all flags, signifying which ones were only supported after the
+        # original optimizer had already been released. aka impls where we need to check BC.
+        not_og_supported_flags: tuple[str, ...] = (
+            "foreach",
+            "differentiable",
+            "maximize",
+            "capturable",
+        ),
+        # the optim supports passing in sparse gradients as well as dense grads
+        supports_sparse: bool = False,
+        # the optimizer constructor supports passing in capturable as a kwarg
+        has_capturable_arg: bool = False,
+        # the optim only supports one config: sparse grads w/ dense params, see SparseAdam
+        only_supports_sparse_grads: bool = False,
+        # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
+        # with especially tuned hyperparameters. These only apply if the optimizer supports
+        # sparse parameters or grads.
+        metadata_for_sparse=({}, []),
+        # the optim supports complex parameters
+        supports_complex: bool = True,
+        # whether the optimizer.step() function requires a closure to be passed
+        step_requires_closure: bool = False,
+        # whether the optimizer supports per-param options with parameter groups
+        supports_param_groups: bool = True,
+        # whether the optimizer supports parameters on multiple devices
+        supports_multiple_devices: bool = True,
+        skips=(),  # Indicates which tests to skip
+        decorators=None,  # Additional decorators to apply to generated tests
+        optim_error_inputs_func=None,  # Function to generate optim inputs that error
+        supports_fused_on: tuple[str, ...] = (),
+    ):
+        self.optim_cls = optim_cls
+        self.optim_inputs_func = optim_inputs_func
+        self.scheduler_inputs = scheduler_inputs
+        self.supported_impls = supported_impls
+        self.not_og_supported_flags = not_og_supported_flags
+        self.supports_sparse = supports_sparse
+        self.has_capturable_arg = has_capturable_arg
+        self.metadata_for_sparse = metadata_for_sparse
+        self.only_supports_sparse_grads = only_supports_sparse_grads
+        self.supports_complex = supports_complex
+        self.step_requires_closure = step_requires_closure
+        self.supports_param_groups = supports_param_groups
+        self.supports_multiple_devices = supports_multiple_devices
+        self.decorators = (
+            *(decorators if decorators else []),
+            *(skips if skips else []),
+        )
+        self.optim_error_inputs_func = optim_error_inputs_func
+        self.supports_fused_on = supports_fused_on
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(
+                    test_class, test_name, device, dtype, param_kwargs
+                ):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    @property
+    def name(self):
+        return self.optim_cls.__name__
+
+
+class optims(_TestParametrizer):
+    """Decorator for specifying a list of optimizers over which to run a test."""
+
+    def __init__(self, optim_info_iterable, dtypes=None):
+        self.optim_info_list = list(optim_info_iterable)
+
+        # optimizers aren't limited to be one dtype as parameters can have different dtypes
+        # We default to torch.float32, but dtypes should be specified through passed in
+        # parameters.
+        self.dtypes = dtypes if dtypes is not None else [torch.float32]
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError(
+                "The @optims decorator is only intended to be used in a device-specific "
+                "context; use it with instantiate_device_type_tests() instead of "
+                "instantiate_parametrized_tests()"
+            )
+
+        for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
+            # Construct the test name; device / dtype parts are handled outside.
+            # See [Note: device and dtype suffix placement]
+            test_name = optim_info.name
+
+            # Construct parameter kwargs to pass to the test.
+            param_kwargs = {"optim_info": optim_info, "dtype": dtype}
+
+            try:
+
+                @functools.wraps(test)
+                def test_wrapper(*args, **kwargs):
+                    return test(*args, **kwargs)
+
+                decorator_fn = functools.partial(
+                    optim_info.get_decorators,
+                    generic_cls.__name__,
+                    test.__name__,
+                    device_cls.device_type,
+                    dtype,
+                )
+
+                yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+            except Exception as ex:
+                # Provides an error message for debugging before rethrowing the exception
+                print(
+                    f"Failed to instantiate {test_name} for module {optim_info.name}!"
+                )
+                raise ex
+
+
+# Helper function for generating error inputs for all optimizers, used below.
+def get_error_inputs_for_all_optims(device, dtype):
+    if _get_device_type(device) == "cpu":
+        # Creating 2D parameters for compatibility with Muon.
+        sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
+        sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
+        return [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=sample_param,
+                    kwargs={},
+                    desc="invalid param type",
+                ),
+                error_type=TypeError,
+                error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_param, sample_param],
+                    kwargs={},
+                    desc="a param group cannot have duplicate parameters",
+                ),
+                error_type=UserWarning,
+                error_regex=".*a parameter group with duplicate parameters.*",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[{"params": sample_param}, {"params": sample_param}],
+                    kwargs={},
+                    desc="duplicate parameters should not occur across param groups either",
+                ),
+                error_type=ValueError,
+                error_regex="some parameters appear in more than one parameter group",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor([0.001, 0.001])),
+                    desc="Tensor lr must be 1-element",
+                ),
+                error_type=ValueError,
+                error_regex="Tensor lr must be 1-element",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[("weight", sample_param), sample_param2],
+                    kwargs={},
+                    desc="all optimizer params should be with/without names",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer params should be with/without names. Some param names are missing",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {"params": [sample_param], "lr": 1e-2},
+                        {"params": [("weight", sample_param2)]},
+                    ],
+                    kwargs={},
+                    desc="all optimizer param groups should be with/without names.",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer param groups should be with/without names. "
+                "cannot add param group with names to the optimizer",
+            ),
+        ]
+    else:
+        return []
+
+
+# ------------------------------------------------------------------------------------------
+# NOTE: [optimizer kwarg categories]
+# We categorize optimizer kwargs as 3 types:
+#  1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
+#     algorithms and thus only show up for certain optimizers. There are many of these, so I
+#     do not bother gathering them all and listing them here. The converse to these would be
+#     global flags that every optimizer ideally _should_ support. We break global flags into
+#     2 further categories and list them all below.
+#  2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
+#     global-friendly flags are global flags who play nicely with all other global flags,
+#     i.e., are mutually exclusive in function. This means that any pair of the following
+#     flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
+#     following flags theoretically can be enabled with ANY other global flag, including the
+#     cliquey ones (e.g, capturable and foreach).
+#  3. global-cliquey = ["foreach", "fused", "differentiable"]
+#     global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
+#     usually because they contradict each other in function. For example, one should not flip
+#     both foreach AND fused to True, because they are two differing performance optimizations
+#     in which you can only opt into one.
+#
+# The following optim_inputs_func_* sampling functions only return constructor combinations of
+# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
+# well with additional kwargs. On the flip side of the same coin, we reserve setting the
+# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
+
+
+def optim_inputs_func_adadelta(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="capturable with weight decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_adadelta(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, rho=1.1),
+                    desc="rho should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid rho value: 1.1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adafactor(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "lr": 0.01},
+            desc="nonzero weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"beta2_decay": -1.0},
+            desc="non-default beta2_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"d": 1.5},
+            desc="non-default clipping threshold d",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adafactor(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
+        complex_param.grad = torch.rand_like(complex_param)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(eps=(-1e-30, 1e-3)),
+                    desc="epsilon1 should be >= 0",
+                ),
+                error_type=ValueError,
+                error_regex="epsilon1 should be >= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(d=0.0),
+                    desc="invalid d",
+                ),
+                error_type=ValueError,
+                error_regex="Clipping threshold d should be >= 1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(beta2_decay=0.8),
+                    desc="invalid beta2_decay",
+                ),
+                error_type=ValueError,
+                error_regex="beta2_decay should be <= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[complex_param],
+                    kwargs=dict(),
+                    desc="does not support complex parameters",
+                ),
+                error_type=RuntimeError,
+                error_regex="Adafactor does not support complex parameters",
+                error_on=OptimizerErrorEnum.STEP_ERROR,
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adagrad(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
+            desc="initial_accumulator_value",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
+            desc="lr_decay",
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001)},
+            desc="Tensor lr",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adagrad(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, lr_decay=-0.5),
+                    desc="lr_decay must be bigger than 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid lr_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
+# with all implementation code paths...
+def optim_inputs_func_adam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
+            desc="capturable, amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
+            desc="Tensor lr with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])),
+                "amsgrad": True,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor(0.9), torch.tensor(0.99)),
+                "amsgrad": False,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable",
+        ),
+    ]
+    mps_supported_configs = [
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
+        ),
+    ]
+
+    total = (
+        [
+            OptimizerInput(params=None, kwargs={}, desc="default"),
+            OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+            OptimizerInput(
+                params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "maximize": True},
+                desc="maximize",
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "amsgrad": True},
+                desc="amsgrad",
+            ),
+        ]
+        + (
+            cuda_supported_configs
+            if _get_device_type(device) in CUDA_CONFIG_GPUS
+            else []
+        )
+        + (mps_supported_configs if _get_device_type(device) == "mps" else [])
+    )
+    if dtype == torch.float16:
+        for input in total:
+            """
+            Too small eps will make denom to be zero for low precision dtype
+            denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+            For example,
+            >>> a
+            tensor([0.], dtype=torch.float16)
+            >>> a + 1e-8
+            tensor([0.], dtype=torch.float16)
+            """
+            input.kwargs["eps"] = 0.1
+    return total
+
+
+def optim_error_inputs_func_adam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor(0.001), foreach=True),
+                    desc="lr as Tensor doesn't work with foreach & not capturable",
+                ),
+                error_type=ValueError,
+                error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(
+                        lr=1e-2,
+                        betas=(torch.tensor(0.9), torch.tensor(0.99)),
+                        foreach=True,
+                    ),
+                    desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+                ),
+                error_type=ValueError,
+                error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+        ]
+    if _get_device_type(device) in CUDA_CONFIG_GPUS:
+        sample_tensor = torch.empty((), device=device, dtype=dtype)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"foreach": True, "fused": True},
+                    desc="`fused` and `foreach` cannot be `True` together",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` and `foreach` cannot be `True` together",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"fused": True, "differentiable": True},
+                    desc="`fused` does not support `differentiable`",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` does not support `differentiable`",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamax(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
+            desc="capturable, maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "maximize": False,
+                "capturable": True,
+            },
+            desc="capturable, weight_decay, tensor LR",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_adamax(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
+                    desc="beta2 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 1: 1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamw(device, dtype=None):
+    return optim_inputs_func_adam(device, dtype)
+
+
+def optim_error_inputs_func_adamw(device, dtype):
+    return optim_error_inputs_func_adam(device, dtype)
+
+
+def optim_inputs_func_asgd(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True, "capturable": True},
+            desc="maximize, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="maximize, weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.1,
+                "maximize": True,
+                "capturable": True,
+            },
+            desc="maximize, weight_decay, capturable, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
+        OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, nonzero weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_asgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-0.5),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_lbfgs(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"line_search_fn": "strong_wolfe"},
+            desc="strong_wolfe",
+        ),
+    ]
+
+
+def optim_error_inputs_func_lbfgs(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    return error_inputs
+
+
+def optim_inputs_func_muon(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.2},
+            desc="non-default weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.8},
+            desc="non-default momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"ns_steps": 6},
+            desc="passing alternative ns_steps",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "ns_coefficients": (3.4, -4.7, 2.0),
+            },
+            desc="passing alternative ns_coefficients",
+        ),
+    ]
+
+
+def optim_error_inputs_func_muon(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
+    complex_param.grad = torch.rand_like(complex_param)
+    non_2d_param = torch.rand(2, 3, 4, device=device, dtype=dtype)
+    non_2d_param.grad = torch.rand_like(non_2d_param)
+    param = torch.rand(2, 3, device=device, dtype=dtype)
+    param.grad = torch.rand_like(param)
+    error_inputs += [
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[non_2d_param],
+                kwargs=dict(),
+                desc="only support 2D parameters",
+            ),
+            error_type=ValueError,
+            error_regex="Muon only supports 2D parameters",
+            error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        ),
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[param],
+                kwargs={"adjust_lr_fn": "arbitrary"},
+                desc="only support `original` and `match_rms_adamw`",
+            ),
+            error_type=ValueError,
+            error_regex="Adjust learning rate function arbitrary is not supported",
+            error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        ),
+        ErrorOptimizerInput(
+            OptimizerInput(
+                params=[complex_param],
+                kwargs=dict(),
+                desc="does not support complex parameters",
+            ),
+            error_type=RuntimeError,
+            error_regex="Muon does not support complex parameters",
+            error_on=OptimizerErrorEnum.STEP_ERROR,
+        ),
+    ]
+    return error_inputs
+
+
+def optim_inputs_func_nadam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum_decay": 6e-3},
+            desc="non-zero momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+            },
+            desc="weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+            desc="weight_decay, momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+            },
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_nadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum_decay=-0.2),
+                    desc="momentum_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum_decay value: -0.2",
+            ),
+        ]
+    return error_inputs
+
+
+# Weird story bro, NAdam and RAdam do not have maximize.
+def optim_inputs_func_radam(device=None, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+            },
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_radam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rmsprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+            },
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True},
+            desc="centered",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+                "weight_decay": 0.1,
+            },
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
+            desc="momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "centered": True,
+                "momentum": 0.1,
+                "maximize": True,
+            },
+            desc="maximize, centered, weight_decay, w/ momentum",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_rmsprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-1.0),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"step_sizes": (2e-6, 100)},
+            desc="non-default step_sizes",
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ] + (cuda_supported_configs if _get_device_type(device) in CUDA_CONFIG_GPUS else [])
+
+
+def optim_error_inputs_func_rprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
+                    desc="0 < eta1 < 1 < eta2",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid eta values: 1.0, 0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sgd(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "dampening": 0.5},
+            desc="dampening",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "weight_decay": 0.1},
+            desc="weight_decay w/ momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
+            desc="nesterov",
+        ),
+    ]
+
+
+def optim_error_inputs_func_sgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-0.5),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sparseadam(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"lr": 0.01}, desc="non-default lr"
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ]
+
+
+def optim_error_inputs_func_sparseadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        torch.zeros(
+                            3, layout=torch.sparse_coo, device=device, dtype=dtype
+                        )
+                    ],
+                    kwargs={},
+                    desc="dense params required",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {
+                            "params": [
+                                torch.zeros(
+                                    3,
+                                    layout=torch.sparse_coo,
+                                    device=device,
+                                    dtype=dtype,
+                                )
+                            ]
+                        }
+                    ],
+                    kwargs={},
+                    desc="dense params required in param_groups",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
+                    kwargs={},
+                    desc="complex not supported",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam does not support complex parameters",
+            ),
+        ]
+    return error_inputs
+
+
+def _get_device_type(device: Union[str, torch.device]) -> str:
+    # Returns the device type as a string, e.g., "cpu" or "cuda"
+    if isinstance(device, torch.device):
+        device = str(device.type)
+    assert isinstance(device, str)
+    return device.split(":")[0]
+
+
+def _get_optim_inputs_including_global_cliquey_kwargs(
+    device, dtype, optim_info, skip=()
+) -> list[OptimizerInput]:
+    """
+    Return a list of all configs for a given optimizer as a list of OptimizerInputs,
+    including configs that have supported global cliquey kwargs (foreach, fused,
+    differentiable) based on optim_info.supported_impls.
+
+    The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
+    intentionally do NOT include global cliquey kwargs to give flexibility to tests.
+    For example, testing correctness between toggling foreach on and off is now
+    trivial. That said, we sometimes want to test for all possible configs on an
+    optimizer including all supported flags, so this helper returns all optim inputs.
+    """
+    assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
+        "skip must be a subset of ['foreach', 'fused', 'differentiable']"
+    )
+
+    optim_inputs = optim_info.optim_inputs_func(device)
+
+    supported_impls = tuple(
+        x
+        for x in optim_info.supported_impls
+        if x not in skip
+        and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
+        and (
+            _get_device_type(device) in _get_foreach_kernels_supported_devices()
+            or x != "foreach"
+        )
+    )
+
+    all_optim_inputs = []
+    for optim_input in optim_inputs:
+        # Add the base config where all the flags are False
+        base_kwargs = deepcopy(optim_input.kwargs)
+        if len(supported_impls) != 0:
+            for flag in supported_impls:
+                base_kwargs[flag] = False
+            all_optim_inputs.append(
+                OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
+            )
+        else:
+            all_optim_inputs.append(optim_input)
+        # Add a config for when each of the global cliquey kwargs is True
+        # Note that in [optimizer kwarg categories], these kwargs are mutually
+        # exclusive, so we do not need to product them together.
+        for flag in supported_impls:
+            new_kwargs = deepcopy(base_kwargs)
+            new_kwargs[flag] = True
+            all_optim_inputs.append(
+                OptimizerInput(
+                    params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
+                )
+            )
+    return all_optim_inputs
+
+
+# Database of OptimizerInfo entries in alphabetical order.
+optim_db: list[OptimizerInfo] = [
+    OptimizerInfo(
+        Adadelta,
+        optim_inputs_func=optim_inputs_func_adadelta,
+        optim_error_inputs_func=optim_error_inputs_func_adadelta,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            # Note on tolerances:
+            # test_correctness_Adadelta_cuda_float32
+            # Mismatched elements: 10 / 100 (10.0%)
+            # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
+            # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
+            # This is due to floating point ordering error + usage of sqrt
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            rtol=5.5e-4,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adafactor,
+        optim_inputs_func=optim_inputs_func_adafactor,
+        optim_error_inputs_func=optim_error_inputs_func_adafactor,
+        supported_impls=("foreach",),
+        not_og_supported_flags=("foreach",),
+        supports_complex=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="cuda",
+                active_if=lambda kwargs: kwargs.get("use_closure", False),
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_peak_memory_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028 regarding copy not supported"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="xpu",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adagrad,
+        optim_inputs_func=optim_inputs_func_adagrad,
+        optim_error_inputs_func=optim_error_inputs_func_adagrad,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu",),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
+            [
+                lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
+                lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
+            ],
+        ),
+        decorators=(
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adam,
+        optim_inputs_func=optim_inputs_func_adam,
+        scheduler_inputs=(
+            [lambda opt: ExponentialLR(opt, gamma=0.9)],
+            [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
+            [
+                lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+            ],
+            [
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_adam,
+        supported_impls=("foreach", "differentiable", "fused"),
+        has_capturable_arg=True,
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "xpu", "mps"),
+        decorators=(
+            # Expected floating point error between fused and compiled forloop
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            DecorateInfo(
+                # Note on tolerances:
+                # Tracking through #127000
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=3e-5, rtol=1.3e-06),
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adamax,
+        optim_inputs_func=optim_inputs_func_adamax,
+        optim_error_inputs_func=optim_error_inputs_func_adamax,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Uses too much memory, even for H100, surprisingly."),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        AdamW,
+        optim_inputs_func=optim_inputs_func_adamw,
+        optim_error_inputs_func=optim_error_inputs_func_adamw,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "mps"),
+        has_capturable_arg=True,
+        decorators=(
+            # Expected error between compiled forloop and fused optimizers
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    #  Note on tolerances:
+                    #  difference comes from the fact that the non fused kernel have
+                    #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                    #  to make sure there is no discrepancies between cuda fused kernel
+                    #  and cpu fused kernel
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            # Note on tolerances:
+            # Tracking through #127000
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            atol=3e-5,
+                            rtol=1.3e-06,
+                        )
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        ASGD,
+        optim_inputs_func=optim_inputs_func_asgd,
+        optim_error_inputs_func=optim_error_inputs_func_asgd,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1.5e-5, rtol=1e-5),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "ASGD internally changes the weights even with zero grad"
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        LBFGS,
+        optim_inputs_func=optim_inputs_func_lbfgs,
+        optim_error_inputs_func=optim_error_inputs_func_lbfgs,
+        supported_impls=(),
+        step_requires_closure=True,
+        supports_param_groups=False,
+        supports_multiple_devices=False,
+        skips=(
+            # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
+            DecorateInfo(
+                skipIfMPS,
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.complex64: tol(
+                            rtol=4.5e-5,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+            ),
+            DecorateInfo(
+                unittest.skip("LBFGS doesn't support multidevice"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            # https://github.com/pytorch/pytorch/issues/131398
+            DecorateInfo(
+                unittest.expectedFailure,
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                active_if=lambda kwargs: sys.platform == "darwin"
+                and kwargs["use_closure"],
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Muon,
+        optim_inputs_func=optim_inputs_func_muon,
+        optim_error_inputs_func=optim_error_inputs_func_muon,
+        supported_impls=(),
+        not_og_supported_flags=(),
+        supports_complex=False,
+        skips=(
+            # Note on numerical differences: `compile` applies different matmul tuning,
+            # which leads to deviations compared to eager mode. In the Newton-Schulz
+            # iteration for orthogonalization, computations are done in bfloat16, further
+            # amplifying these numerical differences.
+            DecorateInfo(
+                unittest.skip(
+                    "Expect high difference between compiled and eager due to bfloat16 and iterative process."
+                ),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        NAdam,
+        optim_inputs_func=optim_inputs_func_nadam,
+        optim_error_inputs_func=optim_error_inputs_func_nadam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors, https://github.com/pytorch/pytorch/issues/117150"
+                ),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RAdam,
+        optim_inputs_func=optim_inputs_func_radam,
+        optim_error_inputs_func=optim_error_inputs_func_radam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        # previously atol=1e-7, rtol=1e-7
+                        torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RMSprop,
+        optim_inputs_func=optim_inputs_func_rmsprop,
+        optim_error_inputs_func=optim_error_inputs_func_rmsprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
+                        torch.float32: tol(atol=5e-04, rtol=0.01),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+                active_if=TEST_WITH_TORCHDYNAMO,
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Rprop,
+        optim_inputs_func=optim_inputs_func_rprop,
+        optim_error_inputs_func=optim_error_inputs_func_rprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SGD,
+        optim_inputs_func=optim_inputs_func_sgd,
+        scheduler_inputs=(
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
+            [
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.8, total_iters=4
+                )
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.6, total_iters=4
+                ),
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.99, step_size=10),
+                lambda opt: ExponentialLR(opt, gamma=0.99),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_sgd,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {
+                "lr": 4.8e-3,
+                "maximize": False,
+                "momentum": 0,
+                "nesterov": False,
+                "weight_decay": 0,
+            },
+            [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
+        ),
+        supports_fused_on=(
+            "cpu",
+            "cuda",
+            "xpu",
+            "mps",
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SparseAdam,
+        optim_inputs_func=optim_inputs_func_sparseadam,
+        optim_error_inputs_func=optim_error_inputs_func_sparseadam,
+        supported_impls=(),
+        only_supports_sparse_grads=True,
+        metadata_for_sparse=({"lr": 4e-2}, []),
+        supports_complex=False,  # Missing complex support, see #118153
+        skips=(
+            DecorateInfo(
+                skipIfMPS,  # SparseAdam does not support MPS
+                "TestOptimRenewed",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_tensor_lr",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "SparseAdam does not support dense gradients, see #116507"
+                ),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_state_dict_with_cuda_params",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+            ),
+        ),
+    ),
+]
+
+
+class TensorTracker:
+    """
+    A utility to track tensor clones in a list, with the expectation of popping them later (in
+    order) to make fair comparisons between two multi-step computation. The intended use case is
+    usually when comparing two supposed equal computations, such as an optimizer step that each
+    individually consists of multiple steps, where numerical deviation could multiply.
+
+    The goal is to be able to compare and align numbers at every milestone so as to minimize
+    numerical discrepancies, and so when the test fails, it is likely a real problem.
+    """
+
+    def __init__(self, assert_eq_kwargs=None):
+        if assert_eq_kwargs is None:
+            assert_eq_kwargs = {}
+        self.assert_eq_kwargs = assert_eq_kwargs
+        self.tensors = []
+
+    def add(self, tensor):
+        """
+        Add a detach().clone()'d version of the tensor
+        """
+        self.tensors.append(tensor.detach().clone())
+
+    # pops from beginning, like a queue and not a stack!
+    def pop_check_set(self, tensor_to_set, testcase):
+        """
+        Pop the first element in the tensor tracker, assert equality between the popped tensor and
+        the input tensor, and then set the input tensor to have the same values as the popped tensor
+        (with copy_).
+        """
+        testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
+        ref = self.tensors.pop(0)
+
+        testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
+        testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
+
+        with torch.no_grad():
+            tensor_to_set.copy_(ref)
+
+    def all_popped(self):
+        return len(self.tensors) == 0
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd57fa976ebc671e0184cc1a32128a3aed5b6bf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
@@ -0,0 +1,675 @@
+# mypy: ignore-errors
+
+r"""Importing this file includes common utility methods for checking quantized
+tensors and modules.
+"""
+import numpy as np
+import torch
+from torch import Tensor
+from contextlib import contextmanager
+from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS
+
+supported_qengines = torch.backends.quantized.supported_engines
+# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
+# QNNPACK is not supported on PPC
+if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]):
+    supported_qengines.remove('qnnpack')
+
+def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
+                       output_padding=0):
+    """Computes the output shape given convolution parameters."""
+    return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
+                     * (dilation - 1)) / stride) + 2 * output_padding + 1
+
+# Quantization references
+def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
+    """Quantizes a numpy array."""
+    if qmin is None:
+        qmin = np.iinfo(dtype).min
+    if qmax is None:
+        qmax = np.iinfo(dtype).max
+    qx = np.round(x / scale + zero_point).astype(np.int64)
+    qx = np.clip(qx, qmin, qmax)
+    qx = qx.astype(dtype)
+    return qx
+
+
+def _dequantize(qx, scale, zero_point):
+    """Dequantizes a numpy array."""
+    x = (qx.astype(float) - zero_point) * scale
+    return x
+
+
+def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
+    """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
+    converted back to given type"""
+    qx = (x * multiplier).round() + zero_point
+    qx = np.clip(qx, qmin, qmax).astype(qtype)
+    return qx
+
+def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
+    if qscheme == torch.per_tensor_symmetric:
+        assert dtype == torch.qint8
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    if dtype == torch.qint8:
+        if reduce_range:
+            qmin, qmax = -64, 63
+        else:
+            qmin, qmax = -128, 127
+    else:  # dtype == torch.quint8
+        if reduce_range:
+            qmin, qmax = 0, 127
+        else:
+            qmin, qmax = 0, 255
+    min_val = X.min()
+    max_val = X.max()
+    is_symmetric = (qscheme == torch.per_tensor_symmetric)
+    if min_val == max_val:
+        scale = 1.0
+        zero_point = 0
+    else:
+        if is_symmetric:
+            max_val = max(max_val, -min_val)
+            min_val = -max_val
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = qmin - round(min_val / scale)
+            zero_point = max(qmin, zero_point)
+            zero_point = min(qmax, zero_point)
+    return [float(scale), int(zero_point)]
+
+def _calculate_dynamic_per_channel_qparams(X, dtype):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
+    n_levels = qmax - qmin
+    scale = np.zeros(X.shape[0], dtype=np.float64)
+    zero_point = np.zeros(X.shape[0], dtype=np.int64)
+    for i in range(zero_point.shape[0]):
+        min_val = X.min()
+        max_val = X.max()
+        if min_val == max_val:
+            scale[i] = 1.0
+            zero_point[i] = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale[i] = (max_val - min_val) / n_levels
+            scale[i] = max(scale[i], np.finfo(np.float32).eps)
+            zero_point[i] = qmin - round(min_val / scale[i])
+            zero_point[i] = max(qmin, zero_point[i])
+            zero_point[i] = min(qmax, zero_point[i])
+
+    return scale, zero_point
+
+def _snr(x, x_hat):
+    """Calculates the signal to noise ratio and returns the signal and noise
+    power, as well as the SNR in dB.
+    If the input is a list/tuple this function is called recursively on each
+    element. The result will have the same nested structure as the inputs.
+
+    Args:
+        x, x_hat: Either a tensor or a nested list/tuple of tensors.
+    Returns:
+        signal, noise, SNR(in dB): Either floats or a nested list of floats
+    """
+    if isinstance(x, (list, tuple)):
+        assert len(x) == len(x_hat)
+        res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))]
+        return res
+    if x_hat.is_quantized:
+        x_hat = x_hat.dequantize()
+    if x.is_quantized:
+        x = x.dequantize()
+    noise = (x - x_hat).norm()
+    if noise == 0:
+        return 0.0, float('inf'), float('inf')
+    signal = x.norm()
+    snr = signal / noise
+    snr_db = 20 * snr.log10()
+    return signal, noise, snr_db
+
+@contextmanager
+def override_quantized_engine(qengine):
+    previous = torch.backends.quantized.engine
+    torch.backends.quantized.engine = qengine
+    try:
+        yield
+    finally:
+        torch.backends.quantized.engine = previous
+
+@contextmanager
+def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
+    try:
+        if qengine_is_qnnpack:
+            torch._C._set_default_mobile_cpu_allocator()
+        yield
+    finally:
+        if qengine_is_qnnpack:
+            torch._C._unset_default_mobile_cpu_allocator()
+
+# TODO: Update all quantization tests to use this decorator.
+# Currently for some of the tests it seems to have inconsistent params
+# for fbgemm vs qnnpack.
+def override_qengines(qfunction):
+    def test_fn(*args, **kwargs):
+        for qengine in supported_qengines:
+            with override_quantized_engine(qengine):
+                # qfunction should not return anything.
+                qfunction(*args, **kwargs)
+    return test_fn
+
+def qengine_is_fbgemm():
+    return torch.backends.quantized.engine == 'fbgemm'
+def qengine_is_qnnpack():
+    return torch.backends.quantized.engine == 'qnnpack'
+def qengine_is_onednn():
+    return torch.backends.quantized.engine == 'onednn'
+def qengine_is_x86():
+    return torch.backends.quantized.engine == 'x86'
+
+# Helper function used to simulate per-channel fake-quant against any axis
+def _permute_to_axis_zero(X, axis):
+    new_axis_list = list(range(X.dim()))
+    new_axis_list[axis] = 0
+    new_axis_list[0] = axis
+    y = X.permute(tuple(new_axis_list))
+    return y, new_axis_list
+
+# Reference method for fake quantize
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    res = torch.zeros_like(X)
+
+    for i in range(X.size()[0]):
+        res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
+                  per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
+
+    out = res.permute(tuple(permute_axis_list))
+    return out.to(dtype)
+
+# Reference method for the gradient of the fake quantize operator
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    Xq = torch.zeros_like(X)
+    for i in range(X.size()[0]):
+        Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
+    Xq = Xq.permute(tuple(permute_axis_list))
+    mask = (Xq >= quant_min) * (Xq <= quant_max)
+    res = torch.zeros_like(dY)
+    res[mask] = dY[mask]
+    return res.to(dtype)
+
+def to_tensor(X, device):
+    if not isinstance(X, torch.Tensor):
+        X = torch.tensor(X)
+    else:
+        X = X.detach().clone()
+    return X.to(device=torch.device(device), dtype=torch.float32)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _n_ones(n: int) -> int:
+    return (1 << n) - 1
+
+EBITS_F32, MBITS_F32 = 8, 23
+F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert FP32 numbers to sub-byte floating point numbers with the given
+    number of exponent and mantissa bits.
+
+    Input: torch.Tensor of dtype torch.float
+    Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+
+    Note: there are no special values (NaN, inf) support in this code. Values
+    outside the representable range of Floatx after rounding are clamped to the
+    maximum Floatx magnitude (sign is preserved).
+
+    Code below is an adaptation of https://fburl.com/code/ciwofcg4
+
+    Background 1: last answer in https://stackoverflow.com/q/8981913
+    Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
+    """
+    assert x.dtype == torch.float
+    assert 1 + ebits + mbits <= 8
+
+    # calculate constants
+    exp_bias = _n_ones(ebits - 1)
+    max_int = _n_ones(ebits + mbits)
+    sign_mask = 1 << (ebits + mbits)
+
+    # TODO document this better
+    magic_adder = _n_ones(MBITS_F32 - mbits - 1)
+
+    # all E bits and M bits are 1s
+    max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
+
+    # E bits = 1, M bits = 0
+    min_normal = 2 ** (1 - exp_bias)
+
+    denorm_exp = (
+        # exp bias conversion between formats
+        (F32_EXP_BIAS - exp_bias)
+        # mantissa length difference between formats
+        + (MBITS_F32 - mbits)
+        # add one to encoded exponent for denormalized numbers
+        + 1
+    )
+    denorm_mask_int = denorm_exp << MBITS_F32
+
+    # reinterpret int32 as float32
+    denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
+        torch.float32
+    )
+
+    # save the sign
+    # Note that we have torch.uint32, but some ops like cpu bit shifts
+    # do not work on it. So, we stay in int32.
+    x = x.view(torch.int32)
+    sign = x & 0x80000000
+
+    # set everything to positive, will add sign back at the end
+    x = x ^ sign
+
+    # TODO: can the branch floating point comparisons below be done without
+    # converting to float? probably but need to verify
+    x = x.view(torch.float)
+
+    # rewrite saturate/denorm/norm branches without explicit data dependent
+    # control flow, to be more compiler friendly
+    saturate_mask = x >= max_normal
+    denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
+    normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
+
+    #
+    # branch 1: saturate to max val - handled later in the code which combines
+    #   the branches
+    #
+
+    #
+    # branch 2: to conversion to denormal as well as rounding up to normal
+    #
+    denormal_x = x + denorm_mask_float
+    denormal_x = denormal_x.view(torch.int32)
+    denormal_x -= denorm_mask_int
+    denormal_x = denormal_x.to(torch.uint8)
+
+    #
+    # branch 3: stay in normal range, adjust the exponent and round
+    #
+    normal_x = x.view(torch.int32)
+    # resulting mantissa is odd
+    mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
+    # update exponent, rounding bias part 1
+    val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
+    normal_x += val_to_add
+    # rounding bias part 2
+    normal_x += mant_odd
+    # take the bits!
+    normal_x = normal_x >> (MBITS_F32 - mbits)
+    normal_x = normal_x.to(torch.uint8)
+
+    #
+    # combine the branches
+    #
+    x = torch.full_like(x, max_int, dtype=torch.uint8)
+    x = torch.where(denormal_mask, denormal_x, x)
+    x = torch.where(normal_mask, normal_x, x)
+
+    # add sign back
+    sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
+    sign_lp = sign_lp.to(torch.uint8)
+    # Right shift of a negative signed integer can fill the least significant
+    # bits with either 1s or 0s, depending on the implementation. Since PyTorch
+    # doesn't have an uint32 dtype, we mask out these bits to get just the
+    # f4 sign bit
+    sign_lp = sign_lp & sign_mask
+    x = x | sign_lp
+
+    return x.to(torch.uint8)
+
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147
+def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert sub-byte floating point numbers with the given number of exponent
+    and mantissa bits to FP32.
+
+    Input: torch.Tensor of dtype uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Output: torch.Tensor of dtype fp32 with the dequantized value
+    """
+    assert x.dtype == torch.uint8
+    assert 1 + ebits + mbits <= 8
+
+    sign_mask = 1 << (ebits + mbits)
+    exp_bias = _n_ones(ebits - 1)
+    mantissa_mask = _n_ones(mbits)
+
+    # save the sign
+    sign_lp = x & sign_mask
+
+    # set everything to positive, will add sign back at the end
+    x_pos = x ^ sign_lp
+
+    #
+    # 1. Calculate zero mask
+    #
+    zero_mask = x_pos == 0
+
+    #
+    # 2. Calculate the denormal path mask
+    #
+    denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
+
+    #
+    # 3. Calculate the normal path
+    #
+
+    # calculate the new exponent and shift it to bits 2:9 of the result
+    exp_biased_lp = x_pos >> mbits
+    exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
+    exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
+
+    # shift the mantissa to bits 10:32 of the result
+    mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
+    mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
+    result = exp_biased_f32 | mantissa_f32
+
+    #
+    # 4. Add the zero and denormal casts to the already casted normal path
+    #
+    result[zero_mask] = 0
+
+    denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
+
+    # fast path.
+    # without this, performance for FP4_E2M1 is slower by 2x
+    if mbits == 1:
+        result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
+
+    else:
+        # iterate over all possible values of mantissa
+        # i=0, j=1
+        # i=1, j=10,11
+        # i=2, j=100,101,110,111
+        # and so on
+        for i in range(mbits):
+            for mantissa_cmp in range(1 << i, 1 << (i + 1)):
+                # left shift mantissa until it overflows (create an implicit 1)
+                # subtract exponent by the same amount
+                left_shift = mbits - i
+                mantissa_f32 = (mantissa_cmp - (1 << i)) << (
+                    left_shift + MBITS_F32 - mbits
+                )
+                exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
+
+                # we can update this in-place since the values won't overlap
+                # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
+                # thus we use + instead of | here
+                mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = (
+                    exp_biased_f32 + mantissa_f32
+                )
+
+        result = torch.where(denormal_mask, mantissa_lp_int32, result)
+
+    # add sign back
+    sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
+    result = result | sign_f32
+
+    return result.view(torch.float)
+
+# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
+# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
+# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
+# more naturally.
+def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
+    # Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
+    # Output should be [input.size(0, input.size(1) // blocksize] scales
+    output_scales = torch.zeros(
+        (input.size(0), input.size(1) // blocksize),
+        device=input.device,
+        dtype=input_scales.dtype,
+    )
+
+    # Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
+    # happened for offset purposes.
+    # There are K//blocksize scales, padded to groups of 4.
+    num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
+
+    # (Very) slow reference implementation using horrifying loops.
+    for i in range(input.size(0)):
+        for j in range(input.size(1) // blocksize):
+            # which 128x4 tile of scaling factors am I in
+            scale_tile_h = i // 128
+            scale_tile_w = j // 4
+
+            # There are (padded) input_scales.size(1) // 4 tiles along the w dim.
+            # So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
+            tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
+
+            # indices within the tile - use nomenclature directly from cublas docs
+            outer = i % 128  # "outer" in cublas docs
+            inner = j % 4    # "inner" in cublas docs
+
+            # Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
+            #       anyway, and so 1B == 1 value => use offset directly.
+            # Formula directly from cublas docs in 3.1.4.3.2
+            offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
+
+            output_scales[i, j] = input_scales[offset]
+
+    return output_scales
+
+def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
+    # expand scales
+    scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
+
+    # de-scale and convert
+    x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
+    return x_f32.to(torch.bfloat16)
+
+def to_blocked(input_matrix) -> torch.Tensor:
+    """
+    Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
+
+    See:
+        https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
+
+    Args:
+        input_matrix: Input tensor of shape (H, W)
+
+    Returns:
+        Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
+    """
+    rows, cols = input_matrix.shape
+    n_row_blocks = ceil_div(rows, 128)
+    n_col_blocks = ceil_div(cols, 4)
+
+    # Calculate the padded shape
+    padded_rows = n_row_blocks * 128
+    padded_cols = n_col_blocks * 4
+
+    padded = input_matrix
+    # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
+    if (rows, cols) != (padded_rows, padded_cols):
+        padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
+        padded[:rows, :cols] = input_matrix
+
+    # Rearrange the blocks
+    blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
+    rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
+
+    return rearranged.flatten()
+
+
+def down_size(size):
+    assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
+    return (*size[:-1], size[-1] // 2)
+
+
+def pack_uint4(uint8_data) -> torch.Tensor:
+    # converting to uint8 for operations
+    shape = uint8_data.shape
+    assert shape[-1] % 2 == 0
+    uint8_data = uint8_data.contiguous().view(-1)
+    return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
+
+
+# exponent and mantissa bits of `torch.float4_e2m1fn_x2`
+FP4_EBITS, FP4_MBITS = 2, 1
+
+
+def _bfloat16_to_float4_e2m1fn_x2(x):
+    assert x.dtype == torch.bfloat16
+    x = _f32_to_floatx_unpacked(x.float(), FP4_EBITS, FP4_MBITS)
+    x = pack_uint4(x)
+    x = x.view(torch.float4_e2m1fn_x2)
+    return x
+
+
+# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
+def to_mxfp(
+    data_hp: torch.Tensor,
+    block_size: int = 32,
+    format: str = "mxfp8",
+):
+    assert data_hp.dtype in (
+        torch.bfloat16,
+        torch.float,
+    ), f"{data_hp.dtype} is not supported yet"
+    assert (
+        data_hp.shape[-1] % block_size == 0
+    ), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
+    assert data_hp.is_contiguous(), "unsupported"
+
+    orig_shape = data_hp.shape
+    data_hp = data_hp.reshape(
+        *orig_shape[:-1], orig_shape[-1] // block_size, block_size
+    )
+
+    max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
+
+    data_hp = data_hp.to(torch.float32)
+    max_abs = max_abs.to(torch.float32)
+
+    if format == "mxfp8":
+        F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max  # 448.0
+        max_pos = F8E4M3_MAX
+    elif format == "mxfp4":
+        F4E2M1_MAX = 6.
+        max_pos = F4E2M1_MAX
+
+    # RCEIL
+    def _to_mx_rceil(
+        data_hp: torch.Tensor,
+        max_abs: torch.Tensor,
+        max_pos: float,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        E8M0_EXPONENT_BIAS = 127
+        descale = max_abs / max_pos
+        exponent = torch.where(
+            torch.isnan(descale),
+            0xFF,  # Handle biased exponent for nan
+            # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
+            (
+                torch.clamp(
+                    torch.ceil(torch.log2(descale)),
+                    min=-E8M0_EXPONENT_BIAS,
+                    max=E8M0_EXPONENT_BIAS,
+                )
+                + E8M0_EXPONENT_BIAS
+            ).to(torch.uint8),
+        )
+
+        descale_fp = torch.where(
+            exponent == 0,
+            1.0,
+            torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
+        )
+
+        # scale and saturated cast the data elements to max of target dtype
+        data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
+        return exponent, data_lp
+
+    scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
+
+    # cast to target dtype
+    if format == "mxfp8":
+        data_lp = data_lp.to(torch.float8_e4m3fn)
+        # need to reshape at the end to help inductor fuse things
+        data_lp = data_lp.reshape(orig_shape)
+    elif format == "mxfp4":
+        data_lp = _bfloat16_to_float4_e2m1fn_x2(data_lp.to(torch.bfloat16))
+        final_shape = list(orig_shape)
+        final_shape[-1] //= 2
+        data_lp = data_lp.reshape(final_shape)
+
+    scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
+    scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
+    return scale_e8m0_biased, data_lp
+
+# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310
+def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
+    """
+    Utility function for tests and benchmarks.
+
+    Generates a tensor of length E, containing random values divisible by `multiple_of`,
+    from 0 to M, in sorted order, and where the final value in the tensor is always M.
+    Args:
+        E (int): The length of the tensor.
+        M (int): The maximum value in the tensor.
+    Returns:
+        torch.Tensor: A tensor of length E with the specified properties.
+    """
+    import random
+
+    # Ensure M is divisible by 16
+    if M % multiple_of != 0:
+        raise ValueError(f"M must be divisible by {multiple_of}")
+
+    # Generate a list of possible values
+    possible_values = list(range(multiple_of, M + 1, multiple_of))
+
+    # If E is larger than the number of possible values, raise an error
+    if E > len(possible_values):
+        raise ValueError("E cannot be larger than the number of possible values")
+
+    # Randomly select E - 1 values from the possible values (excluding M)
+    selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
+
+    # Append M to the selected values
+    selected_values = torch.cat((selected_values, torch.tensor([M])))
+
+    # Sort the selected values
+    selected_values, _ = torch.sort(selected_values)
+
+    return selected_values.to(dtype).to(device)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
new file mode 100644
index 0000000000000000000000000000000000000000..cca291133d3e945c6b42054577a711d781857cac
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
@@ -0,0 +1,343 @@
+# mypy: ignore-errors
+
+import torch
+from copy import deepcopy
+from torch.utils._pytree import tree_map
+import torch.utils._pytree as pytree
+
+
+# TODO: Move LoggingTensor here.
+from torch.testing._internal.logging_tensor import LoggingTensor
+
+
+# Base class for wrapper-style tensors.
+class WrapperTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, *args, **kwargs):
+        t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
+        if "size" not in kwargs:
+            size = t.size()
+        else:
+            size = kwargs["size"]
+            del kwargs["size"]
+        if "dtype" not in kwargs:
+            kwargs["dtype"] = t.dtype
+        if "layout" not in kwargs:
+            kwargs["layout"] = t.layout
+        if "device" not in kwargs:
+            kwargs["device"] = t.device
+        if "requires_grad" not in kwargs:
+            kwargs["requires_grad"] = False
+        # Ignore memory_format and pin memory for now as I don't know how to
+        # safely access them on a Tensor (if possible??)
+
+        wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
+        wrapper._validate_methods()
+        return wrapper
+
+    @classmethod
+    def get_wrapper_properties(cls, *args, **kwargs):
+        # Should return both an example Tensor and a dictionary of kwargs
+        # to override any of that example Tensor's properly.
+        # This is very similar to the `t.new_*(args)` API
+        raise NotImplementedError("You need to implement get_wrapper_properties")
+
+    def _validate_methods(self):
+        # Skip this if not in debug mode?
+        # Changing these on the python side is wrong as it would not be properly reflected
+        # on the c++ side
+        # This doesn't catch attributes set in the __init__
+        forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
+        for el in forbidden_overrides:
+            if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
+                raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
+                                   f"property {el} but this is not allowed as such change would "
+                                   "not be reflected to c++ callers.")
+
+
+class WrapperTensorWithCustomSizes(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "sizes"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomSizes) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomSizes(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class WrapperTensorWithCustomStrides(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "strides"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomStrides) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomStrides(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class DiagTensorBelow(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, diag, requires_grad=False):
+        assert diag.ndim == 1
+        return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
+
+    def __init__(self, diag, requires_grad=False):
+        self.diag = diag
+
+    handled_ops = {}
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        # For everything else, call the handler:
+        fn = cls.handled_ops.get(func.__name__, None)
+        if fn:
+            return fn(*args, **(kwargs or {}))
+        else:
+            # Note that here, because we don't need to provide the autograd formulas
+            # we can have a default "fallback" that creates a plain Tensor based
+            # on the diag elements and calls the func again.
+
+            def unwrap(e):
+                return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
+
+            def wrap(e):
+                if isinstance(e, torch.Tensor) and e.ndim == 1:
+                    return DiagTensorBelow(e)
+                if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
+                    return DiagTensorBelow(e.diag())
+                return e
+
+            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+            return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"diag={self.diag}")
+
+
+class SparseTensor(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
+        assert values.device == indices.device
+        return values, {"size": size, "requires_grad": requires_grad}
+
+    def __init__(self, size, values, indices, requires_grad=False):
+        self.values = values
+        self.indices = indices
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
+
+    def sparse_to_dense(self):
+        res = torch.zeros(self.size(), dtype=self.values.dtype)
+        res[self.indices.unbind(1)] = self.values
+        return res
+
+    @staticmethod
+    def from_dense(t):
+        indices = t.nonzero()
+        values = t[indices.unbind(1)]
+        return SparseTensor(t.size(), values, indices)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        func_name = f"{func.__module__}.{func.__name__}"
+
+        res = cls._try_call_special_impl(func_name, args, kwargs)
+        if res is not NotImplemented:
+            return res
+
+        # Otherwise, use a default implementation that construct dense
+        # tensors and use that to compute values
+        def unwrap(e):
+            return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
+
+        # Wrap back all Tensors into our custom class
+        def wrap(e):
+            # Check for zeros and use that to get indices
+            return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+
+    _SPECIAL_IMPLS = {}
+
+    @classmethod
+    def _try_call_special_impl(cls, func, args, kwargs):
+        if func not in cls._SPECIAL_IMPLS:
+            return NotImplemented
+        return cls._SPECIAL_IMPLS[func](args, kwargs)
+
+
+# Example non-wrapper subclass that stores extra state.
+class NonWrapperTensor(torch.Tensor):
+    def __new__(cls, data):
+        t = torch.Tensor._make_subclass(cls, data)
+        t.extra_state = {
+            'last_func_called': None
+        }
+        return t
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        result = super().__torch_function__(func, types, args, kwargs)
+
+        if isinstance(result, cls):
+            # Do something with the extra state. For the example here, just store the name of the
+            # last function called (skip for deepcopy so the copy has the same extra state).
+            if func is torch.Tensor.__deepcopy__:
+                result.extra_state = deepcopy(args[0].extra_state)
+            else:
+                result.extra_state = {
+                    'last_func_called': func.__name__,
+                }
+
+        return result
+
+    # new_empty() must be defined for deepcopy to work
+    def new_empty(self, shape):
+        return type(self)(torch.empty(shape))
+
+
+# Class used to store info about subclass tensors used in testing.
+class SubclassInfo:
+
+    __slots__ = ['name', 'create_fn', 'closed_under_ops']
+
+    def __init__(self, name, create_fn, closed_under_ops=True):
+        self.name = name
+        self.create_fn = create_fn  # create_fn(shape) -> tensor instance
+        self.closed_under_ops = closed_under_ops
+
+
+# Helper function to create a subclass of the given class and possibly cache sizes / strides.
+def _create_and_access_shape(cls, shape):
+    sub = cls(torch.randn(shape))
+    # NB: Wrapper subclasses with custom dispatched sizes / strides cache this info
+    # on the first call via non-serializable PyCapsules. We purposefully trigger cache
+    # population here for serialization / deepcopy tests to verify that the presence of this
+    # cache info doesn't cause problems.
+    sub.size()
+    sub.stride()
+    return sub
+
+
+subclass_db = {
+    torch.Tensor: SubclassInfo(
+        'base_tensor', create_fn=torch.randn
+    ),
+    NonWrapperTensor: SubclassInfo(
+        'non_wrapper_tensor',
+        create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
+    ),
+    LoggingTensor: SubclassInfo(
+        'logging_tensor',
+        create_fn=lambda shape: LoggingTensor(torch.randn(shape))
+    ),
+    SparseTensor: SubclassInfo(
+        'sparse_tensor',
+        create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
+    ),
+    DiagTensorBelow: SubclassInfo(
+        'diag_tensor_below',
+        create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
+        closed_under_ops=False  # sparse semantics
+    ),
+    WrapperTensorWithCustomSizes: SubclassInfo(
+        'wrapper_with_custom_sizes',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomSizes, shape),
+        closed_under_ops=False,
+    ),
+    WrapperTensorWithCustomStrides: SubclassInfo(
+        'wrapper_with_custom_strides',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomStrides, shape),
+        closed_under_ops=False,
+    ),
+}
+
+class SubclassWithTensorFactory(torch.Tensor):
+    @staticmethod
+    def __new__(cls, src):
+        shape = src.shape
+        kwargs = {}
+        kwargs["strides"] = src.stride()
+        kwargs["storage_offset"] = src.storage_offset()
+        kwargs["device"] = src.device
+        kwargs["layout"] = src.layout
+        kwargs["requires_grad"] = src.requires_grad
+        kwargs["dtype"] = src.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+        return out
+
+    def __init__(self, src):
+        self.src = src
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}"
+
+    def __tensor_flatten__(self):
+        return ["src"], None
+
+    @classmethod
+    def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
+        src = inner_tensors["src"]
+        return cls(src)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+
+        def _fn(x):
+            return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
+
+        _args = pytree.tree_map_only(cls, _fn, args)
+        _kwargs = pytree.tree_map_only(cls, _fn, kwargs)
+
+        _out = func(*_args, **_kwargs)
+
+        _out_flat, _out_spec = pytree.tree_flatten(_out)
+
+        out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
+        return pytree.tree_unflatten(out_flat, _out_spec)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a0b3c3a537116daa3be625a7dea1d6f60acd647
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
@@ -0,0 +1,5882 @@
+# mypy: allow-untyped-defs
+
+r"""Importing this file must **not** initialize CUDA context. test_distributed
+relies on this assumption to properly run. This means that when this is imported
+no CUDA calls shall be made, including torch.cuda.device_count(), etc.
+
+torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported.
+"""
+
+import argparse
+import contextlib
+import copy
+import ctypes
+import errno
+import functools
+import gc
+import hashlib
+import inspect
+import io
+import json
+import logging
+import math
+import operator
+import os
+import pathlib
+import platform
+import random
+import re
+import shutil
+import signal
+import socket
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import types
+import unittest
+import warnings
+from collections.abc import Mapping, Sequence
+from contextlib import closing, contextmanager
+from copy import deepcopy
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial, wraps
+from itertools import product, chain
+from pathlib import Path
+from statistics import mean
+from typing import (
+    Any,
+    Optional,
+    TypeVar,
+    Union,
+)
+from collections.abc import Callable
+from collections.abc import Iterable, Iterator
+from unittest.mock import MagicMock
+
+import expecttest
+import numpy as np
+
+import __main__  # type: ignore[import]
+import torch
+import torch.backends.cudnn
+import torch.backends.mkl
+import torch.backends.mps
+import torch.backends.xnnpack
+import torch.cuda
+from torch import Tensor
+from torch._C import ScriptDict, ScriptList  # type: ignore[attr-defined]
+from torch._utils_internal import get_writable_path
+from torch._logging.scribe import open_source_signpost
+from torch.nn import (
+    ModuleDict,
+    ModuleList,
+    ParameterDict,
+    ParameterList,
+    Sequential,
+)
+from torch.onnx import (
+    register_custom_op_symbolic,
+    unregister_custom_op_symbolic,
+)
+from torch.testing import make_tensor
+from torch.testing._comparison import (
+    BooleanPair,
+    NonePair,
+    NumberPair,
+    Pair,
+    TensorLikePair,
+)
+from torch.testing._comparison import not_close_error_metas
+from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.utils._import_utils import _check_module_exists
+import torch.utils._pytree as pytree
+from torch.utils import cpp_extension
+try:
+    import pytest  # type: ignore[import-not-found]
+    has_pytest = True
+except ImportError:
+    has_pytest = False
+
+SEED = 1234
+MI350_ARCH = ("gfx950",)
+MI300_ARCH = ("gfx942",)
+MI200_ARCH = ("gfx90a")
+NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
+NAVI3_ARCH = ("gfx1100", "gfx1101")
+NAVI4_ARCH = ("gfx1200", "gfx1201")
+
+class ProfilingMode(Enum):
+    LEGACY = 1
+    SIMPLE = 2
+    PROFILING = 3
+
+# Set by parse_cmd_line_args() if called
+DISABLED_TESTS_FILE = ""
+GRAPH_EXECUTOR : Optional[ProfilingMode] = None
+LOG_SUFFIX = ""
+PYTEST_SINGLE_TEST = ""
+REPEAT_COUNT = 0
+RERUN_DISABLED_TESTS = False
+RUN_PARALLEL = 0
+SHOWLOCALS = False
+SLOW_TESTS_FILE = ""
+TEST_BAILOUTS = False
+TEST_DISCOVER = False
+TEST_IN_SUBPROCESS = False
+TEST_SAVE_XML = ""
+UNITTEST_ARGS : list[str] = []
+USE_PYTEST = False
+
+def is_navi3_arch():
+    if torch.cuda.is_available():
+        prop = torch.cuda.get_device_properties(0)
+        gfx_arch = prop.gcnArchName.split(":")[0]
+        if gfx_arch in NAVI3_ARCH:
+            return True
+    return False
+
+def freeze_rng_state(*args, **kwargs):
+    return torch.testing._utils.freeze_rng_state(*args, **kwargs)
+
+
+# Class to keep track of test flags configurable by environment variables.
+# Flags set here are intended to be read-only and should not be modified after
+# definition.
+# TODO: Expand this class to handle arbitrary settings in addition to boolean flags?
+class TestEnvironment:
+    # Set of env vars to set for the repro command that is output on test failure.
+    # Specifically, this includes env vars that are set to non-default values and
+    # are not implied. Maps from env var name -> value (int)
+    repro_env_vars: dict = {}
+
+    # Defines a flag usable throughout the test suite, determining its value by querying
+    # the specified environment variable.
+    #
+    # Args:
+    #     name (str): The name of the flag. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this flag. If this is None or the environment variable
+    #         is unset, the default value will be used unless otherwise implied (see
+    #         implied_by_fn). Default: None
+    #     default (bool): The default value to use for the flag if unset by the environment
+    #         variable and unimplied. Default: False
+    #     include_in_repro (bool): Indicates whether this flag should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     enabled_fn (Callable): Callable returning whether the flag should be enabled
+    #         given the environment variable value and the default value. Default: Lambda
+    #         requiring "0" to disable if on by default OR "1" to enable if off by default.
+    #     implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled
+    #         by something outside of its primary environment variable setting. For example,
+    #         this can be useful if the value of another environment variable implies the flag
+    #         as enabled. Default: Lambda returning False to indicate no implications.
+    @staticmethod
+    def def_flag(
+        name,
+        env_var=None,
+        default=False,
+        include_in_repro=True,
+        enabled_fn=lambda env_var_val, default: (
+            (env_var_val != "0") if default else (env_var_val == "1")),
+        implied_by_fn=lambda: False,
+    ):
+        enabled = default
+        env_var_val = None
+        if env_var is not None:
+            env_var_val = os.getenv(env_var)
+            enabled = enabled_fn(env_var_val, default)
+        implied = implied_by_fn()
+        enabled = enabled or implied
+        if include_in_repro and (env_var is not None) and (enabled != default) and not implied:
+            TestEnvironment.repro_env_vars[env_var] = env_var_val
+
+        # export flag globally for convenience
+        assert name not in globals(), f"duplicate definition of flag '{name}'"
+        globals()[name] = enabled
+        return enabled
+
+    # Defines a setting usable throughout the test suite, determining its value by querying
+    # the specified environment variable. This differs from a flag in that it's not restricted
+    # to a boolean value.
+    #
+    # Args:
+    #     name (str): The name of the setting. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this setting. If this is None or the environment variable
+    #         is unset, the default value will be used. Default: None
+    #     default (Any): The default value to use for the setting if unset by the environment
+    #         variable. Default: None
+    #     include_in_repro (bool): Indicates whether this setting should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     parse_fn (Callable): Callable parsing the env var string. Default value just uses
+    #         the string itself.
+    @staticmethod
+    def def_setting(
+        name,
+        env_var=None,
+        default=None,
+        include_in_repro=True,
+        parse_fn=lambda maybe_val_str: maybe_val_str,
+    ):
+        value = default if env_var is None else os.getenv(env_var)
+        value = parse_fn(value)
+        if include_in_repro and (value != default):
+            TestEnvironment.repro_env_vars[env_var] = value
+
+        # export setting globally for convenience
+        assert name not in globals(), f"duplicate definition of setting '{name}'"
+        globals()[name] = value
+        return value
+
+    # Returns a string prefix usable to set environment variables for any test
+    # settings that should be explicitly set to match this instantiation of the
+    # test suite.
+    # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1"
+    @staticmethod
+    def repro_env_var_prefix() -> str:
+        return " ".join([f"{env_var}={value}"
+                         for env_var, value in TestEnvironment.repro_env_vars.items()])
+
+
+log = logging.getLogger(__name__)
+torch.backends.disable_global_flags()
+
+FILE_SCHEMA = "file://"
+if sys.platform == 'win32':
+    FILE_SCHEMA = "file:///"
+
+# NB: This flag differs semantically from others in that setting the env var to any
+# non-empty value will cause it to be true:
+#   CI=1, CI="true", CI=0, etc. all set the flag to be true.
+#   CI= and an unset CI set the flag to be false.
+# GitHub sets the value to CI="true" to enable it.
+IS_CI: bool = TestEnvironment.def_flag(
+    "IS_CI",
+    env_var="CI",
+    include_in_repro=False,
+    enabled_fn=lambda env_var_value, _: bool(env_var_value),
+)
+IS_SANDCASTLE: bool = TestEnvironment.def_flag(
+    "IS_SANDCASTLE",
+    env_var="SANDCASTLE",
+    implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle",
+    include_in_repro=False,
+)
+IN_RE_WORKER: bool = os.environ.get("INSIDE_RE_WORKER") is not None
+
+_is_fbcode_default = (
+    hasattr(torch._utils_internal, "IS_FBSOURCE") and
+    torch._utils_internal.IS_FBSOURCE
+)
+
+IS_FBCODE: bool = TestEnvironment.def_flag(
+    "IS_FBCODE",
+    env_var="PYTORCH_TEST_FBCODE",
+    default=_is_fbcode_default,
+    include_in_repro=False,
+)
+IS_REMOTE_GPU: bool = TestEnvironment.def_flag(
+    "IS_REMOTE_GPU",
+    env_var="PYTORCH_TEST_REMOTE_GPU",
+    include_in_repro=False,
+)
+
+DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag(
+    "DISABLE_RUNNING_SCRIPT_CHK",
+    env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK",
+    include_in_repro=False,
+)
+# NB: enabled by default unless in an fbcode context.
+PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag(
+    "PRINT_REPRO_ON_FAILURE",
+    env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
+    default=(not IS_FBCODE),
+    include_in_repro=False,
+)
+
+# possibly restrict OpInfo tests to a single sample input
+OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting(
+    "OPINFO_SAMPLE_INPUT_INDEX",
+    env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX",
+    default=None,
+    # Don't include the env var value in the repro command because the info will
+    # be queried from the tracked sample input instead
+    include_in_repro=False,
+    parse_fn=lambda val: None if val is None else int(val),
+)
+
+DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
+DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json'
+
+disabled_tests_dict = {}
+slow_tests_dict = {}
+
+def maybe_load_json(filename):
+    if os.path.isfile(filename):
+        with open(filename) as fp:
+            return json.load(fp)
+    log.warning("Attempted to load json file '%s' but it does not exist.", filename)
+    return {}
+
+# set them here in case the tests are running in a subprocess that doesn't call run_tests
+if os.getenv("SLOW_TESTS_FILE", ""):
+    slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", ""))
+if os.getenv("DISABLED_TESTS_FILE", ""):
+    disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
+
+NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', 'mtia', torch._C._get_privateuse1_backend_name())
+
+# used for managing devices testing for torch profiler UTs
+# for now cpu, cuda and xpu are added for testing torch profiler UTs
+DEVICE_LIST_SUPPORT_PROFILING_TEST = ('cpu', 'cuda', 'xpu')
+ALLOW_XPU_PROFILING_TEST = True
+
+check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra', 'thor']
+IS_JETSON = any(name in platform.platform() for name in check_names)
+
+def gcIfJetson(fn):
+    # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if IS_JETSON:
+            gc.collect()
+            torch.cuda.empty_cache()
+        fn(*args, **kwargs)
+    return wrapper
+
+# Tries to extract the current test function by crawling the stack.
+# If unsuccessful, return None.
+def extract_test_fn() -> Optional[Callable]:
+    try:
+        stack = inspect.stack()
+        for frame_info in stack:
+            frame = frame_info.frame
+            if "self" not in frame.f_locals:
+                continue
+            self_val = frame.f_locals["self"]
+            if isinstance(self_val, unittest.TestCase):
+                test_id = self_val.id()
+                *_, cls_name, test_name = test_id.rsplit('.', 2)
+                if cls_name == type(self_val).__name__ and test_name.startswith("test"):
+                    test_fn = getattr(self_val, test_name).__func__
+                    return test_fn
+    except Exception:
+        pass
+    return None
+
+# Contains tracked input data useful for debugging purposes
+@dataclass
+class TrackedInput:
+    index: int
+    val: Any
+    type_desc: str
+
+# Attempt to pull out tracked input information from the test function.
+# A TrackedInputIter is used to insert this information.
+def get_tracked_input() -> Optional[TrackedInput]:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return None
+    return getattr(test_fn, "tracked_input", None)
+
+def clear_tracked_input() -> None:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return
+    if not hasattr(test_fn, "tracked_input"):
+        return
+    test_fn.tracked_input = None  # type: ignore[attr-defined]
+
+# Wraps an iterator and tracks the most recent value the iterator produces
+# for debugging purposes. Tracked values are stored on the test function.
+class TrackedInputIter:
+    def __init__(
+        self,
+        child_iter,
+        input_type_desc,
+        item_callback=None,
+        track_callback=None,
+        set_seed=True,
+        restrict_to_index=None
+    ):
+        self.child_iter = enumerate(child_iter)
+        # Input type describes the things we're tracking (e.g. "sample input", "error input").
+        self.input_type_desc = input_type_desc
+        # NB: The two types of callbacks below exist because the thing we want to track isn't
+        # always the same as the thing we want returned from the iterator. An example of this
+        # is ErrorInput, which we want returned from the iterator, but which contains a
+        # SampleInput that we want to track.
+        # Item callback is run on each (iterated thing, index) to get the thing to return.
+        self.item_callback = item_callback
+        if self.item_callback is None:
+            self.item_callback = lambda x, i: x
+        # Track callback is run on each iterated thing to get the thing to track.
+        self.track_callback = track_callback
+        if self.track_callback is None:
+            self.track_callback = lambda x: x
+        self.test_fn = extract_test_fn()
+        # Indicates whether the random seed should be set before each call to the iterator
+        self.set_seed = set_seed
+        # Indicates that iteration should be restricted to only the provided index.
+        # If None, no restriction is done
+        self.restrict_to_index = restrict_to_index
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        while True:
+            if self.set_seed:
+                # use a test-name-specific hash for the seed if possible
+                seed = (
+                    int.from_bytes(hashlib.sha256(
+                        self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little')
+                    if self.test_fn is not None else SEED
+                )
+                set_rng_seed(seed)
+
+            # allow StopIteration to bubble up
+            input_idx, input_val = next(self.child_iter)
+            if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index):
+                break
+
+        self._set_tracked_input(
+            TrackedInput(
+                index=input_idx, val=self.track_callback(input_val), type_desc=self.input_type_desc
+            )
+        )
+        return self.item_callback(input_val, input_idx)
+
+    def _set_tracked_input(self, tracked_input: TrackedInput):
+        if self.test_fn is None:
+            return
+        if not hasattr(self.test_fn, "tracked_input"):
+            return
+        self.test_fn.tracked_input = tracked_input  # type: ignore[attr-defined]
+
+class _TestParametrizer:
+    """
+    Decorator class for parametrizing a test function, yielding a set of new tests spawned
+    from the original generic test, each specialized for a specific set of test inputs. For
+    example, parametrizing a test across the set of ops will result in a test function per op.
+
+    The decision of how to parametrize / what to parametrize over is intended to be implemented
+    by each derived class.
+
+    In the details, the decorator adds a 'parametrize_fn' property to the test function. This function
+    is intended to be called later by one of:
+      * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this
+        case there is no need to explicitly parametrize over device type, as that is handled separately.
+      * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests().
+
+    If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
+    composite 'parametrize_fn' will be created that generates tests with the product of the parameters
+    generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
+    """
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        """
+        Parametrizes the given test function across whatever dimension is specified by the derived class.
+        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
+        ops, all modules, or all ops + their associated dtypes.
+
+        Args:
+            test (fn): Test function to parametrize over
+            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+            device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
+                if the tests are not part of a device-specific set
+
+        Returns:
+            Generator object returning 4-tuples of:
+                test (fn): Parametrized test function; must support a device arg and args for any params
+                test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
+                    the base name of the test
+                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
+                decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs
+        """
+        raise NotImplementedError
+
+    def __call__(self, fn):
+        if hasattr(fn, 'parametrize_fn'):
+            # Do composition with the product of args.
+            old_parametrize_fn = fn.parametrize_fn
+            new_parametrize_fn = self._parametrize_test
+            fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn)
+        else:
+            fn.parametrize_fn = self._parametrize_test
+        return fn
+
+
+def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn):
+    """
+    Returns a parametrize_fn that parametrizes over the product of the parameters handled
+    by the given parametrize_fns. Each given parametrize_fn should each have the signature
+    f(test, generic_cls, device_cls).
+
+    The test names will be a combination of the names produced by the parametrize_fns in
+    "_" order. This order is done to match intuition for constructed names
+    when composing multiple decorators; the names will be built in top to bottom order when stacking
+    parametrization decorators.
+
+    Args:
+        old_parametrize_fn (callable) - First parametrize_fn to compose.
+        new_parametrize_fn (callable) - Second parametrize_fn to compose.
+    """
+
+    def composite_fn(test, generic_cls, device_cls,
+                     old_parametrize_fn=old_parametrize_fn,
+                     new_parametrize_fn=new_parametrize_fn):
+        old_tests = list(old_parametrize_fn(test, generic_cls, device_cls))
+        for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests:
+            for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \
+                    new_parametrize_fn(old_test, generic_cls, device_cls):
+                redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys())
+                if redundant_params:
+                    raise RuntimeError('Parametrization over the same parameter by multiple parametrization '
+                                       f'decorators is not supported. For test "{test.__name__}", the following parameters '
+                                       f'are handled multiple times: {redundant_params}')
+                full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
+                merged_test_name = '{}{}{}'.format(new_test_name,
+                                                   '_' if old_test_name != '' and new_test_name != '' else '',
+                                                   old_test_name)
+
+                def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn):
+                    return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs))
+
+                yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn)
+
+    return composite_fn
+
+
+def instantiate_parametrized_tests(generic_cls):
+    """
+    Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
+    decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
+    parametrized tests with specialized names. This should be used instead of
+    instantiate_device_type_tests() if the test class contains device-agnostic tests.
+
+    You can also use it as a class decorator. E.g.
+
+    ```
+    @instantiate_parametrized_tests
+    class TestFoo(TestCase):
+        ...
+    ```
+
+    Args:
+        generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+    """
+    for attr_name in tuple(dir(generic_cls)):
+        class_attr = getattr(generic_cls, attr_name)
+        if not hasattr(class_attr, 'parametrize_fn'):
+            continue
+
+        # Remove the generic test from the test class.
+        delattr(generic_cls, attr_name)
+
+        # Add parametrized tests to the test class.
+        def instantiate_test_helper(cls, name, test, param_kwargs):
+            @wraps(test)
+            def instantiated_test(self, param_kwargs=param_kwargs):
+                test(self, **param_kwargs)
+
+            assert not hasattr(generic_cls, name), f"Redefinition of test {name}"
+            setattr(generic_cls, name, instantiated_test)
+
+        for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn(
+                class_attr, generic_cls=generic_cls, device_cls=None):
+            full_name = f'{test.__name__}_{test_suffix}'
+
+            # Apply decorators based on full param kwargs.
+            for decorator in decorator_fn(param_kwargs):
+                test = decorator(test)
+
+            instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
+    return generic_cls
+
+
+class subtest:
+    """
+    Explicit subtest case for use with test parametrization.
+    Allows for explicit naming of individual subtest cases as well as applying
+    decorators to the parametrized test.
+
+    Args:
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name (str): Optional name to use for the test.
+        decorators (iterable): Iterable of decorators to apply to the generated test.
+    """
+    __slots__ = ['arg_values', 'name', 'decorators']
+
+    def __init__(self, arg_values, name=None, decorators=None):
+        self.arg_values = arg_values
+        self.name = name
+        self.decorators = decorators if decorators else []
+
+
+class parametrize(_TestParametrizer):
+    """
+    Decorator for applying generic test parametrizations.
+
+    The interface for this decorator is modeled after `@pytest.mark.parametrize`.
+    Basic usage between this decorator and pytest's is identical. The first argument
+    should be a string containing comma-separated names of parameters for the test, and
+    the second argument should be an iterable returning values or tuples of values for
+    the case of multiple parameters.
+
+    Beyond this basic usage, the decorator provides some additional functionality that
+    pytest does not.
+
+    1. Parametrized tests end up as generated test functions on unittest test classes.
+    Since this differs from how pytest works, this decorator takes on the additional
+    responsibility of naming these test functions. The default test names consists of
+    the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
+    but custom names can be defined using `name_fn` or the `subtest` structure (see below).
+
+    2. The decorator specially handles parameter values of type `subtest`, which allows for
+    more fine-grained control over both test naming and test execution. In particular, it can
+    be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
+    below).
+
+    Examples::
+
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        def test_bar(self, x, y):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
+                     name_fn=lambda x, y: '{}_{}'.format(x, y))
+        def test_bar_custom_names(self, x, y):
+            ...
+
+        @parametrize("x, y", [subtest((1, 2), name='double'),
+                              subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
+                              subtest((1, 4), name='quadruple')])
+        def test_baz(self, x, y):
+            ...
+
+    To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or
+    instantiate_device_type_tests() should be called. The former is intended for test classes
+    that contain device-agnostic tests, while the latter should be used for test classes that
+    contain device-specific tests. Both support arbitrary parametrizations using the decorator.
+
+    Args:
+        arg_str (str): String of arg names separate by commas (e.g. "x,y").
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name_fn (Callable): Optional function that takes in parameters and returns subtest name.
+    """
+    def __init__(self, arg_str, arg_values, name_fn=None):
+        self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != '']
+        self.arg_values = arg_values
+        self.name_fn = name_fn
+
+    def _formatted_str_repr(self, idx, name, value):
+        """ Returns a string representation for the given arg that is suitable for use in test function names. """
+        if isinstance(value, torch.dtype):
+            return dtype_name(value)
+        elif isinstance(value, torch.device):
+            return str(value)
+        # Can't use isinstance as it would cause a circular import
+        elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}:
+            return value.formatted_name
+        elif isinstance(value, (int, float, str)):
+            return f"{name}_{str(value).replace('.', '_')}"
+        else:
+            return f"{name}{idx}"
+
+    def _default_subtest_name(self, idx, values):
+        return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values, strict=True)])
+
+    def _get_subtest_name(self, idx, values, explicit_name=None):
+        if explicit_name:
+            subtest_name = explicit_name
+        elif self.name_fn:
+            subtest_name = self.name_fn(*values)
+        else:
+            subtest_name = self._default_subtest_name(idx, values)
+        return subtest_name
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if len(self.arg_names) == 0:
+            # No additional parameters needed for the test.
+            test_name = ''
+            yield (test, test_name, {}, lambda _: [])
+        else:
+            # Each "values" item is expected to be either:
+            # * A tuple of values with one for each arg. For a single arg, a single item is expected.
+            # * A subtest instance with arg_values matching the previous.
+            values = check_exhausted_iterator = object()
+            for idx, values in enumerate(self.arg_values):
+                maybe_name = None
+
+                decorators: list[Any] = []
+                if isinstance(values, subtest):
+                    sub = values
+                    values = sub.arg_values
+                    maybe_name = sub.name
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    decorators = sub.decorators
+                    gen_test = test_wrapper
+                else:
+                    gen_test = test
+
+                values = list(values) if len(self.arg_names) > 1 else [values]  # type: ignore[call-overload]
+                if len(values) != len(self.arg_names):
+                    raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
+                                       f'values and {len(self.arg_names)} names for test "{test.__name__}"')
+
+                param_kwargs = dict(zip(self.arg_names, values, strict=True))
+
+                test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name)
+
+                def decorator_fn(_, decorators=decorators):
+                    return decorators
+
+                yield (gen_test, test_name, param_kwargs, decorator_fn)
+
+            if values is check_exhausted_iterator:
+                raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. '
+                                 'Note that this may result from reuse of a generator.')
+
+
+class reparametrize(_TestParametrizer):
+    """
+    Decorator for adjusting the way an existing parametrizer operates. This class runs
+    the given adapter_fn on each parametrization produced by the given parametrizer,
+    allowing for on-the-fly parametrization more flexible than the default,
+    product-based composition that occurs when stacking parametrization decorators.
+
+    If the adapter_fn returns None for a given test parametrization, that parametrization
+    will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
+    modified parametrizations, with tweaked test names and parameter kwargs.
+
+    Examples::
+
+        def include_is_even_arg(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            new_param_kwargs = dict(param_kwargs)
+            new_param_kwargs["is_even"] = is_even
+            is_even_suffix = "_even" if is_even else "_odd"
+            new_test_name = f"{test_name}{is_even_suffix}"
+            yield (new_test_name, new_param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), include_is_even_arg)
+        def test_foo(self, x, is_even):
+            ...
+
+        def exclude_odds(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            yield None if not is_even else (test_name, param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), exclude_odds)
+        def test_bar(self, x):
+            ...
+
+    """
+    def __init__(self, parametrizer, adapter_fn):
+        self.parametrizer = parametrizer
+        self.adapter_fn = adapter_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for (gen_test, test_name, param_kwargs, decorator_fn) in \
+                self.parametrizer._parametrize_test(test, generic_cls, device_cls):
+            adapted = self.adapter_fn(test_name, param_kwargs)
+            if adapted is not None:
+                for adapted_item in adapted:
+                    if adapted_item is not None:
+                        new_test_name, new_param_kwargs = adapted_item
+                        yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
+
+
+class decorateIf(_TestParametrizer):
+    """
+    Decorator for applying parameter-specific conditional decoration.
+    Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.).
+
+    Examples::
+
+        @decorateIf(unittest.skip, lambda params: params["x"] == 2)
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["x"] == 3 and params["y"] == "baz"
+        )
+        def test_bar(self, x, y):
+            ...
+
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["op"].name == "add" and params["dtype"] == torch.float16
+        )
+        @ops(op_db)
+        def test_op_foo(self, device, dtype, op):
+            ...
+
+        @decorateIf(
+            unittest.skip,
+            lambda params: params["module_info"].module_cls is torch.nn.Linear and \
+                params["device"] == "cpu"
+        )
+        @modules(module_db)
+        def test_module_foo(self, device, dtype, module_info):
+            ...
+
+    Args:
+        decorator: Test decorator to apply if the predicate is satisfied.
+        predicate_fn (Callable): Function taking in a dict of params and returning a boolean
+            indicating whether the decorator should be applied or not.
+    """
+    def __init__(self, decorator, predicate_fn):
+        self.decorator = decorator
+        self.predicate_fn = predicate_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+
+        # Leave test as-is and return the appropriate decorator_fn.
+        def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn):
+            if predicate_fn(params):
+                return [decorator]
+            else:
+                return []
+
+        @wraps(test)
+        def test_wrapper(*args, **kwargs):
+            return test(*args, **kwargs)
+
+        test_name = ''
+        yield (test_wrapper, test_name, {}, decorator_fn)
+
+
+def cppProfilingFlagsToProfilingMode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    torch._C._jit_set_profiling_executor(old_prof_exec_state)
+    torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+    if old_prof_exec_state:
+        if old_prof_mode_state:
+            return ProfilingMode.PROFILING
+        else:
+            return ProfilingMode.SIMPLE
+    else:
+        return ProfilingMode.LEGACY
+
+@contextmanager
+def enable_profiling_mode_for_profiling_tests():
+    old_prof_exec_state = False
+    old_prof_mode_state = False
+    assert GRAPH_EXECUTOR
+    if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+        old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+        old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            torch._C._jit_set_profiling_executor(old_prof_exec_state)
+            torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def enable_profiling_mode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_profiling_executor(old_prof_exec_state)
+        torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def num_profiled_runs(num_runs):
+    old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_num_profiled_runs(old_num_runs)
+
+func_call = torch._C.ScriptFunction.__call__
+meth_call = torch._C.ScriptMethod.__call__
+
+def prof_callable(callable, *args, **kwargs):
+    if 'profile_and_replay' in kwargs:
+        del kwargs['profile_and_replay']
+        assert GRAPH_EXECUTOR
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            with enable_profiling_mode_for_profiling_tests():
+                callable(*args, **kwargs)
+                return callable(*args, **kwargs)
+
+    return callable(*args, **kwargs)
+
+def raise_on_run_directly(file_to_call):
+    raise RuntimeError("This test file is not meant to be run directly, "
+                       f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
+                       "instead.")
+
+def prof_func_call(*args, **kwargs):
+    return prof_callable(func_call, *args, **kwargs)
+
+def prof_meth_call(*args, **kwargs):
+    return prof_callable(meth_call, *args, **kwargs)
+
+torch._C.ScriptFunction.__call__ = prof_func_call  # type: ignore[method-assign]
+torch._C.ScriptMethod.__call__ = prof_meth_call  # type: ignore[method-assign]
+
+def _get_test_report_path():
+    # allow users to override the test file location. We need this
+    # because the distributed tests run the same test file multiple
+    # times with different configurations.
+    override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
+    test_source = override if override is not None else 'python-unittest'
+    return os.path.join('test-reports', test_source)
+
+def parse_cmd_line_args():
+    global DISABLED_TESTS_FILE
+    global GRAPH_EXECUTOR
+    global LOG_SUFFIX
+    global PYTEST_SINGLE_TEST
+    global REPEAT_COUNT
+    global RERUN_DISABLED_TESTS
+    global RUN_PARALLEL
+    global SHOWLOCALS
+    global SLOW_TESTS_FILE
+    global TEST_BAILOUTS
+    global TEST_DISCOVER
+    global TEST_IN_SUBPROCESS
+    global TEST_SAVE_XML
+    global UNITTEST_ARGS
+    global USE_PYTEST
+
+    is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
+    parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
+    parser.add_argument('--subprocess', action='store_true',
+                        help='whether to run each test in a subprocess')
+    parser.add_argument('--accept', action='store_true')
+    parser.add_argument('--jit-executor', '--jit_executor', type=str)
+    parser.add_argument('--repeat', type=int, default=1)
+    parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
+    parser.add_argument('--use-pytest', action='store_true')
+    parser.add_argument('--save-xml', nargs='?', type=str,
+                        const=_get_test_report_path(),
+                        default=_get_test_report_path() if IS_CI else None)
+    parser.add_argument('--discover-tests', action='store_true')
+    parser.add_argument('--log-suffix', type=str, default="")
+    parser.add_argument('--run-parallel', type=int, default=1)
+    parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
+    parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
+    parser.add_argument('--rerun-disabled-tests', action='store_true')
+    parser.add_argument('--pytest-single-test', type=str, nargs=1)
+    parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
+
+# Only run when -h or --help flag is active to display both unittest and parser help messages.
+    def run_unittest_help(argv):
+        unittest.main(argv=argv)
+
+    if '-h' in sys.argv or '--help' in sys.argv:
+        help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
+        help_thread.start()
+        help_thread.join()
+
+    args, remaining = parser.parse_known_args()
+    if args.jit_executor == 'legacy':
+        GRAPH_EXECUTOR = ProfilingMode.LEGACY
+    elif args.jit_executor == 'profiling':
+        GRAPH_EXECUTOR = ProfilingMode.PROFILING
+    elif args.jit_executor == 'simple':
+        GRAPH_EXECUTOR = ProfilingMode.SIMPLE
+    else:
+        # infer flags based on the default settings
+        GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
+
+    RERUN_DISABLED_TESTS = args.rerun_disabled_tests
+
+    SLOW_TESTS_FILE = args.import_slow_tests
+    DISABLED_TESTS_FILE = args.import_disabled_tests
+    LOG_SUFFIX = args.log_suffix
+    RUN_PARALLEL = args.run_parallel
+    TEST_BAILOUTS = args.test_bailouts
+    USE_PYTEST = args.use_pytest
+    PYTEST_SINGLE_TEST = args.pytest_single_test
+    TEST_DISCOVER = args.discover_tests
+    TEST_IN_SUBPROCESS = args.subprocess
+    TEST_SAVE_XML = args.save_xml
+    REPEAT_COUNT = args.repeat
+    SHOWLOCALS = args.showlocals
+    if not getattr(expecttest, "ACCEPT", False):
+        expecttest.ACCEPT = args.accept
+    UNITTEST_ARGS = [sys.argv[0]] + remaining
+
+    set_rng_seed()
+
+
+def wait_for_process(p, timeout=None):
+    try:
+        return p.wait(timeout=timeout)
+    except KeyboardInterrupt:
+        # Give `p` a chance to handle KeyboardInterrupt. Without this,
+        # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
+        exit_status = p.wait(timeout=5)
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+            raise
+    except subprocess.TimeoutExpired:
+        # send SIGINT to give pytest a chance to make xml
+        p.send_signal(signal.SIGINT)
+        exit_status = None
+        try:
+            exit_status = p.wait(timeout=5)
+        # try to handle the case where p.wait(timeout=5) times out as well as
+        # otherwise the wait() call in the finally block can potentially hang
+        except subprocess.TimeoutExpired:
+            pass
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+        raise
+    except:  # noqa: B001,E722, copied from python core library
+        p.kill()
+        raise
+    finally:
+        # Always call p.wait() to ensure exit
+        p.wait()
+
+def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
+    sys.stdout.flush()
+    sys.stderr.flush()
+    # The following cool snippet is copied from Py3 core library subprocess.call
+    # only the with
+    #   1. `except KeyboardInterrupt` block added for SIGINT handling.
+    #   2. In Py2, subprocess.Popen doesn't return a context manager, so we do
+    #      `p.wait()` in a `final` block for the code to be portable.
+    #
+    # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
+    assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
+    p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
+    return wait_for_process(p, timeout=timeout)
+
+
+def retry_shell(
+    command,
+    cwd=None,
+    env=None,
+    stdout=None,
+    stderr=None,
+    timeout=None,
+    retries=1,
+    was_rerun=False,
+) -> tuple[int, bool]:
+    # Returns exicode + whether it was rerun
+    assert (
+        retries >= 0
+    ), f"Expecting non negative number for number of retries, got {retries}"
+    try:
+        exit_code = shell(
+            command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout
+        )
+        if exit_code == 0 or retries == 0:
+            return exit_code, was_rerun
+        print(
+            f"Got exit code {exit_code}, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    except subprocess.TimeoutExpired:
+        if retries == 0:
+            print(
+                f"Command took >{timeout // 60}min, returning 124",
+                file=stdout,
+                flush=True,
+            )
+            return 124, was_rerun
+        print(
+            f"Command took >{timeout // 60}min, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    return retry_shell(
+        command,
+        cwd=cwd,
+        env=env,
+        stdout=stdout,
+        stderr=stderr,
+        timeout=timeout,
+        retries=retries - 1,
+        was_rerun=True,
+    )
+
+
+def discover_test_cases_recursively(suite_or_case):
+    if isinstance(suite_or_case, unittest.TestCase):
+        return [suite_or_case]
+    rc = []
+    for element in suite_or_case:
+        print(element)
+        rc.extend(discover_test_cases_recursively(element))
+    return rc
+
+def get_test_names(test_cases):
+    return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
+
+def _print_test_names():
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    test_cases = discover_test_cases_recursively(suite)
+    for name in get_test_names(test_cases):
+        print(name)
+
+def chunk_list(lst, nchunks):
+    return [lst[i::nchunks] for i in range(nchunks)]
+
+# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api
+def sanitize_test_filename(filename):
+    strip_py = re.sub(r'.py$', '', filename)
+    return re.sub('/', r'.', strip_py)
+
+def lint_test_case_extension(suite):
+    succeed = True
+    for test_case_or_suite in suite:
+        test_case = test_case_or_suite
+        if isinstance(test_case_or_suite, unittest.TestSuite):
+            first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None
+            if first_test is not None and isinstance(first_test, unittest.TestSuite):
+                return succeed and lint_test_case_extension(test_case_or_suite)
+            test_case = first_test
+
+        if test_case is not None:
+            if not isinstance(test_case, TestCase):
+                test_class = test_case.id().split('.', 1)[1].split('.')[0]
+                err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't."
+                print(f"{test_class} - failed. {err}")
+                succeed = False
+    return succeed
+
+
+def get_report_path(argv=None, pytest=False):
+    if argv is None:
+        argv = UNITTEST_ARGS
+    test_filename = sanitize_test_filename(argv[0])
+    test_report_path = TEST_SAVE_XML + LOG_SUFFIX
+    test_report_path = os.path.join(test_report_path, test_filename)
+    if pytest:
+        test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
+        os.makedirs(test_report_path, exist_ok=True)
+        test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
+        return test_report_path
+    os.makedirs(test_report_path, exist_ok=True)
+    return test_report_path
+
+
+def sanitize_pytest_xml(xml_file: str):
+    # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
+    # consider somehow modifying the XML logger in conftest to do this instead
+    import xml.etree.ElementTree as ET
+    tree = ET.parse(xml_file)
+    for testcase in tree.iter('testcase'):
+        full_classname = testcase.attrib.get("classname")
+        if full_classname is None:
+            continue
+        # The test prefix is optional
+        regex_result = re.search(r"^(test\.)?(?P.*)\.(?P[^\.]*)$", full_classname)
+        if regex_result is None:
+            continue
+        classname = regex_result.group("classname")
+        file = regex_result.group("file").replace(".", "/")
+        testcase.set("classname", classname)
+        testcase.set("file", f"{file}.py")
+    tree.write(xml_file)
+
+
+def get_pytest_test_cases(argv: list[str]) -> list[str]:
+    class TestCollectorPlugin:
+        def __init__(self) -> None:
+            self.tests: list[Any] = []
+
+        def pytest_collection_finish(self, session):
+            for item in session.items:
+                self.tests.append(session.config.cwd_relative_nodeid(item.nodeid))
+
+    test_collector_plugin = TestCollectorPlugin()
+    import pytest
+    pytest.main(
+        [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'],
+        plugins=[test_collector_plugin]
+    )
+    return test_collector_plugin.tests
+
+
+def run_tests(argv=None):
+    parse_cmd_line_args()
+    if argv is None:
+        argv = UNITTEST_ARGS
+
+    # import test files.
+    if SLOW_TESTS_FILE:
+        if os.path.exists(SLOW_TESTS_FILE):
+            with open(SLOW_TESTS_FILE) as fp:
+                global slow_tests_dict
+                slow_tests_dict = json.load(fp)
+                # use env vars so pytest-xdist subprocesses can still access them
+                os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE
+        else:
+            warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}', stacklevel=2)
+    if DISABLED_TESTS_FILE:
+        if os.path.exists(DISABLED_TESTS_FILE):
+            with open(DISABLED_TESTS_FILE) as fp:
+                global disabled_tests_dict
+                disabled_tests_dict = json.load(fp)
+                os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE
+        else:
+            warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}', stacklevel=2)
+    # Determine the test launch mechanism
+    if TEST_DISCOVER:
+        _print_test_names()
+        return
+
+    # Before running the tests, lint to check that every test class extends from TestCase
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    if not lint_test_case_extension(suite):
+        sys.exit(1)
+
+    if SHOWLOCALS:
+        argv = [
+            argv[0],
+            *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]),
+            *argv[1:],
+        ]
+
+    if TEST_IN_SUBPROCESS:
+        other_args = []
+        if DISABLED_TESTS_FILE:
+            other_args.append("--import-disabled-tests")
+        if SLOW_TESTS_FILE:
+            other_args.append("--import-slow-tests")
+        if USE_PYTEST:
+            other_args.append("--use-pytest")
+        if RERUN_DISABLED_TESTS:
+            other_args.append("--rerun-disabled-tests")
+        if TEST_SAVE_XML:
+            other_args += ['--save-xml', TEST_SAVE_XML]
+
+        test_cases = (
+            get_pytest_test_cases(argv) if USE_PYTEST else
+            [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)]
+        )
+
+        failed_tests = []
+
+        for test_case_full_name in test_cases:
+
+            cmd = (
+                [sys.executable] + [argv[0]] + other_args + argv[1:] +
+                (["--pytest-single-test"] if USE_PYTEST else []) +
+                [test_case_full_name]
+            )
+            string_cmd = " ".join(cmd)
+
+            timeout = None if RERUN_DISABLED_TESTS else 15 * 60
+
+            exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
+
+            if exitcode != 0:
+                # This is sort of hacky, but add on relevant env variables for distributed tests.
+                if 'TestDistBackendWithSpawn' in test_case_full_name:
+                    backend = os.environ.get("BACKEND", "")
+                    world_size = os.environ.get("WORLD_SIZE", "")
+                    env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}"
+                    string_cmd = env_prefix + " " + string_cmd
+                # Log the command to reproduce the failure.
+                print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}")
+                failed_tests.append(test_case_full_name)
+
+            assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
+                len(failed_tests), '\n\t'.join(failed_tests))
+
+    elif RUN_PARALLEL > 1:
+        test_cases = discover_test_cases_recursively(suite)
+        test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
+        processes = []
+        for i in range(RUN_PARALLEL):
+            command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i]
+            processes.append(subprocess.Popen(command, universal_newlines=True))
+        failed = False
+        for p in processes:
+            failed |= wait_for_process(p) != 0
+        assert not failed, "Some test shards have failed"
+    elif USE_PYTEST:
+        pytest_args = argv + ["--use-main-module"]
+        test_report_path = ""
+        if TEST_SAVE_XML:
+            test_report_path = get_report_path(pytest=True)
+            print(f'Test results will be stored in {test_report_path}')
+            pytest_args.append(f'--junit-xml-reruns={test_report_path}')
+        if PYTEST_SINGLE_TEST:
+            pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:]
+
+        import pytest
+        os.environ["NO_COLOR"] = "1"
+        exit_code = pytest.main(args=pytest_args)
+        if TEST_SAVE_XML:
+            sanitize_pytest_xml(test_report_path)
+
+        # exitcode of 5 means no tests were found, which happens since some test configs don't
+        # run tests from certain files
+        sys.exit(0 if exit_code == 5 else exit_code)
+    elif TEST_SAVE_XML:
+        # import here so that non-CI doesn't need xmlrunner installed
+        import xmlrunner  # type: ignore[import]
+        from xmlrunner.result import _XMLTestResult  # type: ignore[import]
+
+        class XMLTestResultVerbose(_XMLTestResult):
+            """
+            Adding verbosity to test outputs:
+            by default test summary prints 'skip',
+            but we want to also print the skip reason.
+            GH issue: https://github.com/pytorch/pytorch/issues/69014
+
+            This works with unittest_xml_reporting<=3.2.0,>=2.0.0
+            (3.2.0 is latest at the moment)
+            """
+
+            def addSkip(self, test, reason):
+                super().addSkip(test, reason)
+                for c in self.callback.__closure__:
+                    if isinstance(c.cell_contents, str) and c.cell_contents == 'skip':
+                        # this message is printed in test summary;
+                        # it stands for `verbose_str` captured in the closure
+                        c.cell_contents = f"skip: {reason}"
+
+            def printErrors(self) -> None:
+                super().printErrors()
+                self.printErrorList("XPASS", self.unexpectedSuccesses)
+        test_report_path = get_report_path()
+        verbose = '--verbose' in argv or '-v' in argv
+        if verbose:
+            print(f'Test results will be stored in {test_report_path}')
+        unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
+            output=test_report_path,
+            verbosity=2 if verbose else 1,
+            resultclass=XMLTestResultVerbose))
+    elif REPEAT_COUNT > 1:
+        for _ in range(REPEAT_COUNT):
+            if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
+                sys.exit(-1)
+    else:
+        unittest.main(argv=argv)
+
+IS_LINUX = sys.platform == "linux"
+IS_WINDOWS = sys.platform == "win32"
+IS_MACOS = sys.platform == "darwin"
+IS_PPC = platform.machine() == "ppc64le"
+IS_X86 = platform.machine() in ('x86_64', 'i386')
+IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
+IS_S390X = platform.machine() == "s390x"
+
+def is_avx512_vnni_supported():
+    if sys.platform != 'linux':
+        return False
+    with open("/proc/cpuinfo", encoding="ascii") as f:
+        lines = f.read()
+    return "vnni" in lines
+
+IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryFileName(*args, **kwargs):
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        if 'delete' in kwargs:
+            if kwargs['delete'] is not False:
+                raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.")
+        else:
+            kwargs['delete'] = False
+        f = tempfile.NamedTemporaryFile(*args, **kwargs)  # noqa:SIM115
+        try:
+            f.close()
+            yield f.name
+        finally:
+            os.unlink(f.name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryFileName(*args, **kwargs):
+        with tempfile.NamedTemporaryFile(*args, **kwargs) as f:
+            yield f.name
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryDirectoryName(suffix=None):
+        # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
+        # so we first create the directory using mkdtemp and then remove it manually
+        try:
+            dir_name = tempfile.mkdtemp(suffix=suffix)
+            yield dir_name
+        finally:
+            shutil.rmtree(dir_name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryDirectoryName(suffix=None):
+        with tempfile.TemporaryDirectory(suffix=suffix) as d:
+            yield d
+
+
+def is_privateuse1_backend_available():
+    privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
+    privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
+    return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
+
+
+def make_lazy_class(cls):
+
+    def lazy_init(self, cb):
+        self._cb = cb
+        self._value = None
+
+    cls.__init__ = lazy_init
+
+    for basename in [
+        "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
+        "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
+        "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
+    ]:
+        name = f"__{basename}__"
+
+        def inner_wrapper(name):
+            use_operator = basename not in ("bool", "int")
+
+            def wrapped(self, *args, **kwargs):
+                if self._cb is not None:
+                    self._value = self._cb()
+                    self._cb = None
+                if not use_operator:
+                    return getattr(self._value, name)(*args, **kwargs)
+                else:
+                    return getattr(operator, name)(self._value, *args, **kwargs)
+            return wrapped
+
+        setattr(cls, name, inner_wrapper(name))
+
+    return cls
+
+
+@make_lazy_class
+class LazyVal:
+    pass
+
+
+IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
+
+TEST_NUMPY = _check_module_exists('numpy')
+TEST_FAIRSEQ = _check_module_exists('fairseq')
+TEST_SCIPY = _check_module_exists('scipy')
+TEST_MKL = torch.backends.mkl.is_available()
+TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_acl_supported()
+TEST_MPS = torch.backends.mps.is_available()
+MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
+TEST_XPU = torch.xpu.is_available()
+TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
+TEST_CUDA = torch.cuda.is_available()
+TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available())  # type: ignore[call-arg]
+TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1)  # type: ignore[call-arg]
+custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
+TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
+TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
+TEST_NUMBA = _check_module_exists('numba')
+TEST_TRANSFORMERS = _check_module_exists('transformers')
+TEST_DILL = _check_module_exists('dill')
+
+TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64
+
+TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
+
+TEST_Z3 = _check_module_exists('z3')
+
+def split_if_not_empty(x: str):
+    return x.split(",") if len(x) != 0 else []
+
+NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', ''))
+
+skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill")
+
+
+NO_MULTIPROCESSING_SPAWN: bool = False
+TEST_WITH_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ASAN",
+    env_var="PYTORCH_TEST_WITH_ASAN",
+)
+TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_DEV_DBG_ASAN",
+    env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN",
+)
+TEST_WITH_TSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TSAN",
+    env_var="PYTORCH_TEST_WITH_TSAN",
+)
+TEST_WITH_UBSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_UBSAN",
+    env_var="PYTORCH_TEST_WITH_UBSAN",
+)
+TEST_WITH_ROCM: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ROCM",
+    env_var="PYTORCH_TEST_WITH_ROCM",
+)
+TEST_WITH_MTIA: bool = TestEnvironment.def_flag(
+    "TEST_WITH_MTIA",
+    env_var="PYTORCH_TEST_WITH_MTIA",
+)
+
+# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+# See #64427
+TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1'
+# Enables tests that are slow to run (disabled by default)
+TEST_WITH_SLOW: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW",
+    env_var="PYTORCH_TEST_WITH_SLOW",
+)
+
+# Disables non-slow tests (these tests enabled by default)
+# This is usually used in conjunction with TEST_WITH_SLOW to
+# run *only* slow tests.  (I could have done an enum, but
+# it felt a little awkward.
+TEST_SKIP_FAST: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_FAST",
+    env_var="PYTORCH_TEST_SKIP_FAST",
+)
+
+# Enables crossref tests, in addition to standard tests which
+# are being run.  crossref tests work by installing a torch
+# function mode that runs extra compute alongside the regular
+# computation that happens with the test.  After both computations
+# are done, we cross-reference them (thus the name) to check for
+# correction, before throwing out the extra compute and proceeding
+# as we had before.  By default, we don't run these tests.
+TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag(
+    "TEST_WITH_CROSSREF",
+    env_var="PYTORCH_TEST_WITH_CROSSREF",
+)
+
+TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_CUDAGRAPH",
+    env_var="PYTORCH_TEST_SKIP_CUDAGRAPH",
+)
+TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
+    torch.version.cuda or
+    (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
+)
+
+TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
+
+TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
+    torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
+)
+
+if TEST_CUDA_PYTHON_BINDINGS:
+    def cuda_python_error_check(function_call_output):
+        """Makes calls to cuda-python's cuda runtime functions more
+        pythonic by throwing an exception if they return a status
+        which is not cudaSuccess
+        """
+        import cuda.bindings  # type: ignore[import]
+
+        error, *others = function_call_output
+        if error != cuda.bindings.runtime.cudaError_t.cudaSuccess:
+            raise ValueError(f"CUDA failure! {error}")
+        else:
+            return tuple(others)
+else:
+    cuda_python_error_check = None  # type: ignore[assignment]
+
+def allocator_option_enabled_fn(allocator_config, _, option):
+    if allocator_config is None:
+        return False
+    allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
+    mapping = dict([var.split(':') for var in allocator_config])
+
+    if option in mapping and mapping[option] == 'True':
+        return True
+    else:
+        return False
+
+EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag(
+    "EXPANDABLE_SEGMENTS",
+    env_var="PYTORCH_CUDA_ALLOC_CONF",
+    enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'),
+)
+
+if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
+    num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
+    gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30
+    # other libraries take up about a little under 1 GB of space per process
+    torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2))
+
+requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA")
+
+def skipIfCrossRef(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_CROSSREF:
+            raise unittest.SkipTest("test doesn't currently with crossref")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+class CrossRefMode(torch.overrides.TorchFunctionMode):
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+        r = func(*args, **kwargs)
+        return r
+
+# Run PyTorch tests with TorchDynamo
+TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHINDUCTOR",
+    env_var="PYTORCH_TEST_WITH_INDUCTOR",
+)
+# AOT_EAGER not tested in ci, useful for debugging
+TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag(
+    "TEST_WITH_AOT_EAGER",
+    env_var="PYTORCH_TEST_WITH_AOT_EAGER",
+)
+TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHDYNAMO",
+    env_var="PYTORCH_TEST_WITH_DYNAMO",
+    implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER,
+)
+TEST_WITHOUT_COMPILED_AUTOGRAD: bool = TestEnvironment.def_flag(
+    "TEST_WITHOUT_COMPILED_AUTOGRAD",
+    env_var="PYTORCH_TEST_WITHOUT_COMPILED_AUTOGRAD",
+)
+
+if TEST_WITH_TORCHDYNAMO:
+    import torch._dynamo
+    # Do not spend time on helper functions that are called with different inputs
+    torch._dynamo.config.accumulated_recompile_limit = 64
+    # Do not log compilation metrics from unit tests
+    torch._dynamo.config.log_compilation_metrics = False
+    # Silence 3.13.0 guard performance warnings
+    torch._dynamo.config.issue_3_13_0_warning = False
+    if TEST_WITH_TORCHINDUCTOR:
+        import torch._inductor.config
+        torch._inductor.config.fallback_random = True
+    else:
+        # only dynamo for now
+        torch._dynamo.config.compiled_autograd = not TEST_WITHOUT_COMPILED_AUTOGRAD
+
+
+# seems like this is only used in test/torch_np
+def xpassIfTorchDynamo_np(func):
+    # numpy 2.0+ is causing issues
+    if TEST_WITH_TORCHDYNAMO and np.__version__[0] == '2':
+        return unittest.skip("skipping numpy 2.0+ dynamo-wrapped test")(func)
+    return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
+
+
+def xfailIfACL(func):
+    return unittest.expectedFailure(func) if TEST_ACL else func
+
+
+def xfailIfTorchDynamo(func):
+    return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
+
+
+def xfailIfPy312Plus(func):
+    return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func
+
+
+def xfailIfLinux(func):
+    return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func
+
+
+def xfailIfWindows(func):
+    return unittest.expectedFailure(func) if IS_WINDOWS else func
+
+
+def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
+    """
+    Usage:
+    @skipIfTorchDynamo(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?"
+
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if TEST_WITH_TORCHDYNAMO:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if TEST_WITH_TORCHDYNAMO:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
+                        condition=TEST_WITH_TORCHINDUCTOR):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def runWithoutCompiledAutograd(msg="test doesn't currently work with compiled autograd"):
+    """
+    Usage:
+    @runWithoutCompiledAutograd(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str)
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            with torch._dynamo.compiled_autograd._disable():
+                func(*args, **kwargs)
+        return wrapper
+
+    return decorator
+
+def serialTest(condition=True):
+    """
+    Decorator for running tests serially.  Requires pytest
+    """
+    # If one apply decorator directly condition will be callable
+    # And test will essentially be essentially skipped, which is undesirable
+    assert type(condition) is bool
+
+    def decorator(fn):
+        if has_pytest and condition:
+            return pytest.mark.serial(fn)
+        return fn
+    return decorator
+
+def unMarkDynamoStrictTest(cls=None):
+    def decorator(cls):
+        cls.dynamo_strict = False
+        return cls
+
+    if cls is None:
+        return decorator
+    else:
+        return decorator(cls)
+
+
+def markDynamoStrictTest(cls_or_func=None, nopython=False):
+    """
+    Marks the test as 'strict'. In strict mode, we reset before and after the
+    test, and run without suppress errors.
+
+    Args:
+    - nopython: if we should run torch._dynamo.optimize with nopython={True/False}.
+    """
+    def decorator(cls_or_func):
+        if inspect.isclass(cls_or_func):
+            cls_or_func.dynamo_strict = True
+            cls_or_func.dynamo_strict_nopython = nopython
+            return cls_or_func
+
+        fn = cls_or_func
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            torch._dynamo.reset()
+            with unittest.mock.patch("torch._dynamo.config.suppress_errors", False):
+                fn(*args, **kwargs)
+            torch._dynamo.reset()
+        return wrapper
+
+    if cls_or_func is None:
+        return decorator
+    else:
+        return decorator(cls_or_func)
+
+
+def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
+    return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
+
+def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                assert GRAPH_EXECUTOR
+                if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+
+    return decorator
+
+
+def make_dynamo_test(
+    fn: Optional[Callable[..., Any]] = None
+) -> Callable[..., Any]:
+    """
+    Decorator function to create a dynamo test case. A function annotate with
+    this decorator takes as input a unittest object.
+    """
+    from torch._dynamo.testing import CompileCounter, reset, optimize_assert
+    if fn is None:
+        return lambda fn: make_dynamo_test(fn)
+
+    def standard_test(
+        self: Any,
+        fn: Callable[..., Any],
+        kwargs,
+    ) -> None:
+        def dummy() -> None:
+            fn(self, **kwargs)
+
+        actual = CompileCounter()
+
+        dummy()
+        reset()
+        opt_fn = optimize_assert(actual)(dummy)
+        opt_fn()
+        reset()
+
+    @functools.wraps(fn)
+    def test_fn(self: Any, **kwargs) -> None:
+        return standard_test(
+            self,
+            fn=fn,
+            kwargs=kwargs,
+        )
+
+    return test_fn
+
+
+# Run PyTorch tests with translation validation on.
+TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
+
+if TEST_WITH_TV:
+    torch.fx.experimental._config.translation_validation = True
+
+# Determine whether to enable cuda memory leak check.
+# CUDA mem leak check is expensive and thus we don't want to execute it on every
+# test case / configuration.
+# If this is True then CUDA memory leak checks are skipped. If this is false
+#   then CUDA memory leak checks are performed.
+# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
+TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag(
+    "TEST_CUDA_MEM_LEAK_CHECK",
+    env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK",
+)
+
+
+# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
+numpy_to_torch_dtype_dict = {
+    np.bool_      : torch.bool,
+    np.uint8      : torch.uint8,
+    np.uint16     : torch.uint16,
+    np.uint32     : torch.uint32,
+    np.uint64     : torch.uint64,
+    np.int8       : torch.int8,
+    np.int16      : torch.int16,
+    np.int32      : torch.int32,
+    np.int64      : torch.int64,
+    np.float16    : torch.float16,
+    np.float32    : torch.float32,
+    np.float64    : torch.float64,
+    np.complex64  : torch.complex64,
+    np.complex128 : torch.complex128
+}
+
+
+# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
+# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
+# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
+def numpy_to_torch_dtype(np_dtype):
+    try:
+        return numpy_to_torch_dtype_dict[np_dtype]
+    except KeyError:
+        return numpy_to_torch_dtype_dict[np_dtype.type]
+
+
+def has_corresponding_torch_dtype(np_dtype):
+    try:
+        numpy_to_torch_dtype(np_dtype)
+        return True
+    except KeyError:
+        return False
+
+
+if IS_WINDOWS:
+    # Size of `np.intc` is platform defined.
+    # It is returned by functions like `bitwise_not`.
+    # On Windows `int` is 32-bit
+    # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160
+    numpy_to_torch_dtype_dict[np.intc] = torch.int
+
+# Dict of torch dtype -> NumPy dtype
+torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
+torch_to_numpy_dtype_dict.update({
+    torch.bfloat16: np.float32,
+    torch.complex32: np.complex64
+})
+
+def skipIfNNModuleInlined(
+    msg="test doesn't currently work with nn module inlining",
+    condition=torch._dynamo.config.inline_inbuilt_nn_modules,
+):
+    def decorator(fn):
+        if not isinstance(fn, type):
+
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
+    def dec_fn(fn):
+        reason = f"skipIfRocm: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_WITH_ROCM:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def getRocmArchName(device_index: int = 0):
+    return torch.cuda.get_device_properties(device_index).gcnArchName
+
+def isRocmArchAnyOf(arch: tuple[str, ...]):
+    rocmArch = getRocmArchName()
+    return any(x in rocmArch for x in arch)
+
+def skipIfRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
+                reason = f"skipIfRocm: test skipped on {arch}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def runOnRocm(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_ROCM:
+            fn(*args, **kwargs)
+        else:
+            raise unittest.SkipTest("test currently only works on the ROCm stack")
+    return wrapper
+
+def runOnRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
+                reason = f"skipIfRocm: test only runs on {arch}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def xfailIfS390X(func):
+    return unittest.expectedFailure(func) if IS_S390X else func
+
+def xfailIf(condition):
+    def wrapper(func):
+        if condition:
+            return unittest.expectedFailure(func)
+        else:
+            return func
+    return wrapper
+
+def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
+    def dec_fn(fn):
+        reason = f"skipIfXpu: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_XPU:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfMPS(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_MPS:
+            raise unittest.SkipTest("test doesn't currently work with MPS")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def skipIfHpu(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_HPU:
+            raise unittest.SkipTest("test doesn't currently work with HPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def getRocmVersion() -> tuple[int, int]:
+    from torch.testing._internal.common_cuda import _get_torch_rocm_version
+    rocm_version = _get_torch_rocm_version()
+    return (rocm_version[0], rocm_version[1])
+
+# Skips a test on CUDA if ROCm is available and its version is lower than requested.
+def skipIfRocmVersionLessThan(version=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM:
+                rocm_version_tuple = getRocmVersion()
+                if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
+                    reason = f"ROCm {rocm_version_tuple} is available but {version} required"
+                    raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def skipIfNotMiopenSuggestNHWC(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_MIOPEN_SUGGEST_NHWC:
+            raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"):
+    def dec_fn(fn):
+        reason = f"skipIfWindows: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_WINDOWS:  # noqa: F821
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfWindowsXPU(func=None, *, msg="test doesn't currently work on the Windows stack"):
+    def dec_fn(fn):
+        reason = f"skipIfWindowsXPU: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_WINDOWS and torch.xpu.is_available():  # noqa: F821
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def requires_cuda_p2p_access():
+    cuda_p2p_access_available = (
+        torch.cuda.is_available()
+        and torch.cuda.get_device_capability() >= (8, 0)
+        and torch.cuda.device_count() >= 2
+    )
+    num_devices = torch.cuda.device_count()
+    for i in range(num_devices - 1):
+        for j in range(i + 1, num_devices):
+            if not torch.cuda.can_device_access_peer(i, j):
+                cuda_p2p_access_available = False
+                break
+        if not cuda_p2p_access_available:
+            break
+
+    return skip_but_pass_in_sandcastle_if(
+        not cuda_p2p_access_available,
+        "cuda p2p access is not available",
+    )
+
+# Reverts the linalg backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setLinalgBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_linalg_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_linalg_library(_preferred_backend)
+    return _fn
+
+
+# Reverts the blas backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setBlasBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_blas_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_blas_library(_preferred_backend)
+    return _fn
+
+
+# Context manager for setting deterministic flag and automatically
+# resetting it to its original value
+class DeterministicGuard:
+    def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
+        self.deterministic = deterministic
+        self.warn_only = warn_only
+        self.fill_uninitialized_memory = fill_uninitialized_memory
+
+    @classmethod
+    def _current_state(cls):
+        return cls(
+            torch.are_deterministic_algorithms_enabled(),
+            warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
+            fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory,  # type: ignore[attr-defined]
+        )
+
+    def _update(self):
+        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
+        torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory  # type: ignore[attr-defined]
+
+    def __enter__(self):
+        self._restore = self._current_state()
+        self._update()
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        self._restore._update()
+
+class AlwaysWarnTypedStorageRemoval:
+    def __init__(self, always_warn):
+        assert isinstance(always_warn, bool)
+        self.always_warn = always_warn
+
+    def __enter__(self):
+        self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal()
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore)
+
+# Context manager for setting cuda sync debug mode and reset it
+# to original value
+# we are not exposing it to the core because sync debug mode is
+# global and thus not thread safe
+class CudaSyncGuard:
+    def __init__(self, sync_debug_mode):
+        self.mode = sync_debug_mode
+
+    def __enter__(self):
+        self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
+        torch.cuda.set_sync_debug_mode(self.mode)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
+
+# Context manager for setting torch.__future__.set_swap_module_params_on_conversion
+# and automatically resetting it to its original value
+class SwapTensorsGuard:
+    def __init__(self, use_swap_tensors):
+        self.use_swap_tensors = use_swap_tensors
+
+    def __enter__(self):
+        self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion()
+        if self.use_swap_tensors is not None:
+            torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore)
+
+# This decorator can be used for API tests that call
+# torch.use_deterministic_algorithms().  When the test is finished, it will
+# restore the previous deterministic flag setting.
+#
+# If CUDA >= 10.2, this will set the environment variable
+# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that
+# setting is not thrown during the test unless the test changes that variable
+# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be
+# restored once the test is finished.
+#
+# Note that if a test requires CUDA to actually register the changed
+# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because
+# CUDA only checks the variable when the runtime initializes. Tests can be
+# run inside a subprocess like so:
+#
+#   import subprocess, sys, os
+#   script = '''
+#   # Test code should go here
+#   '''
+#   try:
+#       subprocess.check_output(
+#           [sys.executable, '-c', script],
+#           stderr=subprocess.STDOUT,
+#           cwd=os.path.dirname(os.path.realpath(__file__)),
+#           env=os.environ.copy())
+#   except subprocess.CalledProcessError as e:
+#       error_message = e.output.decode('utf-8')
+#       # Handle exceptions raised by the subprocess here
+#
+def wrapDeterministicFlagAPITest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with DeterministicGuard(
+                torch.are_deterministic_algorithms_enabled(),
+                warn_only=torch.is_deterministic_algorithms_warn_only_enabled()):
+            class CuBLASConfigGuard:
+                cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
+
+                def __enter__(self):
+                    self.cublas_config_restore = os.environ.get(self.cublas_var_name)
+                    os.environ[self.cublas_var_name] = ':4096:8'
+
+                def __exit__(self, exception_type, exception_value, traceback):
+                    cur_cublas_config = os.environ.get(self.cublas_var_name)
+                    if self.cublas_config_restore is None:
+                        if cur_cublas_config is not None:
+                            del os.environ[self.cublas_var_name]
+                    else:
+                        os.environ[self.cublas_var_name] = self.cublas_config_restore
+            with CuBLASConfigGuard():
+                fn(*args, **kwargs)
+    return wrapper
+
+# This decorator can be used for API tests that want to safely call
+# torch.__future__.set_swap_module_params_on_conversion.  `swap` can be set to
+# True, False or None where None indicates that the context manager does not
+# set the flag. When the test is finished, it will restore the previous swap
+# flag setting.
+def wrapSwapTensorsTest(swap=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            with SwapTensorsGuard(swap):
+                fn(*args, **kwargs)
+        return wrapper
+    return dec_fn
+
+# test parametrizer for swapping
+class swap(_TestParametrizer):
+    def __init__(self, swap_values):
+        super().__init__()
+        self.swap_values = swap_values
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for swap in self.swap_values:
+            yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: []
+
+def skipIfCompiledWithoutNumpy(fn):
+    # Even if the numpy module is present, if `USE_NUMPY=0` is used during the
+    # build, numpy tests will fail
+    numpy_support = TEST_NUMPY
+    if numpy_support:
+        try:
+            # The numpy module is present, verify that PyTorch is compiled with
+            # numpy support
+            torch.from_numpy(np.array([2, 2]))
+        except RuntimeError:
+            numpy_support = False
+
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not numpy_support:
+            raise unittest.SkipTest("PyTorch was compiled without numpy support")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def _test_function(fn, device):
+    def run_test_function(self):
+        return fn(self, device)
+    return run_test_function
+
+def skipIfNoXNNPACK(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch.backends.xnnpack.enabled:  # type: ignore[attr-defined]
+            raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNoLapack(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch._C.has_lapack:
+            raise unittest.SkipTest('PyTorch compiled without Lapack')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNotRegistered(op_name, message):
+    """Wraps the decorator to hide the import of the `core`.
+
+    Args:
+        op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`.
+        message: message to fail with.
+
+    Usage:
+        @skipIfNotRegistered('MyOp', 'MyOp is not linked!')
+            This will check if 'MyOp' is in the caffe2.python.core
+    """
+    return unittest.skip("Pytorch is compiled without Caffe2")
+
+def skipIfNoSciPy(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_SCIPY:
+            raise unittest.SkipTest("test require SciPy, but SciPy not found")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skip_if_pytest(fn):
+    @wraps(fn)
+    def wrapped(*args, **kwargs):
+        if "PYTEST_CURRENT_TEST" in os.environ:
+            raise unittest.SkipTest("does not work under pytest")
+        return fn(*args, **kwargs)
+
+    return wrapped
+
+def skipIfNoXPU(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_XPU:
+            raise unittest.SkipTest("test required PyTorched compiled with XPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def slowTest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+        else:
+            fn(*args, **kwargs)
+    wrapper.__dict__['slow_test'] = True
+    return wrapper
+
+
+def slowTestIf(condition):
+    return slowTest if condition else lambda fn: fn
+
+
+def skipCUDAMemoryLeakCheckIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_memory_leak_check', True):  # if current True
+            fn._do_cuda_memory_leak_check = not condition
+        return fn
+    return dec
+
+def skipCUDANonDefaultStreamIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_non_default_stream', True):  # if current True
+            fn._do_cuda_non_default_stream = not condition
+        return fn
+    return dec
+
+def suppress_warnings(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def to_gpu(obj, type_map=None):
+    if type_map is None:
+        type_map = {}
+    if isinstance(obj, torch.Tensor):
+        assert obj.is_leaf
+        t = type_map.get(obj.dtype, obj.dtype)
+        with torch.no_grad():
+            res = obj.to(dtype=t, device="cuda", copy=True)
+            res.requires_grad = obj.requires_grad
+        return res
+    elif torch.is_storage(obj):
+        return obj.new().resize_(obj.size()).copy_(obj)  # type: ignore[attr-defined, union-attr]
+    elif isinstance(obj, list):
+        return [to_gpu(o, type_map) for o in obj]
+    elif isinstance(obj, tuple):
+        return tuple(to_gpu(o, type_map) for o in obj)
+    else:
+        return deepcopy(obj)
+
+
+def get_function_arglist(func):
+    return inspect.getfullargspec(func).args
+
+
+def set_rng_seed(seed=None):
+    if seed is None:
+        seed = SEED
+    torch.manual_seed(seed)
+    random.seed(seed)
+    if TEST_NUMPY:
+        np.random.seed(seed)
+
+
+@contextlib.contextmanager
+def set_default_dtype(dtype):
+    saved_dtype = torch.get_default_dtype()
+    torch.set_default_dtype(dtype)
+    try:
+        yield
+    finally:
+        torch.set_default_dtype(saved_dtype)
+
+@contextlib.contextmanager
+def set_default_tensor_type(tensor_type):
+    saved_tensor_type = torch.tensor([]).type()
+    torch.set_default_tensor_type(tensor_type)
+    try:
+        yield
+    finally:
+        torch.set_default_tensor_type(saved_tensor_type)
+
+def iter_indices(tensor):
+    if tensor.dim() == 0:
+        return range(0)
+    if tensor.dim() == 1:
+        return range(tensor.size(0))
+    return product(*(range(s) for s in tensor.size()))
+
+
+def is_iterable(obj):
+    try:
+        iter(obj)
+        return True
+    except TypeError:
+        return False
+
+
+def is_iterable_of_tensors(iterable, include_empty=False):
+    """ Returns True if iterable is an iterable of tensors and False o.w.
+
+        If the iterable is empty, the return value is :attr:`include_empty`
+    """
+    # Tensor itself is iterable so we check this first
+    if isinstance(iterable, torch.Tensor):
+        return False
+
+    try:
+        if len(iterable) == 0:
+            return include_empty
+
+        for t in iter(iterable):
+            if not isinstance(t, torch.Tensor):
+                return False
+
+    except TypeError:
+        return False
+
+    return True
+
+
+class CudaNonDefaultStream:
+    def __enter__(self):
+        # Before starting CUDA test save currently active streams on all
+        # CUDA devices and set new non default streams to all CUDA devices
+        # to ensure CUDA tests do not use default stream by mistake.
+        beforeDevice = torch.cuda.current_device()
+        self.beforeStreams = []
+        for d in range(torch.cuda.device_count()):
+            self.beforeStreams.append(torch.cuda.current_stream(d))
+            deviceStream = torch.cuda.Stream(device=d)
+            self.beforeStreams[-1].synchronize()
+            torch._C._cuda_setStream(stream_id=deviceStream.stream_id,
+                                     device_index=deviceStream.device_index,
+                                     device_type=deviceStream.device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # After completing CUDA test load previously active streams on all
+        # CUDA devices.
+        beforeDevice = torch.cuda.current_device()
+        for d in range(torch.cuda.device_count()):
+            torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id,
+                                     device_index=self.beforeStreams[d].device_index,
+                                     device_type=self.beforeStreams[d].device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+class CudaMemoryLeakCheck:
+    def __init__(self, testcase, name=None):
+        self.name = testcase.id() if name is None else name
+        self.testcase = testcase
+
+        # initialize context & RNG to prevent false positive detections
+        # when the test is the first to initialize those
+        from torch.testing._internal.common_cuda import initialize_cuda_context_rng
+        initialize_cuda_context_rng()
+
+    # Stores CUDA memory data provided by PyTorch's caching allocator and
+    #   the CUDA driver.
+    #
+    # NOTE: The undocumented torch.cuda.mem_get_info() returns
+    #   (#free bytes, #total bytes available) on the GPU
+    def __enter__(self):
+        self.caching_allocator_befores = []
+        self.driver_befores = []
+
+        # Performs a gc if required (required if any CUDA memory is held)
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+            # NOTE: gc is based exclusively on caching allocator memory
+            #   because the driver will always have some bytes in use (context size?)
+            if caching_allocator_mem_allocated > 0:
+                gc.collect()
+                torch._C._cuda_clearCublasWorkspaces()
+                torch.cuda.empty_cache()
+                break
+
+        # Acquires caching allocator and driver statistics before the test is run
+        for i in range(num_devices):
+            self.caching_allocator_befores.append(torch.cuda.memory_allocated(i))
+            bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+            driver_mem_allocated = bytes_total - bytes_free
+            self.driver_befores.append(driver_mem_allocated)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # Don't check for leaks if an exception was thrown
+        if exc_type is not None:
+            return
+
+        # Compares caching allocator before/after statistics
+        # An increase in allocated memory is a discrepancy indicating a possible
+        #   memory leak
+        discrepancy_detected = False
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            # avoid counting cublasWorkspace allocations
+            torch._C._cuda_clearCublasWorkspaces()
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+
+            if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                discrepancy_detected = True
+                break
+
+        # Short-circuits if no discrepancy detected
+        if not discrepancy_detected:
+            return
+
+        # Validates the discrepancy persists after garbage collection and
+        #   is confirmed by the driver API
+
+        # NOTE: driver API iscrepancies alone are ignored because with the jiterator
+        #   some tests may permanently increase the CUDA context size and
+        #   that will appear as a driver memory leak but is the expected behavior.
+
+        # GCs and clears the cache
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        for i in range(num_devices):
+
+            discrepancy_detected = True
+
+            # Query memory multiple items to ensure leak was not transient
+            for _ in range(3):
+                caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+                bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+                driver_mem_allocated = bytes_total - bytes_free
+
+                caching_allocator_discrepancy = False
+                driver_discrepancy = False
+
+                if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                    caching_allocator_discrepancy = True
+
+                if driver_mem_allocated > self.driver_befores[i]:
+                    driver_discrepancy = True
+
+                if not (caching_allocator_discrepancy or driver_discrepancy):
+                    # Leak was false positive, exit loop
+                    discrepancy_detected = False
+                    break
+
+            if not discrepancy_detected:
+                continue
+
+            if caching_allocator_discrepancy and not driver_discrepancy:  # type: ignore[possibly-undefined]
+                # Just raises a warning if the leak is not validated by the
+                #   driver API
+                # NOTE: this may be a problem with how the caching allocator collects its
+                #   statistics or a leak too small to trigger the allocation of an
+                #   additional block of memory by the CUDA driver
+                msg = ("CUDA caching allocator reports a memory leak not "  # type: ignore[possibly-undefined]
+                       f"verified by the driver API in {self.name}! "
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "  # type: ignore[possibly-undefined]
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")  # type: ignore[possibly-undefined]
+                warnings.warn(msg, stacklevel=2)
+            elif caching_allocator_discrepancy and driver_discrepancy:  # type: ignore[possibly-undefined]
+                # A caching allocator discrepancy validated by the driver API is a
+                #   failure (except on ROCm, see below)
+                msg = (f"CUDA driver API confirmed a leak in {self.name}! "  # type: ignore[possibly-undefined]
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "  # type: ignore[possibly-undefined]
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")  # type: ignore[possibly-undefined]
+
+                raise RuntimeError(msg)
+
+@contextmanager
+def skip_exception_type(exc_type):
+    try:
+        yield
+    except exc_type as e:
+        raise unittest.SkipTest(f"not implemented: {e}") from e
+
+@contextmanager
+def print_repro_on_failure(repro_parts):
+    try:
+        yield
+    except unittest.SkipTest:
+        raise
+    except Exception as e:
+        # Get the index of the sample input that failed the test if possible.
+        sample_isolation_prefix = ""
+        tracked_input = getattr(e, "_tracked_input", None)
+        if tracked_input is not None:
+            sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}"
+
+        repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts)))
+
+        open_source_signpost(
+            subsystem="test_repros",
+            name="test_failure",
+            parameters=json.dumps(
+                {
+                    "repro": " ".join(filter(None, (sample_isolation_prefix, *repro_parts))),
+                }
+            ),
+        )
+
+        repro_msg = f"""
+To execute this test, run the following from the base repo dir:
+    {repro_str}
+
+This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
+
+        # NB: Hacking the exception args is the cleanest way I've found to append
+        # failure reproduction info without poisoning the stack trace.
+        if len(e.args) >= 1:
+            e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:])
+        raise
+
+#  "min_satisfying_examples" setting has been deprecated in hypothesis
+#  3.56.0 and removed in hypothesis 4.x
+try:
+    import hypothesis
+
+    def settings(*args, **kwargs):
+        if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
+            kwargs.pop('min_satisfying_examples')
+        return hypothesis.settings(*args, **kwargs)
+
+
+    hypothesis.settings.register_profile(
+        "pytorch_ci",
+        settings(
+            derandomize=True,
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=50,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "dev",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=10,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "debug",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=1000,
+            verbosity=hypothesis.Verbosity.verbose))
+
+    hypothesis.settings.load_profile(
+        "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
+    )
+except ImportError:
+    warnings.warn('Fail to import hypothesis in common_utils, tests are not derandomized', ImportWarning, stacklevel=2)
+
+# Used in check_if_enable to see if a test method should be disabled by an issue,
+# sanitizes a test method name from appended suffixes by @dtypes parametrization.
+# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
+# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
+def remove_device_and_dtype_suffixes(test_name: str) -> str:
+    # import statement is localized to avoid circular dependency issues with common_device_type.py
+    from torch.testing._internal.common_device_type import get_device_type_test_bases
+    device_suffixes = [x.device_type for x in get_device_type_test_bases()]
+    dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
+
+    test_name_chunks = test_name.split("_")
+    if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
+        if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
+            return "_".join(test_name_chunks[0:-2])
+        return "_".join(test_name_chunks[0:-1])
+    return test_name
+
+
+def check_if_enable(test: unittest.TestCase):
+    classname = str(test.__class__).split("'")[1].split(".")[-1]
+    sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName)
+
+    def matches_test(target: str):
+        target_test_parts = target.split()
+        if len(target_test_parts) < 2:
+            # poorly formed target test name
+            return False
+        target_testname = target_test_parts[0]
+        target_classname = target_test_parts[1][1:-1].split(".")[-1]
+        # if test method name or its sanitized version exactly matches the disabled
+        # test method name AND allow non-parametrized suite names to disable
+        # parametrized ones (TestSuite disables TestSuiteCPU)
+        return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname))
+
+    if any(matches_test(x) for x in slow_tests_dict):
+        getattr(test, test._testMethodName).__dict__['slow_test'] = True
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+
+    if not IS_SANDCASTLE:
+        should_skip = False
+        skip_msg = ""
+
+        for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
+            if matches_test(disabled_test):
+                platform_to_conditional: dict = {
+                    "mac": IS_MACOS,
+                    "macos": IS_MACOS,
+                    "win": IS_WINDOWS,
+                    "windows": IS_WINDOWS,
+                    "linux": IS_LINUX,
+                    "rocm": TEST_WITH_ROCM,
+                    "xpu": TEST_XPU,
+                    "asan": TEST_WITH_ASAN,
+                    "dynamo": TEST_WITH_TORCHDYNAMO,
+                    "dynamo_wrapped": TEST_WITH_TORCHDYNAMO,
+                    "inductor": TEST_WITH_TORCHINDUCTOR,
+                    "slow": TEST_WITH_SLOW,
+                }
+
+                invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
+                if len(invalid_platforms) > 0:
+                    invalid_plats_str = ", ".join(invalid_platforms)
+                    valid_plats = ", ".join(platform_to_conditional.keys())
+
+                    print(f"Test {disabled_test} is disabled for some unrecognized ",
+                          f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ",
+                          'assigned to this flaky test, changing "Platforms: ..." to a comma separated ',
+                          f"subset of the following (or leave it blank to match all platforms): {valid_plats}")
+
+                    # Sanitize the platforms list so that we continue to disable the test for any valid platforms given
+                    platforms = list(filter(lambda p: p in platform_to_conditional, platforms))
+
+                if platforms == [] or any(platform_to_conditional[platform] for platform in platforms):
+                    should_skip = True
+                    skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
+                        f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
+                        "If you're seeing this on your local machine and would like to enable this test, " \
+                        "please make sure CI is not set and you are not using the flag --import-disabled-tests."
+                    break
+
+        if should_skip and not RERUN_DISABLED_TESTS:
+            # Skip the disabled test when not running under --rerun-disabled-tests verification mode
+            raise unittest.SkipTest(skip_msg)
+
+        if not should_skip and RERUN_DISABLED_TESTS:
+            # Probably test has disable issue but not for this platform
+            skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \
+                " disabled tests are run"
+            raise unittest.SkipTest(skip_msg)
+
+    if TEST_SKIP_FAST:
+        if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
+            raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
+
+
+# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very
+# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of
+# `torch.testing._comparison.are_equal`, used for example by the public testing function
+# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence
+# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only
+# change the supported inputs, but the comparison logic is the same.
+# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation.
+
+class RelaxedBooleanPair(BooleanPair):
+    """Pair for boolean-like inputs.
+
+    In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single
+    element tensor-like.
+    """
+    _supported_number_types = NumberPair(0, 0)._supported_types
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
+        # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
+        if not (
+            (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+            or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_bool(input, id=id) for input in (actual, expected)]
+
+    def _to_bool(self, bool_like, *, id):
+        if isinstance(bool_like, np.number):
+            return bool(bool_like.item())
+        elif type(bool_like) in self._supported_number_types:
+            return bool(bool_like)
+        elif isinstance(bool_like, (torch.Tensor, np.ndarray)):
+            numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a boolean. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+
+            return bool(bool_like.item())
+        else:
+            return super()._to_bool(bool_like, id=id)
+
+
+class RelaxedNumberPair(NumberPair):
+    """Pair for number-like inputs.
+
+    In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element
+    tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when
+    ``check_dtype=True`` is passed.
+
+    In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also
+    supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and
+    ``@toleranceOverride`` decorators.
+    """
+    _TYPE_TO_DTYPE = {
+        int: torch.int64,
+        float: torch.float32,
+        complex: torch.complex64,
+    }
+
+    def __init__(
+            self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters
+    ) -> None:
+        super().__init__(actual, expected, check_dtype=False, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
+        # element tensor or array, whereas in default NumberPair both inputs have to be numbers.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *tensor_or_array_types)
+        if not (
+                (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+                or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_number(input, id=id) for input in (actual, expected)]
+
+    def _to_number(self, number_like, *, id):
+        if isinstance(number_like, (torch.Tensor, np.ndarray)):
+            numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a number. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+            number = number_like.item()
+            if isinstance(number, bool):
+                number = int(number)
+
+            return number
+        elif isinstance(number_like, Enum):
+            return int(number_like)  # type: ignore[call-overload]
+        else:
+            number = super()._to_number(number_like, id=id)
+            if type(number) not in self._TYPE_TO_DTYPE:
+                self._inputs_not_supported()
+            return number
+
+
+class TensorOrArrayPair(TensorLikePair):
+    """Pair for tensor-like inputs.
+
+    On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of
+    :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a
+    tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their
+    relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine.
+
+    In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride``
+    and ``@toleranceOverride`` decorators.
+    """
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id, allow_subclasses):
+        self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray))
+
+        actual, expected = (self._to_tensor(input) for input in (actual, expected))
+        for tensor in (actual, expected):
+            self._check_supported(tensor, id=id)
+        return actual, expected
+
+
+class TypedStoragePair(TensorLikePair):
+    """Pair for :class:`torch.storage.TypedStorage` inputs."""
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage)
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _to_tensor(self, typed_storage):
+        return torch.tensor(
+            typed_storage._untyped_storage,
+            dtype={
+                torch.quint8: torch.uint8,
+                torch.quint4x2: torch.uint8,
+                torch.quint2x4: torch.uint8,
+                torch.qint32: torch.int32,
+                torch.qint8: torch.int8
+            }.get(typed_storage.dtype, typed_storage.dtype),
+            device=typed_storage.device,
+        )
+
+
+class UnittestPair(Pair):
+    """Fallback ABC pair that handles non-numeric inputs.
+
+    To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in
+    order to use it with the :class:`Pair` "framework" from :func:`are_equal`.
+
+    Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
+    """
+    CLS: Union[type, tuple[type, ...]]
+    TYPE_NAME: Optional[str] = None
+
+    def __init__(self, actual, expected, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=self.CLS)
+        super().__init__(actual, expected, **other_parameters)
+
+    def compare(self):
+        test_case = unittest.TestCase()
+
+        try:
+            return test_case.assertEqual(self.actual, self.expected)
+        except test_case.failureException as error:
+            msg = str(error)
+
+        type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__
+        self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}")
+
+
+class StringPair(UnittestPair):
+    CLS = (str, bytes)
+    TYPE_NAME = "string"
+
+
+class SetPair(UnittestPair):
+    CLS = set
+
+
+class TypePair(UnittestPair):
+    CLS = type
+
+
+class ObjectPair(UnittestPair):
+    CLS = object
+
+
+# This implements a variant of assertRaises/assertRaisesRegex where we first test
+# if the exception is NotImplementedError, and if so just skip the test instead
+# of failing it.
+#
+# This is implemented by inheriting from the (private) implementation of
+# assertRaises from unittest.case, and slightly tweaking it for this new
+# behavior.  The year is 2021: this private class hierarchy hasn't changed since
+# 2010, seems low risk to inherit from.
+class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext):
+    def __exit__(self, exc_type, exc_value, tb):
+        if exc_type is not None and issubclass(exc_type, NotImplementedError):
+            self.test_case.skipTest(f"not_implemented: {exc_value}")  # type: ignore[attr-defined]
+        return super().__exit__(exc_type, exc_value, tb)
+
+
+@contextmanager
+def set_warn_always_context(new_val: bool):
+    old_val = torch.is_warn_always_enabled()
+    torch.set_warn_always(new_val)
+    try:
+        yield
+    finally:
+        torch.set_warn_always(old_val)
+
+
+class NoTest:
+    # causes pytest to not recognize this class as a test
+    __test__ = False
+
+
+class TestCase(expecttest.TestCase):
+    # NOTE: "precision" lets classes and generated tests set minimum
+    # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for
+    # example.
+    # NOTE: "rel_tol" lets classes and generated tests set minimum
+    # rtol values when comparing tensors. Used by @toleranceOverride, for example.
+    _precision: float = 0
+    _rel_tol: float = 0
+
+    # Toggles whether to assert that `torch.get_default_dtype()` returns
+    # `torch.float` when `setUp` and `tearDown` are called.
+    _default_dtype_check_enabled: bool = False
+
+    # Always use difflib to print diffs on multi line equality.
+    # Undocumented feature in unittest
+    _diffThreshold = sys.maxsize
+    maxDiff = None
+
+    # checker to early terminate test suite if unrecoverable failure occurs.
+    def _should_stop_test_suite(self):
+        if torch.cuda.is_initialized():
+            # CUDA device side error will cause subsequence test cases to fail.
+            # stop entire test suite if catches RuntimeError during torch.cuda.synchronize().
+            try:
+                torch.cuda.synchronize()
+            except RuntimeError as rte:
+                print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr)
+                print(str(rte), file=sys.stderr)
+                return True
+            return False
+        else:
+            return False
+
+    @property
+    def precision(self) -> float:
+        return self._precision
+
+    @precision.setter
+    def precision(self, prec: float) -> None:
+        self._precision = prec
+
+    @property
+    def rel_tol(self) -> float:
+        return self._rel_tol
+
+    @rel_tol.setter
+    def rel_tol(self, prec: float) -> None:
+        self._rel_tol = prec
+
+    _do_cuda_memory_leak_check = False
+    _do_cuda_non_default_stream = False
+
+    # When True, if a test case raises a NotImplementedError, instead of failing
+    # the test, skip it instead.
+    _ignore_not_implemented_error = False
+
+    def __init__(self, method_name='runTest', methodName='runTest'):
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+
+        test_method = getattr(self, method_name, None)
+        if test_method is not None:
+            # Wraps the tested method if we should do CUDA memory check.
+            if TEST_CUDA_MEM_LEAK_CHECK:
+                self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
+                # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
+                if self._do_cuda_memory_leak_check and not IS_WINDOWS:
+                    self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors)
+
+            # Wraps the tested method if we should enforce non default CUDA stream.
+            self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True)
+            if self._do_cuda_non_default_stream and not IS_WINDOWS:
+                self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream)
+
+            if self._ignore_not_implemented_error:
+                self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
+
+            if PRINT_REPRO_ON_FAILURE:
+                try:
+                    def _get_rel_test_path(abs_test_path):
+                        # Attempt to get relative path based on the "test" dir.
+                        # In CI, the working dir is not guaranteed to be the base repo dir so
+                        # we can't just compute relative path from that.
+                        parts = Path(abs_test_path).parts
+                        for i, part in enumerate(parts):
+                            if part == "test":
+                                base_dir = os.path.join(*parts[:i]) if i > 0 else ''
+                                return os.path.relpath(abs_test_path, start=base_dir)
+
+                        # Can't determine containing dir; just return the test filename.
+                        # The path isn't strictly correct but it's arguably better than nothing.
+                        return os.path.split(abs_test_path)[1]
+
+                    abs_test_path = inspect.getfile(type(self))
+                    test_filename = _get_rel_test_path(abs_test_path)
+                    class_name = type(self).__name__
+                    test_run_cmd = f"python {test_filename} {class_name}.{method_name}"
+                    env_var_prefix = TestEnvironment.repro_env_var_prefix()
+                    repro_parts = [env_var_prefix, test_run_cmd]
+                    self.wrap_with_policy(
+                        method_name,
+                        lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
+                except Exception as e:
+                    # Don't fail entirely if we can't get the test filename
+                    log.info("could not print repro string", extra=str(e))  # type: ignore[arg-type]
+
+    def assertLeaksNoCudaTensors(self, name=None):
+        name = self.id() if name is None else name
+        return CudaMemoryLeakCheck(self, name)
+
+    def enforceNonDefaultStream(self):
+        return CudaNonDefaultStream()
+
+    def _remove_ansi_escape(self, input):
+        # 7-bit C1 ANSI sequences
+        ansi_escape = re.compile(r'''
+            \x1B  # ESC
+            (?:   # 7-bit C1 Fe (except CSI)
+                [@-Z\\-_]
+            |     # or [ for CSI, followed by a control sequence
+                \[
+                [0-?]*  # Parameter bytes
+                [ -/]*  # Intermediate bytes
+                [@-~]   # Final byte
+            )
+        ''', re.VERBOSE)
+        return ansi_escape.sub('', input)
+
+    def remove_comment_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if not line.strip().startswith('#')]
+        return '\n'.join(filtered_lines)
+
+    def remove_empty_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if line.strip() != '']
+        return '\n'.join(filtered_lines)
+
+    # ignore comments will ignore lines that starts with # after being stripped
+    def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False):
+        actual = actual if isinstance(actual, str) else str(actual)
+        actual = self._remove_ansi_escape(actual)
+        expect = self._remove_ansi_escape(expect)
+        if ignore_comments:
+            actual = self.remove_comment_lines(actual)
+            expect = self.remove_comment_lines(expect)
+
+        if ignore_empty_lines:
+            actual = self.remove_empty_lines(actual)
+            expect = self.remove_empty_lines(expect)
+
+        return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
+
+    # Munges exceptions that internally contain stack traces, using munge_exc
+    def assertExpectedInlineMunged(
+        self, exc_type, callable, expect, *, skip=0, suppress_suffix=True, post_munge=None,
+    ):
+        try:
+            callable()
+        except exc_type as e:
+            munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=skip + 1)
+            if post_munge:
+                munged = post_munge(munged)
+            self.assertExpectedInline(
+                munged, expect, skip=skip + 1
+            )
+            return
+        self.fail(msg="Did not raise when expected to")
+
+    def assertLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertLogs(logger, level)
+
+    def assertNoLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertNoLogs(logger, level)
+
+    def wrap_with_cuda_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        # the import below may initialize CUDA context, so we do it only if
+        # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream
+        # is True.
+        # TODO: sure looks like we unconditionally initialize the context here
+        # -- ezyang
+        from torch.testing._internal.common_cuda import TEST_CUDA
+        fullname = self.id().lower()  # class_name.method_name
+        if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
+            setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    def wrap_with_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    # A policy is a zero-argument function that returns a context manager.
+    # We don't take the context manager directly as it may be necessary to
+    # construct it once per test method
+    def wrap_method_with_policy(self, method, policy):
+        # Assumes that `method` is the tested function in `self`.
+        # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
+        #       alive, so this cannot be done in setUp and tearDown because
+        #       tearDown is run unconditionally no matter whether the test
+        #       passes or not. For the same reason, we can't wrap the `method`
+        #       call in try-finally and always do the check.
+        @wraps(method)
+        def wrapper(self, *args, **kwargs):
+            with policy():
+                method(*args, **kwargs)
+        return types.MethodType(wrapper, self)
+
+    def wrap_with_cuda_memory_check(self, method):
+        return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
+
+    def _dynamo_test_key(self):
+        return f"{self.__class__.__name__}.{self._testMethodName}"
+
+    def compile_fn(self, fn, backend, nopython):
+        # Allows subclasses to control compilation
+        return torch._dynamo.optimize(backend, nopython=nopython)(fn)
+
+    def _run_custom(self, result=None):
+        using_unittest = isinstance(result, unittest.TestResult)
+
+        super_run = super().run
+        test_cls = super_run.__self__  # type: ignore[attr-defined]
+
+        # Are we compiling?
+        compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
+        # Is the class strict and compiling?
+        strict_default = False
+        should_reset_dynamo = False
+
+        # We disable size_asserts for test_ops since some tests fail
+        # due to mismatch of strides returned from eager v.s. meta kernels
+        # Only some of the ops has this problem, but since tests in
+        # test_op.py are parametrized, it's hard to do this specifically
+        # for the affected ops.
+        # It's not a big deal since these problems are captured by
+        # test_torchinductor_opinfo.py as well.
+        should_disable_size_asserts = False
+        if compiled:
+            try:
+                path = inspect.getfile(type(test_cls))
+                full_path = os.path.abspath(path)
+                match = re.match(r".*/test/(.*).py", full_path)
+                if match is not None:
+                    filename = match.group(1)
+                    if TEST_WITH_TORCHINDUCTOR:
+                        from .dynamo_test_failures import FIXME_inductor_non_strict
+                        strict_default = filename not in FIXME_inductor_non_strict
+                        should_reset_dynamo = True
+
+                        if filename == "test_ops":
+                            should_disable_size_asserts = True
+                    else:
+                        strict_default = True
+            # inspect.getfile can fail with these
+            except (OSError, TypeError):
+                pass
+            if "STRICT_DEFAULT" in os.environ:
+                if os.environ["STRICT_DEFAULT"] == "1":
+                    strict_default = True
+
+        strict_mode = False
+        if compiled:
+            test_method = getattr(self, self._testMethodName)
+            if hasattr(test_method, "dynamo_strict"):
+                strict_mode = test_method.dynamo_strict
+            elif hasattr(test_cls, "dynamo_strict"):
+                strict_mode = test_cls.dynamo_strict
+            else:
+                strict_mode = strict_default
+        nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+
+        torch.compiler.set_stance("default")
+
+        # TODO: Remove this; this is grandfathered in because we suppressed errors
+        # on test suite previously
+        # When strict mode is False, suppress_errors is True
+        if compiled:
+            suppress_errors = not strict_mode
+        else:
+            suppress_errors = torch._dynamo.config.suppress_errors
+
+        maybe_disable_size_asserts = (
+            torch._inductor.config.patch(size_asserts=False)
+            if should_disable_size_asserts
+            else contextlib.nullcontext()
+        )
+
+        with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
+            if TEST_WITH_AOT_EAGER:
+                super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
+            elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
+                if TEST_WITH_TORCHINDUCTOR:
+                    super_run = self.compile_fn(super_run, "inductor", nopython)
+                else:
+                    # Assume eager-generated GraphModules will not error out.
+                    # If we do, this is probably a Dynamo bug!
+                    super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
+
+                key = self._dynamo_test_key()
+
+                def expect_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:  # noqa: B036
+                            self.skipTest(e)
+                        raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_expected_failures"
+                    from .dynamo_test_failures import inductor_expected_failures as expected_failures
+                else:
+                    subdir = "test/dynamo_expected_failures"
+                    from .dynamo_test_failures import dynamo_expected_failures as expected_failures
+
+                if key in expected_failures:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, expect_failure(method, file_name))
+
+                def ignore_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:  # noqa: B036
+                            self.skipTest(e)
+                        method = getattr(self, self._testMethodName)
+                        if getattr(method, "__unittest_expecting_failure__", False):
+                            self.skipTest("unexpected success")
+                        else:
+                            self.skipTest(f"This test passed, maybe we can remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_skips"
+                    from .dynamo_test_failures import inductor_skips as skips
+                else:
+                    subdir = "test/dynamo_skips"
+                    from .dynamo_test_failures import dynamo_skips as skips
+
+                if key in skips:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, ignore_failure(method, file_name))
+
+                from .dynamo_test_failures import compiled_autograd_skips
+                if torch._dynamo.config.compiled_autograd and key in compiled_autograd_skips:
+                    # Still run the test, but with compiled autograd disabled
+                    super_run = runWithoutCompiledAutograd()(super_run)
+
+            super_run(result=result)
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+        elif torch._dynamo.config.compiled_autograd:
+            torch._dynamo.compiled_autograd.reset()
+
+        # Early terminate test if necessary.  If using pytest, use the -x flag instead
+        if using_unittest and self._should_stop_test_suite():
+            if result.wasSuccessful():
+                case = TestCase()
+                if TEST_SAVE_XML is not None:
+                    # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo
+                    # Create dummy TestInfo to record results correctly
+                    from xmlrunner.result import _TestInfo  # type: ignore[import]
+                    case = _TestInfo(result, case)
+                    case.output = _TestInfo.ERROR  # type: ignore[attr-defined]
+                    case.elapsed_time = 0.0  # type: ignore[attr-defined]
+                    case.test_description = "TestSuiteEarlyFailure"  # type: ignore[attr-defined]
+                # This shouldn't really happen, but if does add fake failure
+                # For more details see https://github.com/pytorch/pytorch/issues/71973
+                result.failures.append((case, "TestSuite execution was aborted early"))
+                assert result.wasSuccessful() is False
+            result.stop()
+
+
+    def run(self, result=None):
+        with contextlib.ExitStack() as stack:
+            if TEST_WITH_CROSSREF:
+                stack.enter_context(CrossRefMode())
+            self._run_custom(
+                result=result,
+            )
+
+    def setUp(self):
+        check_if_enable(self)
+        set_rng_seed()
+
+        # Save global check sparse tensor invariants state that can be
+        # restored from tearDown:
+        self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled()
+
+        # Enable invariant checks for all sparse tensors constructions
+        # including the unsafe ones. If this is not desired for some
+        # test case, use check_invariants=False optional argument to
+        # sparse tensor constructors or
+        # @torch.sparse.check_sparse_tensor_invariants(False)
+        # decorator to disable the invariant checks.
+        torch.sparse.check_sparse_tensor_invariants.enable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attempt to reset some global state at the end of the test
+        self._prev_grad_state = torch.is_grad_enabled()
+
+    def tearDown(self):
+        # There exists test cases that override TestCase.setUp
+        # definition, so we cannot assume that _check_invariants
+        # attribute is defined in general.
+        if hasattr(self, '_check_invariants'):
+            # Restore the global check sparse tensor invariants state
+            if self._check_invariants:
+                torch.sparse.check_sparse_tensor_invariants.enable()
+            else:
+                torch.sparse.check_sparse_tensor_invariants.disable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attribute may not be defined, per above
+        if hasattr(self, '_prev_grad_state'):
+            torch.set_grad_enabled(self._prev_grad_state)
+
+    @staticmethod
+    def _make_crow_indices(n_rows, n_cols, nnz,
+                           *, device, dtype, random=True):
+        """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and
+        the number of specified elements nnz.
+
+        If random is True, the column counts of rows are in random
+        order. Otherwise, the column counts of rows are defined by the
+        used sampling method.
+
+        Sampling method
+        ---------------
+
+        The used sampling method was introduced in
+        https://pearu.github.io/csr_sampling.html, and here we give
+        only an overall description of the method.
+
+        Notice that crow_indices can be defined as cumsum(counts)
+        where counts is a sequence of non-negative integers satisfying
+        the following conditions:
+
+          len(counts) == n_rows + 1
+          counts.max() <= n_cols
+
+        while counts[i + 1] is interpreted as the number of specified
+        elements in the i-th row.
+
+        The used sampling method aims at increasing the diversity of
+        CSR samples, that is, a CSR sample should contain (i) rows
+        that are all filled, (ii) rows with no elements at all, and
+        (iii) rows that are partially filled. At the same time and for
+        the given total number of specified elements (nnz), there
+        should be minimal preference to rows with a given number of
+        elements.  To achieve this, the sampling method is built-up on
+        using a sawteeth model for counts. In the simplest case, we
+        would have
+
+          counts = arange(n_rows + 1) % (n_cols + 1)
+
+        that has equal number of all possible column counts per row.
+        This formula can be used only for specific input values of
+        n_rows, n_cols, and nnz. To generalize this model to any
+        combinations of inputs, the counts model above is extended
+        with an incomplete sawtooth, and the right and lower
+        rectangular parts that will guarantee that
+
+          counts.sum() == nnz
+
+        for any combination of n_rows, n_cols, and nnz. Basically,
+        we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid
+        that is able to hold a sequence of sawteeth and so-called
+        final correction, while the external part of the window is
+        filled with counts to meet the nnz constraint exactly.
+        """
+        assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols)
+
+        def sawteeth(n, m):
+            # return the total number of counts in the sequence of
+            # sawteeth where n and m define a window in (n_rows+1,
+            # n_cols+1) rectangle where the sequence of sawteeth
+            # perfectly fit.
+            M = (n_cols - m) * (n_cols - m + 1) // 2
+            K = (n_rows - n) % (n_cols - m + 1)
+            return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2
+
+        # Different from the original method description, here counts
+        # has leading 0 required by crow_indices:
+        counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu'))
+
+        n = m = 0
+        N = sawteeth(n, m)
+        if N and nnz >= max(N, n_cols):
+            # determine the width of the sawteeth window. We use bisection to solve
+            #   N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols)
+            # for n
+            n_left = n
+            n_right = n_rows - 1
+            N_right = sawteeth(n_right, m)
+            while n_right - n_left > 1:
+                n_middle = (n_left + n_right) // 2
+                N_middle = sawteeth(n_middle, m)
+                if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols):
+                    n_right, N_right = n_middle, N_middle
+                else:
+                    n_left = n_middle
+            n, N = n_right, N_right
+            # fill the right rectangle with counts:
+            assert n
+            counts[-n:].fill_(n_cols)
+
+        if N and nnz - n * n_cols >= max(N, n_rows - n):
+            # determine the height of the sawteeth window. We use bisection to solve
+            #   N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n)
+            # for m.
+            m_left = m
+            m_right = n_cols - 1
+            N_right = sawteeth(n, m_right)
+            while m_right - m_left > 1:
+                m_middle = (m_left + m_right) // 2
+                N_middle = sawteeth(n, m_middle)
+                if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n):
+                    m_right, N_right = m_middle, N_middle
+                else:
+                    m_left = m_middle
+            m, N = m_right, N_right
+            # fill the bottom rectangle with counts:
+            assert m
+            counts[1:n_rows - n + 1].fill_(m)
+
+        if N:
+            # fill the sawteeth window with counts
+            q, r = divmod(nnz - n * n_cols - m * (n_rows - n),
+                          (n_cols - m) * (n_cols - m + 1) // 2)
+            p = 1 + q * (n_cols - m + 1)
+            k = math.isqrt(2 * r)
+            if k * (k + 1) > 2 * r:
+                k -= 1
+            corr = r - k * (k + 1) // 2
+            assert not ((p > 1) and (m > 0))  # full sawteeth are never on top of a bottom rectangle
+            # sequence of full sawteeth:
+            counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1)
+            # incomplete sawtooth:
+            counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device)
+        else:
+            # given input does not support sawteeth
+            p = 1
+            corr = nnz - n * n_cols - m * (n_rows - n)
+
+        # correction that will guarantee counts.sum() == nnz:
+        counts[p] += corr
+
+        if random:
+            # randomize crow_indices by shuffling the sawteeth
+            # sequence:
+            perm = torch.randperm(n_rows, device=counts.device)
+            counts[1:] = counts[1:][perm]
+
+        # compute crow_indices:
+        crow_indices = counts
+        crow_indices.cumsum_(dim=0)
+        return crow_indices.to(device=device)
+
+    def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0):
+        from operator import mul
+        from functools import reduce
+        sparse_dim = 2
+        assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
+        assert len(size) >= sparse_dim
+        if blocksize:
+            assert len(blocksize) == 2, (size, blocksize)
+            assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize)
+            assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize)
+            blocksize0, blocksize1 = blocksize
+        else:
+            blocksize0 = blocksize1 = 1
+
+        size = tuple(size)
+        dense_size = size[(len(size) - dense_dims):]
+
+        def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
+            compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype)
+            plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device)
+            for i in range(n_compressed_dims):
+                count = compressed_indices[i + 1] - compressed_indices[i]
+                plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort(
+                    torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count])
+            low = -1 if dtype != torch.uint8 else 0
+            high = 1 if dtype != torch.uint8 else 2
+            values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high)
+            return values, compressed_indices, plain_indices
+
+        batch_shape = size[:-2 - dense_dims]
+        n_batch = reduce(mul, batch_shape, 1)
+
+        if layout in {torch.sparse_csr, torch.sparse_bsr}:
+            n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1
+        else:
+            n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0
+        blocknnz = nnz // (blocksize0 * blocksize1)
+        sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)]
+        sparse_tensors_it = map(list, zip(*sparse_tensors, strict=True))
+
+        values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size)
+        compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        return torch.sparse_compressed_tensor(compressed_indices, plain_indices,
+                                              values, size=size, dtype=dtype, layout=layout, device=device)
+
+    def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims)
+
+    def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0)
+
+    def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype):
+        # Assert not given impossible combination, where the sparse dims have
+        # empty numel, but nnz > 0 makes the indices containing values.
+        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+
+        v_size = [nnz] + list(size[sparse_dim:])
+        v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1)
+        i = torch.rand(sparse_dim, nnz, device=device)
+        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+        i = i.to(torch.long)
+        if is_uncoalesced:
+            i1 = i[:, :(nnz // 2), ...]
+            i2 = i[:, :((nnz + 1) // 2), ...]
+            i = torch.cat([i1, i2], 1)
+        x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
+
+        if not is_uncoalesced:
+            x = x.coalesce()
+        else:
+            # FIXME: `x` is a sparse view of `v`. Currently rebase_history for
+            #        sparse views is not implemented, so this workaround is
+            #        needed for inplace operations done on `x`, e.g., copy_().
+            #        Remove after implementing something equivalent to CopySlice
+            #        for sparse views.
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
+            x = x.detach().clone()._coalesced_(False)
+        return x, x._indices().clone(), x._values().clone()
+
+    def generate_simple_inputs(self, layout,
+                               device=None,
+                               dtype=None,
+                               index_dtype=None,
+                               pin_memory=None,
+                               members_pin_memory=None,
+                               enable_batch=True,
+                               enable_hybrid=True,
+                               enable_zero_sized=True,
+                               enable_non_contiguous_indices=True,
+                               enable_non_contiguous_values=True,
+                               enable_batch_variable_nse=False,
+                               output_tensor=True,
+                               patterns=None):
+        """Generator of simple inputs for tensor constructors of the given layout.
+
+        The generated tensor inputs have the following properties:
+
+        - tensor shapes are minimal but not trivial
+        - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4]
+        - the generated tensors represent the same mathematical tensor for all layouts
+        - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors.
+        - the generated tensors include contiguous or non-contiguous tensors both in indices and values
+
+        If output_tensor is True, yield tensors with the given
+        layout. Otherwise, yield inputs to the corresponding tensor
+        constructors:
+
+          - sparse compressed input is defined as
+            (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
+                                                              pin_memory=pin_memory)
+
+          - sparse COO input is defined as
+            (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
+
+          - strided input is defined as
+            (values,), dict(device=device, dtype=dtype)
+        """
+        if index_dtype is None:
+            index_dtype = torch.int64
+
+        is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+
+        if output_tensor:
+            for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
+                                                            pin_memory=pin_memory,
+                                                            enable_batch=enable_batch, enable_hybrid=enable_hybrid,
+                                                            enable_zero_sized=enable_zero_sized,
+                                                            enable_non_contiguous_indices=enable_non_contiguous_indices,
+                                                            enable_non_contiguous_values=enable_non_contiguous_values,
+                                                            enable_batch_variable_nse=enable_batch_variable_nse,
+                                                            output_tensor=False):
+                if members_pin_memory:
+                    args = tuple(a.pin_memory() for a in args)
+                if layout is torch.strided:
+                    assert len(args) == 1
+                    size = kwargs.pop('size', None)  # to ensure that a zero-sized tensor has the desired shape
+                    assert size is not None
+                    if pin_memory:
+                        yield args[0].reshape(size).pin_memory()
+                    else:
+                        yield args[0].reshape(size)
+                elif layout is torch.sparse_coo:
+                    yield torch.sparse_coo_tensor(*args, **kwargs)
+                elif is_compressed_sparse_layout:
+                    kwargs.update(layout=layout)
+                    yield torch.sparse_compressed_tensor(*args, **kwargs)
+                else:
+                    assert 0  # unreachable
+            return
+
+        def get_blockpattern(pattern, blocksize):
+            basesize = pattern.shape
+            assert basesize[0] % blocksize[0] == 0, (basesize, blocksize)
+            assert basesize[1] % blocksize[1] == 0, (basesize, blocksize)
+            blockpattern = pattern.reshape(-1,
+                                           blocksize[0],
+                                           basesize[1] // blocksize[1],
+                                           blocksize[1]).transpose(-3, -2).any(-1).any(-1)
+            block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape)
+            return (blockpattern != 0) * block_ids
+
+        def get_sparse_data(pattern):
+            basesize = pattern.shape
+            assert len(basesize) == 2, basesize  # pattern is expected to be a matrix
+
+            # We cannot use `torch.sparse_xyz_tensor(pattern)` to
+            # compute the sparse layout indices and values because
+            # generate_simple_inputs is used to generate the inputs to
+            # test `torch.sparse_xyz_tensor` factory functions, so
+            # we'll compute the indices and values independently of
+            # the factory functions.
+
+            indices = torch.where(pattern != 0)
+            coo_indices = torch.stack(indices)
+            crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64)
+            crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0)
+            col_indices = coo_indices[1]
+            strided_values = torch.zeros(basesize, dtype=torch.int64)
+
+            # the property of `values == range(1, 1+nnz)` is used in
+            # get_sparse_data_with_block to relate BSR and BSC values,
+            # so, don't change the following line:
+            values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64)
+            strided_values[indices] = values
+
+            indices_T = torch.where(pattern.transpose(0, 1) != 0)
+            coo_indices_T = torch.stack(indices_T)
+            ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64)
+            ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0)
+            row_indices = coo_indices_T[1]
+            csc_values = strided_values.transpose(0, 1)[indices_T]
+
+            return {torch.sparse_coo: (coo_indices, values),
+                    torch.sparse_csr: (crow_indices, col_indices, values),
+                    torch.sparse_csc: (ccol_indices, row_indices, csc_values),
+                    torch.strided: (strided_values,)}
+
+        def get_sparse_data_with_block(pattern, blocksize):
+            nonblock_data = get_sparse_data(pattern)
+            blockpattern = get_blockpattern(pattern, blocksize)
+            block_data = get_sparse_data(blockpattern)
+
+            strided_values = nonblock_data[torch.strided][0]
+            block_indices = block_data[torch.sparse_coo][0]
+            bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0],
+                                                     bj * blocksize[1]:(bj + 1) * blocksize[1]]
+                                      for bi, bj in block_indices.transpose(0, 1)])
+
+            # here we use the property `values == range(1, 1+nnz)` and
+            # `values` relation to `csc_values` (see get_sparse_data)
+            # to get BSC blocks via reordering the BSR blocks:
+            bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1]
+
+            return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values),
+                    torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values),
+                    **nonblock_data}
+
+        def get_batch_sparse_data(pattern, blocksize):
+            size = pattern.shape
+            if len(size) <= 2:  # non-batch
+                return get_sparse_data_with_block(pattern, blocksize)
+
+            # batch data is created recursively:
+            batch_data = {}  # type: ignore[var-annotated]
+            for i, item in enumerate(pattern):
+                for layout, d in get_batch_sparse_data(item, blocksize).items():
+                    target = batch_data.get(layout)
+                    if layout is torch.sparse_coo:
+                        # a "batch COO" means a COO with the leading
+                        # sparse dimensions interpreted as batch
+                        # dimensions
+                        ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0]))
+                        if target is None:
+                            target = batch_data[layout] = (ext_coo_indices1, d[1])
+                        else:
+                            target[0].set_(torch.cat((target[0], ext_coo_indices1), 1))  # type: ignore[call-overload]
+                            target[1].set_(torch.cat((target[1], d[1])))
+                    else:
+                        if target is None:
+                            target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d)))
+                        else:
+                            for j in range(len(d)):
+                                target[j].set_(torch.cat((target[j], d[j].unsqueeze(0))))  # type: ignore[call-overload]
+            return batch_data
+
+        def generate_values(base, densesize):
+            """Generates a tensor of shape densesize with values equal to
+
+              base + i_1 * 10^0 + ... + i_d * 10^{d - 1}
+
+            at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <=
+            len(densesize))
+
+            This mapping produces unique values as long as
+            densesize[i] < 10 for all i in range(len(densesize)).
+            """
+
+            if not densesize:
+                return base
+            if not isinstance(base, int) and base.ndim > 0:
+                return torch.stack([generate_values(b, densesize) for b in base])
+            if base == 0:
+                return torch.zeros(densesize, dtype=torch.int64)
+            r = torch.arange(densesize[0], dtype=torch.int64)
+            for i, d in enumerate(densesize[1:]):
+                y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1))
+                r = r[..., None] + y[None, ...]
+            r.add_(base)
+            return r
+
+        if patterns is None:
+            # A pattern is a 3-tuple with the following items:
+            #
+            # - a list of integers with the depth of two or more. The
+            #   integers define the sparsity patterns of the generated
+            #   inputs: zero values correspond to unspecified
+            #   elements/blocks, and non-zero values to the specified
+            #   elements.
+            #
+            #   For debugging convenience, the elements with the same
+            #   value typically belong to the same block. However, it
+            #   is not a hard requirement: as long as the shape of a
+            #   pattern divides with block sizes, the pattern will be
+            #   a valid one.
+            #
+            #   If the depth of the list is larger than two, inputs
+            #   with batch dimensions will be generated.
+            #
+            # - a list of 2-tuples of block sizes, used to generate
+            #   BSR/BSC tensors with various block size parameters
+            #
+            # - a list of tuples of dense dimensions, used to generate
+            #   hybrid tensors with various dense dimensions
+            #
+            patterns = [
+                # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions
+                ([[1, 2, 0],
+                  [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]),
+                # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions
+                ([[[[1, 2, 0],
+                    [1, 0, 3]],
+                   [[1, 2, 3],
+                    [1, 0, 0]],
+                   [[1, 0, 0],
+                    [1, 2, 3]]],
+                  [[[0, 2, 0],
+                    [1, 2, 3]],
+                   [[1, 0, 3],
+                    [1, 2, 0]],
+                   [[1, 2, 3],
+                    [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]),
+                # tensor with non-trivial blocksize
+                ([[0, 1, 0, 2, 0, 2],
+                  [0, 1, 0, 0, 2, 0],
+                  [3, 3, 3, 0, 0, 0],
+                  [0, 0, 0, 0, 0, 0],
+                  [0, 5, 0, 6, 6, 6],
+                  [5, 0, 5, 6, 6, 6],
+                  [0, 0, 0, 0, 8, 8],
+                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]),
+                # batch tensor with variable NSE
+                # Requires https://github.com/pytorch/pytorch/pull/84843 or similar.
+                ([[[1, 2],
+                   [3, 4]],
+                  [[1, 0],
+                   [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))]
+
+        def non_contiguous_copy(t, dim=-1, offset=0):
+            # return a copy of t that is non-contiguous along the
+            # given dimension and with the given storage offset
+            self.assertTrue(t.is_contiguous())
+            if dim < 0:
+                dim = dim + t.ndim
+            assert dim >= 0 and dim < t.ndim
+            step = max(2, offset + 1)
+            tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device)
+            dim_slices = (*((slice(None),) * dim), slice(offset, None, step))
+            r = tmp[dim_slices].copy_(t)
+            self.assertFalse(r.is_contiguous())
+            self.assertEqual(t, r)
+            return r
+
+        # the main loop of the method:
+        for pattern, blocksizes, densesizes in patterns:
+            if not enable_hybrid:
+                densesizes = [s for s in densesizes if not s]
+            if not (densesizes and blocksizes):
+                continue
+            pattern = torch.tensor(pattern, dtype=torch.int64)
+            if not enable_batch and pattern.ndim > 2:
+                continue
+            for blocksize in blocksizes:
+                data = get_batch_sparse_data(pattern, blocksize)[layout]
+                for densesize in densesizes:
+                    indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
+                    values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
+                    kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
+                    if pin_memory is not None:
+                        kwargs.update(pin_memory=pin_memory)
+
+                    yield (*indices, values), kwargs.copy()
+                    if enable_non_contiguous_indices and pattern.ndim > 2:
+                        # sparse compressed indices can be sliced only along batch dimensions
+                        for (dim, offset) in {(0, 1), (-2, 0)}:
+                            indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
+                            yield (*indices_copy, values), kwargs.copy()
+
+                            if enable_non_contiguous_values:
+                                values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                                yield (*indices_copy, values_copy), kwargs.copy()
+
+                    if enable_non_contiguous_values:
+                        values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                        yield (*indices, values_copy), kwargs.copy()
+
+        # zero-sized tensor inputs, non-batch, non-hybrid/hybrid
+        if enable_zero_sized:
+            for basesize, blocksizes, densesizes in [
+                    ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]),
+                    ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]),
+                    ((0, 0), [(1, 2)], [()]),
+            ]:
+                for blocksize in blocksizes:
+                    for densesize in densesizes:  # type: ignore[attr-defined]
+                        if layout == torch.strided:
+                            indices = ()  # type: ignore[assignment]
+                            values = torch.empty((basesize + densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_coo:
+                            indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csr:
+                            crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsr:
+                            crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        else:
+                            assert 0  # unreachable
+                        kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
+                        if pin_memory is not None:
+                            kwargs.update(pin_memory=pin_memory)
+                        yield (*indices, values), kwargs
+
+    def safeToDense(self, t):
+        # coalesce is only implemented for COO
+        if t.layout == torch.sparse_coo:
+            t = t.coalesce()
+        return t.to_dense()
+
+    # Compares a torch function with a reference function for a given sample input (object of SampleInput)
+    # Note: only values are compared, type comparison is not done here
+    def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
+        numpy_sample = sample_input.numpy()
+        n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs
+        t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
+
+        actual = torch_fn(t_inp, *t_args, **t_kwargs)
+        expected = ref_fn(n_inp, *n_args, **n_kwargs)
+
+        self.assertEqual(actual, expected, exact_device=False, **kwargs)
+
+    # Compares the given Torch and NumPy functions on the given tensor-like object.
+    # NOTE: both torch_fn and np_fn should be functions that take a single
+    #   tensor (array). If the torch and/or NumPy function require additional
+    #   arguments then wrap the function in a lambda or pass a partial function.
+    # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol)
+    def compare_with_numpy(self, torch_fn, np_fn, tensor_like,
+                           device=None, dtype=None, **kwargs):
+        assert TEST_NUMPY
+
+        if isinstance(tensor_like, torch.Tensor):
+            assert device is None
+            assert dtype is None
+            t_cpu = tensor_like.detach().cpu()
+            if t_cpu.dtype is torch.bfloat16:
+                t_cpu = t_cpu.float()
+            a = t_cpu.numpy()
+            t = tensor_like
+        else:
+            d = copy.copy(torch_to_numpy_dtype_dict)
+            d[torch.bfloat16] = np.float32
+            a = np.array(tensor_like, dtype=d[dtype])
+            t = torch.tensor(tensor_like, device=device, dtype=dtype)
+
+        np_result = np_fn(a)
+        torch_result = torch_fn(t).cpu()
+
+        # Converts arrays to tensors
+        if isinstance(np_result, np.ndarray):
+            try:
+                np_result = torch.from_numpy(np_result)
+            except Exception:
+                # NOTE: copying an array before conversion is necessary when,
+                #   for example, the array has negative strides.
+                np_result = torch.from_numpy(np_result.copy())
+            if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float:
+                torch_result = torch_result.to(torch.float)
+
+        self.assertEqual(np_result, torch_result, **kwargs)
+
+    def assertEqualIgnoreType(self, *args, **kwargs) -> None:
+        # If you are seeing this function used, that means test is written wrongly
+        # and deserves detailed investigation
+        return self.assertEqual(*args, exact_dtype=False, **kwargs)
+
+    def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None:
+        r"""Tests if tensor x equals to y, if y to be broadcast to x.shape.
+        """
+        if not isinstance(y, Iterable):
+            # int, float, etc. or different shape tensors
+            y = torch.ones_like(x) * y
+        if not isinstance(y, torch.Tensor):
+            # iterable, but not a tensor
+            y = torch.ones_like(x) * torch.tensor(y)
+        return self.assertEqual(x, y, *args, **kwargs)
+
+    def assertEqual(
+            self,
+            x,
+            y,
+            msg: Optional[Union[str, Callable[[str], str]]] = None,
+            *,
+            atol: Optional[float] = None,
+            rtol: Optional[float] = None,
+            equal_nan=True,
+            exact_dtype=True,
+            # TODO: default this to True
+            exact_device=False,
+            exact_layout=False,
+            exact_stride=False,
+            exact_is_coalesced=False
+    ):
+        # Hide this function from `pytest`'s traceback
+        __tracebackhide__ = True
+
+        # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall
+        # back to an elementwise comparison. Note that this has to happen here and not for example in
+        # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform
+        # multiple comparisons.
+        if any(
+            isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y)
+        ):
+            def to_list(input):
+                return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input)
+
+            x = to_list(x)
+            y = to_list(y)
+        # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here.
+        # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container
+        # that should be checked elementwise while the tensor is not.
+        elif isinstance(x, torch.Tensor) and isinstance(y, Sequence):
+            y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
+        elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
+            x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
+
+        # unbind NSTs to compare them; don't do this for NJTs
+        if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
+            x = x.unbind()
+        if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
+            y = y.unbind()
+
+        error_metas = not_close_error_metas(
+            x,
+            y,
+            pair_types=(
+                NonePair,
+                RelaxedBooleanPair,
+                RelaxedNumberPair,
+                TensorOrArrayPair,
+                TypedStoragePair,
+                StringPair,
+                SetPair,
+                TypePair,
+                ObjectPair,
+            ),
+            sequence_types=(
+                Sequence,
+                Sequential,
+                ModuleList,
+                ParameterList,
+                ScriptList,
+                torch.utils.data.dataset.Subset,
+            ),
+            mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict),
+            rtol=rtol,
+            rtol_override=self.rel_tol,
+            atol=atol,
+            atol_override=self.precision,
+            equal_nan=equal_nan,
+            check_device=exact_device,
+            check_dtype=exact_dtype,
+            check_layout=exact_layout,
+            check_stride=exact_stride,
+            check_is_coalesced=exact_is_coalesced,
+        )
+
+        if error_metas:
+            # See [ErrorMeta Cycles]
+            error_metas = [error_metas]  # type: ignore[list-item]
+            # TODO: compose all metas into one AssertionError
+            raise error_metas.pop()[0].to_error(  # type: ignore[index]
+                # This emulates unittest.TestCase's behavior if a custom message passed and
+                # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
+                # is True (default)
+                (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg
+            )
+
+    def assertNotEqual(self, x, y, msg: Optional[str] = None, *,                                       # type: ignore[override]
+                       atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None:
+        with self.assertRaises(AssertionError, msg=msg):
+            self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs)
+
+    def assertEqualTypeString(self, x, y) -> None:
+        # This API is used simulate deprecated x.type() is y.type()
+        self.assertEqual(x.device, y.device)
+        self.assertEqual(x.dtype, y.dtype)
+        self.assertEqual(x.is_sparse, y.is_sparse)
+
+    def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
+        for elem in iterable:
+            if id(obj) == id(elem):
+                return
+        raise AssertionError("object not found in iterable")
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaises(self, expected_exception, *args, **kwargs):
+        if self._ignore_not_implemented_error:
+            context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
+                AssertRaisesContextIgnoreNotImplementedError(expected_exception, self)  # type: ignore[call-arg]
+            try:
+                return context.handle('assertRaises', args, kwargs)  # type: ignore[union-attr, arg-type]
+            finally:
+                # see https://bugs.python.org/issue23890
+                context = None
+        else:
+            return super().assertRaises(expected_exception, *args, **kwargs)
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs):
+        # Verifies that an exception with the type expected_exception and message
+        # matching the regular expression defined by expected_regex is thrown.
+        # If the test is instantiated for a non-native device type (like XLA)
+        # then the message is not validated.
+
+        # Checks whether the test is instantiated for a device type by testing
+        # if the test class has defined the device_type attribute and,
+        # if so, tests whether the instantiated device type is native or not
+        if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps":  # type: ignore[attr-defined]
+            # empty string matches any string
+            expected_regex = ''
+
+        if self._ignore_not_implemented_error:
+            context = AssertRaisesContextIgnoreNotImplementedError(  # type: ignore[call-arg]
+                expected_exception, self, expected_regex)
+            return context.handle('assertRaisesRegex', args, kwargs)  # type: ignore[attr-defined, arg-type]
+        else:
+            return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
+
+    # Verifies that no unraisable exceptions are raised by callable.  Unlike regular
+    # exceptions, these do not actually propagate to the caller and are
+    # suppressed.  We must test for them specially.
+    def assertNoUnraisable(self, callable, *args, **kwargs):
+        raised = None
+
+        def record_unraisable(unraisable):
+            nonlocal raised
+            raised = unraisable
+
+        # Disable GC when running the callable to prevent spurious flakiness
+        # from unlucky GCs inside the callable
+        prev = gc.isenabled()
+        gc.disable()
+        try:
+            with unittest.mock.patch("sys.unraisablehook", record_unraisable):
+                callable(*args, **kwargs)
+        finally:
+            if prev:
+                gc.enable()
+
+        self.assertIsNone(raised)
+
+    # TODO: Support context manager interface
+    # NB: The kwargs forwarding to callable robs the 'subname' parameter.
+    # If you need it, manually apply your callable in a lambda instead.
+    def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
+        subname = None
+        if 'subname' in kwargs:
+            subname = kwargs['subname']
+            del kwargs['subname']
+        try:
+            callable(*args, **kwargs)
+        except exc_type as e:
+            self.assertExpected(str(e), subname)
+            return
+        # Don't put this in the try block; the AssertionError will catch it
+        self.fail(msg="Did not raise when expected to")
+
+    def assertNotWarn(self, callable, msg=''):
+        r"""
+        Test if :attr:`callable` does not raise a warning.
+        """
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                callable()
+            self.assertTrue(len(ws) == 0, msg)
+
+    @contextmanager
+    def assertWarnsOnceRegex(self, category, regex=''):
+        """Context manager for code that *must always* warn
+
+        This filters expected warnings from the test and fails if
+        the expected warning is not caught. It uses set_warn_always() to force
+        TORCH_WARN_ONCE to behave like TORCH_WARN
+        """
+        pattern = re.compile(regex)
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                yield
+            if len(ws) == 0:
+                self.fail('no warning caught')
+            self.assertTrue(any(type(w.message) is category for w in ws))
+            self.assertTrue(
+                any(re.match(pattern, str(w.message)) for w in ws),
+                f'{pattern}, {[w.message for w in ws if type(w.message) is category]}')
+
+    def assertExpected(self, s, subname=None):
+        r"""
+        Test that a string matches the recorded contents of a file
+        derived from the name of this test and subname.  This file
+        is placed in the 'expect' directory in the same directory
+        as the test script. You can automatically update the recorded test
+        output using --accept.
+
+        If you call this multiple times in a single function, you must
+        give a unique subname each time.
+        """
+        if not isinstance(s, str):
+            raise TypeError("assertExpected is strings only")
+
+        def remove_prefix(text, prefix):
+            if text.startswith(prefix):
+                return text[len(prefix):]
+            return text
+        # NB: we take __file__ from the module that defined the test
+        # class, so we place the expect directory where the test script
+        # lives, NOT where test/common_utils.py lives.  This doesn't matter in
+        # PyTorch where all test scripts are in the same directory as
+        # test/common_utils.py, but it matters in onnx-pytorch
+        module_id = self.__class__.__module__
+        munged_id = remove_prefix(self.id(), module_id + ".")
+        test_file = os.path.realpath(sys.modules[module_id].__file__)  # type: ignore[type-var]
+        expected_file = os.path.join(os.path.dirname(test_file),  # type: ignore[type-var, arg-type]
+                                     "expect",
+                                     munged_id)
+
+        subname_output = ""
+        if subname:
+            expected_file += "-" + subname
+            subname_output = f" ({subname})"
+        expected_file += ".expect"
+        expected = None
+
+        def accept_output(update_type):
+            print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}")
+            with open(expected_file, 'w') as f:
+                # Adjust for producer_version, leave s unmodified
+                s_tag = re.sub(r'(producer_version): "[0-9.]*"',
+                               r'\1: "CURRENT_VERSION"', s)
+                f.write(s_tag)
+
+        try:
+            with open(expected_file) as f:
+                expected = f.read()
+        except OSError as e:
+            if e.errno != errno.ENOENT:
+                raise
+            elif expecttest.ACCEPT:
+                return accept_output("output")
+            else:
+                raise RuntimeError(
+                      f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n"
+                      "No expect file exists; to accept the current output, run:\n"
+                      f"python {__main__.__file__} {munged_id} --accept") from None
+
+        # a hack for JIT tests
+        if IS_WINDOWS:
+            expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
+            s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
+
+        # Adjust for producer_version
+        expected = expected.replace(
+            'producer_version: "CURRENT_VERSION"',
+            f'producer_version: "{torch.onnx.producer_version}"'
+        )
+        if expecttest.ACCEPT:
+            if expected != s:
+                return accept_output("updated output")
+        else:
+            if hasattr(self, "assertMultiLineEqual"):
+                # Python 2.7 only
+                # NB: Python considers lhs "old" and rhs "new".
+                self.assertMultiLineEqual(expected, s)
+            else:
+                self.assertEqual(s, expected)
+
+    def assertExpectedStripMangled(self, s, subname=None):
+        s = re.sub(r'__torch__[^ ]+', '', s)
+        self.assertExpected(s, subname)
+
+    def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None):
+        """Assert that ``first`` is greater than or almost equal to ``second``.
+
+        The equality of ``first`` and ``second`` is determined in a similar way to
+        the ``assertAlmostEqual`` function of the standard library.
+        """
+        if delta is not None and places is not None:
+            raise TypeError("specify delta or places not both")
+
+        if first >= second:
+            return
+
+        diff = second - first
+        if delta is not None:
+            if diff <= delta:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {delta} delta"
+        else:
+            if places is None:
+                places = 7
+
+            if round(diff, places) == 0:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {places} places"
+
+        msg = self._formatMessage(msg, standardMsg)
+        raise self.failureException(msg)
+
+    def assertAtenOp(self, onnx_model, operator, overload_name=""):
+        all_aten_nodes = [p for p in onnx_model.graph.node
+                          if p.op_type == "ATen" and p.domain == "org.pytorch.aten"]
+        self.assertTrue(all_aten_nodes)
+
+        for op in all_aten_nodes:
+            attrs = {attr.name: attr.s.decode() for attr in op.attribute}
+            if attrs.get("operator") == operator:
+                break
+
+        self.assertEqual(attrs["operator"], operator)  # type: ignore[possibly-undefined]
+        self.assertEqual(attrs.get("overload_name", ""), overload_name)
+
+    def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
+        '''Checks that an operation produces a nondeterministic alert when
+        expected while `torch.use_deterministic_algorithms(True)` is set.
+
+        Args:
+          fn (callable): Function to check for a nondeterministic alert
+
+          caller_name (str): Name of the operation that produces the
+              nondeterministic alert. This name is expected to appear at the
+              beginning of the error/warning message.
+
+          should_alert (bool, optional): If True, then the check will only pass
+              if calling `fn` produces a nondeterministic error/warning with the
+              expected message. If False, then the check will only pass if
+              calling `fn` does not produce an error. Default: `True`.
+        '''
+
+        alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set'
+
+        # Check that errors are thrown correctly
+        with DeterministicGuard(True):
+            if should_alert:
+                with self.assertRaisesRegex(
+                        RuntimeError,
+                        alert_message,
+                        msg='expected a non-deterministic error, but it was not raised'):
+                    fn()
+
+            else:
+                # If a nondeterministic error is not expected, make sure
+                # that it is not raised
+                try:
+                    fn()
+                except RuntimeError as e:
+                    if 'does not have a deterministic implementation' in str(e):
+                        self.fail(
+                            'did not expect non-deterministic error message, '
+                            + 'but got one anyway: "' + str(e) + '"')
+                    # Reraise exceptions unrelated to nondeterminism
+                    raise
+
+        # Check that warnings are thrown correctly
+        with DeterministicGuard(True, warn_only=True):
+            if should_alert:
+                with self.assertWarnsRegex(
+                        UserWarning,
+                        alert_message):
+                    fn()
+            else:
+                with warnings.catch_warnings(record=True) as w:
+                    warnings.simplefilter("always")
+                    fn()
+                    for warning in w:
+                        if isinstance(warning, UserWarning):
+                            self.assertTrue(re.search(alert_message, str(warning)) is None)
+
+    # run code in subprocess and capture exceptions.
+    @staticmethod
+    def run_process_no_exception(code, env=None):
+        import subprocess
+
+        with subprocess.Popen(
+            [sys.executable, "-c", code],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            env=env,
+        ) as p:
+            (stdout, stderr) = p.communicate()
+            return (stdout, stderr)
+
+    # returns captured stderr
+    @staticmethod
+    def runWithPytorchAPIUsageStderr(code):
+        env = os.environ.copy()
+        env["PYTORCH_API_USAGE_STDERR"] = "1"
+        # remove CI flag since this is a wrapped test process.
+        # CI flag should be set in the parent process only.
+        env.pop("CI", None)
+        env.pop("TEST_SHOWLOCALS", None)
+        _stdout, stderr = TestCase.run_process_no_exception(code, env=env)
+        return stderr.decode('ascii')
+
+    def _attempt_load_from_subprocess(
+        self,
+        file: pathlib.Path,
+        import_string: str,
+        expected_failure_message: Optional[str] = None
+    ) -> None:
+        """
+        Attempts weights_only `torch.load` in a subprocess. This is used to test that
+        weights_only `torch.load` works as expected without global imports.
+
+        Args:
+            file (pathlib.Path): The path to the checkpoint to load.
+            import_string (str): import string to add to the script
+            exected_failure_message (str, optional): The expected failure message if the
+                checkpoint fails to load. If None, the test will pass
+        """
+        script = f"import torch;{import_string}torch.load(r'{file}', weights_only=True)"
+        cm = (
+            self.assertRaisesRegex(RuntimeError, re.escape(expected_failure_message))
+            if expected_failure_message else contextlib.nullcontext()
+        )
+        with cm:
+            try:
+                subprocess.check_output(
+                    [sys.executable, "-c", script],
+                    # On Windows, opening the subprocess with the default CWD makes `import torch`
+                    # fail, so just set CWD to this script's directory
+                    cwd=os.path.dirname(os.path.realpath(__file__)),
+                    stderr=subprocess.STDOUT,
+                )
+            except subprocess.CalledProcessError as e:
+                raise RuntimeError(e.output.decode("utf-8")) from None
+
+
+class TestCaseBase(TestCase):
+    # Calls to super() in dynamically created classes are a bit odd.
+    # See https://github.com/pytorch/pytorch/pull/118586 for more info
+    # Subclassing this class and then calling super(TestCaseBase) will run
+    # TestCase's setUp, tearDown etc functions
+    pass
+
+
+def download_file(url, binary=True):
+    from urllib.parse import urlsplit
+    from urllib import request, error
+
+    filename = os.path.basename(urlsplit(url)[2])
+    data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
+    path = os.path.join(data_dir, filename)
+
+    if os.path.exists(path):
+        return path
+    try:
+        with request.urlopen(url, timeout=15) as f1, open(path, 'wb' if binary else 'w') as f2:
+            data = f1.read()
+            f2.write(data)
+        return path
+    except error.URLError as e:
+        msg = f"could not download test file '{url}'"
+        warnings.warn(msg, RuntimeWarning, stacklevel=2)
+        raise unittest.SkipTest(msg) from e
+
+def find_free_port():
+    """
+    Finds an available port and returns that port number.
+
+    NOTE: If this function is being used to allocate a port to Store (or
+    indirectly via init_process_group or init_rpc), it should be used
+    in conjunction with the `retry_on_connect_failures` decorator as there is a potential
+    race condition where the allocated port may become unavailable before it can be used
+    """
+    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        sock.bind(('localhost', 0))
+        _, port = sock.getsockname()
+        return port
+
+# Errors that we can get in c10d initialization for which we should retry tests for.
+ADDRESS_IN_USE = "Address already in use"
+CONNECT_TIMEOUT = "connect() timed out."
+
+def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
+    """Reruns a test if the test returns a RuntimeError and the exception
+    contains one of the strings in connect_errors."""
+    # This if block is executed when using this function as a decorator with arguments.
+    if func is None:
+        return partial(retry_on_connect_failures, connect_errors=connect_errors)
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        n_retries = 10
+        tries_remaining = n_retries
+        while True:
+            try:
+                return func(*args, **kwargs)
+            except RuntimeError as error:
+                if any(connect_error in str(error) for connect_error in connect_errors):
+                    tries_remaining -= 1
+                    if tries_remaining == 0:
+                        raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
+                    time.sleep(random.random())
+                    continue
+                raise
+    return wrapper
+
+
+# Decorator to retry upon certain Exceptions.
+def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
+    def deco_retry(f):
+        @wraps(f)
+        def f_retry(*args, **kwargs):
+            mtries, mdelay = tries, delay
+            while mtries > 1:
+                try:
+                    return f(*args, **kwargs)
+                except ExceptionToCheck as e:
+                    msg = f"{e}, Retrying in {mdelay:d} seconds..."
+                    print(msg)
+                    time.sleep(mdelay)
+                    mtries -= 1
+            try:
+                return f(*args, **kwargs)
+            except ExceptionToCheck as e:
+                raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e
+        return f_retry  # true decorator
+    return deco_retry
+
+
+# FIXME: modernize these to be consistent with make_tensor
+#   and review including them in torch.testing
+# Methods for matrix generation
+
+def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
+    assert rank <= l
+    A = torch.randn(l, l, dtype=dtype, device=device)
+    u, s, vh = torch.linalg.svd(A, full_matrices=False)
+    for i in range(l):
+        if i >= rank:
+            s[i] = 0
+        elif s[i] == 0:
+            s[i] = 1
+    return (u * s.to(dtype).unsqueeze(-2)) @ vh
+
+def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
+    """
+    Returns a random rectangular matrix (batch of matrices)
+    with singular values sampled from a Gaussian with
+    mean `mean` and standard deviation `sigma`.
+    The smaller the `sigma`, the better conditioned
+    the output matrix is.
+    """
+    primitive_dtype = {
+        torch.float: torch.float,
+        torch.double: torch.double,
+        torch.cfloat: torch.float,
+        torch.cdouble: torch.double
+    }
+    x = torch.rand(shape, dtype=dtype, device=device)
+    m = x.size(-2)
+    n = x.size(-1)
+    u, _, vh = torch.linalg.svd(x, full_matrices=False)
+    s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
+        .sort(-1, descending=True).values.to(dtype)
+    return (u * s.unsqueeze(-2)) @ vh
+
+# Returns a noncontiguous (tensor with the same shape and values as t
+# The noncontiguous tensor is constructed such that elements in the innermost
+#   dimension are separated by zeros or (whenever possible) nans
+# TODO: consider more complicated noncontiguity schemes
+def noncontiguous_like(t):
+    # Short-circuits if t is already noncontiguous
+    if not t.is_contiguous():
+        return t
+
+    # Choose a "weird" value that won't be accessed
+    if t.dtype.is_floating_point or t.dtype.is_complex:
+        value = math.nan
+    elif t.dtype == torch.bool:
+        value = True
+    else:
+        value = 12
+
+    result = t.new_empty(t.shape + (2,))
+    result[..., 0] = value
+    result[..., 1] = t.detach()
+    result = result[..., 1]
+    result.requires_grad_(t.requires_grad)
+    return result
+
+# TODO: remove this (prefer make_symmetric_matrices below)
+def random_symmetric_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mT).div_(2)
+    return A
+
+# Creates a symmetric matrix or batch of symmetric matrices
+# Shape must be a square matrix or batch of square matrices
+def make_symmetric_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    t = (t + t.mT).div_(2)
+    return t
+
+def random_hermitian_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mH).div_(2)
+    return A
+
+
+def random_symmetric_psd_matrix(l, *batches, **kwargs):
+    """
+    Returns a batch of random symmetric positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    return A @ A.mT
+
+
+def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'):
+    """
+    Returns a batch of random Hermitian positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device)
+    return A @ A.mH
+
+
+# TODO: remove this (prefer make_symmetric_pd_matrices below)
+def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return torch.matmul(A, A.mT) \
+        + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5
+
+
+# Creates a symmetric positive-definite matrix or batch of
+#   such matrices
+def make_symmetric_pd_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5
+    return t @ t.mT + i
+
+def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device):
+    """
+    Returns a batch of random Hermitian positive-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device)
+
+# Creates a full rank matrix with distinct singular values or
+#   a batch of such matrices
+def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False):
+    with torch.no_grad():
+        t = make_tensor(shape, device=device, dtype=dtype)
+        u, _, vh = torch.linalg.svd(t, full_matrices=False)
+        real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
+        k = min(shape[-1], shape[-2])
+        # We choose the singular values to be "around one"
+        # This is to make the matrix well conditioned
+        # s = [2, 3, ..., k+1]
+        s = torch.arange(2, k + 2, dtype=real_dtype, device=device)
+        # s = [2, -3, 4, ..., (-1)^k k+1]
+        s[1::2] *= -1.
+        # 1 + 1/s so that the singular values are in the range [2/3, 3/2]
+        # This gives a condition number of 9/4, which should be good enough
+        s.reciprocal_().add_(1.)
+        # Note that the singular values need not be ordered in an SVD so
+        # we don't need need to sort S
+        x = (u * s.to(u.dtype)) @ vh
+    x.requires_grad_(requires_grad)
+    return x
+
+def random_matrix(rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices.
+
+    Parameters:
+      dtype - the data type
+      device - the device kind
+      singular - when True, the output will be singular
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    silent = kwargs.get("silent", False)
+    singular = kwargs.get("singular", False)
+    if silent and not torch._C.has_lapack:
+        return torch.ones(rows, columns, dtype=dtype, device=device)
+
+    A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
+    if A.numel() == 0:
+        return A
+    u, _, vh = torch.linalg.svd(A, full_matrices=False)
+    k = min(rows, columns)
+    s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
+    if singular:
+        # make matrix singular
+        s[k - 1] = 0
+        if k > 2:
+            # increase the order of singularity so that the pivoting
+            # in LU factorization will be non-trivial
+            s[0] = 0
+    return (u * s.unsqueeze(-2)) @ vh
+
+
+def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices with
+    given rank.
+    """
+    B = random_matrix(rows, rank, *batch_dims, **kwargs)
+    C = random_matrix(rank, columns, *batch_dims, **kwargs)
+    return B.matmul(C)
+
+
+def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
+    """Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
+    indices = []  # type: ignore[var-annotated]
+    n_per_row = math.ceil(num_indices / rows)
+    col_indices = list(range(cols))
+
+    for r in range(rows):
+        # Note that this can yield overlapping indices
+        indices.extend((r, c) for c in random.choices(col_indices, k=n_per_row))
+
+    return torch.tensor(indices[:num_indices])
+
+
+def random_sparse_matrix(rows, columns, density=0.01, **kwargs):
+    """Return rectangular random sparse matrix within given density.
+
+    The density of the result approaches to given density as the size
+    of the matrix is increased and a relatively small value of density
+    is specified but higher than min(rows, columns)/(rows * columns)
+    for non-singular matrices.
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+
+    nonzero_elements = max(min(rows, columns), int(rows * columns * density))
+    indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements)
+    values = torch.randn(nonzero_elements, dtype=dtype, device=device)
+
+    # ensure that the diagonal dominates
+    values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp()
+    A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device)
+    return A.coalesce()
+
+
+def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
+    """Return random sparse positive-definite matrix with given density.
+
+    The eigenvalues of the matrix are defined as::
+      arange(1, matrix_size+1)/matrix_size
+
+    Algorithm:
+      A = diag(arange(1, matrix_size+1)/matrix_size)
+      while :
+          
+          R = 
+          A = R^T A R
+    """
+    import math
+    torch = kwargs.get('torch', globals()['torch'])
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    data = {(i, i): float(i + 1) / matrix_size
+            for i in range(matrix_size)}
+
+
+    def multiply(data, N, i, j, cs, sn, left=True):
+        for k in range(N):
+            if left:
+                ik, jk = (k, i), (k, j)
+            else:
+                ik, jk = (i, k), (j, k)
+            aik, ajk = data.get(ik, 0), data.get(jk, 0)
+            aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk
+            if aik:
+                data[ik] = aik
+            else:
+                data.pop(ik, None)
+            if ajk:
+                data[jk] = ajk
+            else:
+                data.pop(jk, None)
+
+    target_nnz = density * matrix_size * matrix_size
+    while len(data) < target_nnz:
+        i = random.randint(0, matrix_size - 1)
+        j = random.randint(0, matrix_size - 1)
+        if i != j:
+            theta = random.uniform(0, 2 * math.pi)
+            cs = math.cos(theta)
+            sn = math.sin(theta)
+            multiply(data, matrix_size, i, j, cs, sn, left=True)
+            multiply(data, matrix_size, i, j, cs, sn, left=False)
+    icoords, jcoords, values = [], [], []
+    for (i, j), v in sorted(data.items()):
+        icoords.append(i)
+        jcoords.append(j)
+        values.append(v)
+    indices_tensor = torch.tensor([icoords, jcoords])
+    return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_dtypes(self, dtypes, layout, device):
+    for dtype in dtypes:
+        if dtype != torch.float16:
+            out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
+            self.assertIs(dtype, out.dtype)
+            self.assertIs(layout, out.layout)
+            self.assertEqual(device, out.device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_empty_full(self, dtypes, layout, device):
+    shape = torch.Size([2, 3])
+
+    def check_value(tensor, dtype, layout, device, value, requires_grad):
+        self.assertEqual(shape, tensor.shape)
+        self.assertIs(dtype, tensor.dtype)
+        self.assertIs(layout, tensor.layout)
+        self.assertEqual(tensor.requires_grad, requires_grad)
+        if tensor.is_cuda and device is not None:
+            self.assertEqual(device, tensor.device)
+        if value is not None:
+            fill = tensor.new(shape).fill_(value)
+            self.assertEqual(tensor, fill)
+
+    def get_int64_dtype(dtype):
+        module = '.'.join(str(dtype).split('.')[1:-1])
+        if not module:
+            return torch.int64
+        return operator.attrgetter(module)(torch).int64
+
+    default_dtype = torch.get_default_dtype()
+    check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
+    check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False)
+    for dtype in dtypes:
+        for rg in {dtype.is_floating_point, False}:
+            int64_dtype = get_int64_dtype(dtype)
+            v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
+            check_value(v, dtype, layout, device, None, rg)
+            out = v.new()
+            check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
+                        dtype, layout, device, None, rg)
+            check_value(v.new_empty(shape), dtype, layout, device, None, False)
+            check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+            check_value(torch.empty_like(v), dtype, layout, device, None, False)
+            check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+
+            if dtype is not torch.float16 and layout != torch.sparse_coo:
+                fv = 3
+                v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
+                check_value(v, dtype, layout, device, fv, rg)
+                check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False)
+                out = v.new()
+                check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
+                            dtype, layout, device, fv + 2, rg)
+                check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 3, False)
+                check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
+                check_value(torch.full_like(v, fv + 5,
+                                            dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 5, False)
+
+# FIXME: improve load_tests() documentation here
+running_script_path = None  # type: ignore[var-annotated]
+def set_running_script_path():
+    global running_script_path
+    try:
+        running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
+        if running_file.endswith('.py'):  # skip if the running file is not a script
+            running_script_path = running_file
+    except Exception:
+        pass
+
+def check_test_defined_in_running_script(test_case):
+    if running_script_path is None:
+        return
+    test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
+    assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \
+        f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \
+        "accidentally import a unittest.TestCase from another file?"
+
+def load_tests(loader, tests, pattern):
+    set_running_script_path()
+    test_suite = unittest.TestSuite()
+    for test_group in tests:
+        if not DISABLE_RUNNING_SCRIPT_CHK:
+            for test in test_group:
+                check_test_defined_in_running_script(test)
+        if test_group._tests:
+            test_suite.addTest(test_group)
+    return test_suite
+
+# FIXME: document this and move it to test_serialization
+class BytesIOContext(io.BytesIO):
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        pass
+
+# Tentative value for nondet_tol for gradcheck when backward implementation
+# relies on nondeterministic operations, i.e., those listed here:
+# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+#
+# For more information see https://github.com/pytorch/pytorch/issues/56202
+GRADCHECK_NONDET_TOL = 1e-12
+
+TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW_GRADCHECK",
+    env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK",
+)
+
+skipIfSlowGradcheckEnv = unittest.skipIf(
+    TEST_WITH_SLOW_GRADCHECK,
+    "Tests that don't use gradcheck don't need to run on slow_gradcheck CI",
+)
+
+
+def gradcheck(fn, inputs, **kwargs):
+    # Wrapper around gradcheck that enables certain keys by default.
+    # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and
+    # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks
+    # to be disabled to default for the public-facing api to avoid breaking user code.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck.
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradcheck(fn, inputs, **kwargs)
+
+def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
+    # Wrapper around gradgradcheck that enables certain keys by default
+    # See gradcheck above for an explanation of why we need something like this.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
+
+
+def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs):
+    # call assert function rather than returning a bool since it's nicer
+    # if we get whether this failed on the gradcheck or the gradgradcheck.
+    test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs))
+    test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs))
+
+
+@contextmanager
+def set_cwd(path: str) -> Iterator[None]:
+    old_cwd = os.getcwd()
+    try:
+        os.chdir(path)
+        yield
+    finally:
+        os.chdir(old_cwd)
+
+
+# FIXME: delete this
+# Using @toleranceOverride specific to your test is the recommended way
+# of doing this. These are just some values that worked for test_nn.
+dtype2prec_DONTUSE = {torch.float: 1e-5,
+                      torch.double: 1e-5,
+                      torch.half: 1e-2,
+                      torch.bfloat16: 1e-1}
+
+# FIXME: move to test_sparse or sparse utils
+# This is a wrapper that wraps a test to run this test twice, one with
+# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors.
+def coalescedonoff(f):
+    @wraps(f)
+    def wrapped(self, *args, **kwargs):
+        f(self, *args, **kwargs, coalesced=True)
+        f(self, *args, **kwargs, coalesced=False)
+    return wrapped
+
+
+def is_coalesced_indices(s):
+    indices = s._indices()
+    hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1]
+    hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1)
+    if s.sparse_dim() > 1:
+        hash_indices.unsqueeze_(-1)
+        hash_indices = (indices * hash_indices).sum(0)
+    else:
+        hash_indices = indices * hash_indices
+
+    # check if indices are sorted
+    res = torch.allclose(hash_indices, hash_indices.sort()[0])
+
+    # check if there are no repeated indices
+    res = res and torch.allclose(hash_indices, hash_indices.unique())
+
+    return res
+
+
+@contextlib.contextmanager
+def disable_gc():
+    if gc.isenabled():
+        try:
+            gc.disable()
+            yield
+        finally:
+            gc.enable()
+    else:
+        yield
+
+
+def find_library_location(lib_name: str) -> Path:
+    # return the shared library file in the installed folder if exist,
+    # else the file in the build folder
+    torch_root = Path(torch.__file__).resolve().parent
+    path = torch_root / 'lib' / lib_name
+    if os.path.exists(path):
+        return path
+    torch_root = Path(__file__).resolve().parents[2]
+    return torch_root / 'build' / 'lib' / lib_name
+
+def skip_but_pass_in_sandcastle(reason):
+    """
+    Similar to unittest.skip, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if not IS_SANDCASTLE:
+            func.__unittest_skip__ = True
+            func.__unittest_skip_why__ = reason
+            return func
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+            return
+        return wrapper
+
+    return decorator
+
+def mock_wrapper(method):
+    """
+    Returns a function that calls the real implementation of a method
+    in addition to passing args to a mock object.
+    """
+    mock = MagicMock()
+
+    @wraps(method)
+    def wrapper(self, *args, **kwargs):
+        mock(*args, **kwargs)
+        return method(self, *args, **kwargs)
+    wrapper.mock = mock  # type: ignore[attr-defined]
+    return wrapper
+
+def get_tensors_from(args, kwargs):
+    """ Returns a set of all Tensor objects in the given args and kwargs. """
+    return set([arg for arg in args if isinstance(arg, Tensor)] +
+               [v for v in kwargs.values() if isinstance(v, Tensor)])
+
+
+# Returns scalar tensor representation of a list of integer byte values
+def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device):
+    dtype_to_ctype: dict[torch.dtype, Any] = {
+        torch.int8: ctypes.c_int8,
+        torch.uint8: ctypes.c_uint8,
+        torch.uint16: ctypes.c_uint16,
+        torch.uint32: ctypes.c_uint32,
+        torch.uint64: ctypes.c_uint64,
+        torch.int16: ctypes.c_int16,
+        torch.int32: ctypes.c_int32,
+        torch.int64: ctypes.c_int64,
+        torch.bool: ctypes.c_bool,
+        torch.float32: ctypes.c_float,
+        torch.complex64: ctypes.c_float,
+        torch.float64: ctypes.c_double,
+        torch.complex128: ctypes.c_double,
+    }
+    ctype = dtype_to_ctype[dtype]
+    num_bytes = ctypes.sizeof(ctype)
+
+    def check_bytes(byte_list):
+        for byte in byte_list:
+            assert 0 <= byte <= 255
+
+    if dtype.is_complex:
+        assert len(byte_list) == (num_bytes * 2)
+        check_bytes(byte_list)
+        real = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[:num_bytes])).value
+        imag = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[num_bytes:])).value
+        res = real + 1j * imag
+    else:
+        assert len(byte_list) == num_bytes
+        check_bytes(byte_list)
+        res = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list)).value
+
+    return torch.tensor(res, device=device, dtype=dtype)
+
+
+def copy_func(f):
+    """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
+    g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__,
+                           argdefs=f.__defaults__,
+                           closure=f.__closure__)
+    g = functools.update_wrapper(g, f)
+    g.__kwdefaults__ = f.__kwdefaults__  # type: ignore[attr-defined]
+    return g
+
+
+def xfail_inherited_tests(tests):
+    """
+    Given a list of test names which are defined by a superclass of the
+    class this decorates, mark them as expected failure.  This is useful
+    if you are doing poor man's parameterized tests by subclassing a generic
+    test class.
+    """
+    def deco(cls):
+        for t in tests:
+            # NB: expectedFailure operates by mutating the method in question,
+            # which is why you have to copy the function first
+            setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t))))
+        return cls
+    return deco
+
+
+def skip_but_pass_in_sandcastle_if(condition, reason):
+    """
+    Similar to unittest.skipIf, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if condition:
+            if IS_SANDCASTLE:
+                @wraps(func)
+                def wrapper(*args, **kwargs):
+                    print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+                return wrapper
+            else:
+                func.__unittest_skip__ = True
+                func.__unittest_skip_why__ = reason
+
+        return func
+
+    return decorator
+
+def dtype_name(dtype):
+    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
+    return str(dtype).split('.')[1]
+
+
+@functools.lru_cache
+def get_cycles_per_ms() -> float:
+    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
+    """
+
+    def measure() -> float:
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
+        torch.cuda._sleep(1000000)
+        end.record()
+        end.synchronize()
+        cycles_per_ms = 1000000 / start.elapsed_time(end)
+        return cycles_per_ms
+
+    # Get 10 values and remove the 2 max and 2 min and return the avg.
+    # This is to avoid system disturbance that skew the results, e.g.
+    # the very first cuda call likely does a bunch of init, which takes
+    # much longer than subsequent calls.
+    #
+    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
+    # and seems to return stable values. Therefore, we enable caching
+    # using lru_cache decorator above.
+    num = 10
+    vals = [measure() for _ in range(num)]
+    vals = sorted(vals)
+    return mean(vals[2 : num - 2])
+
+
+# OpInfo utils
+
+T = TypeVar('T')
+def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
+    """
+    Returns the first sample from an iterable of samples, like those returned by OpInfo.
+    The test will be skipped if no samples are available.
+    """
+    try:
+        return next(iter(samples))
+    except StopIteration as e:
+        raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
+
+# this helper method is to recursively
+# clone the tensor-type input of operators tested by OpInfo
+def clone_input_helper(input):
+    if isinstance(input, torch.Tensor):
+        return torch.clone(input)
+
+    if isinstance(input, Sequence):
+        return tuple(map(clone_input_helper, input))
+
+    return input
+
+@contextmanager
+def custom_op(opname, symbolic_fn, opset_version):
+    """Context manager/decorator to test ONNX export with custom operator"""
+    try:
+        register_custom_op_symbolic(opname, symbolic_fn, opset_version)
+        yield
+    finally:
+        unregister_custom_op_symbolic(opname, opset_version)
+
+
+def outs_and_grads(fn, graph_inps, inps):
+    outs = fn(*graph_inps)
+    for out in pytree.tree_leaves(outs):
+        if isinstance(out, torch.Tensor) and out.requires_grad:
+            out.sum().backward(retain_graph=True)
+    grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)]
+    for inp in pytree.tree_leaves(inps):
+        if isinstance(inp, torch.Tensor):
+            inp.grad = None
+    return outs, grads
+
+def compare_equal_outs_and_grads(test, m1, m2, inps):
+    r1, g1 = outs_and_grads(m1, inps, inps)
+    r2, g2 = outs_and_grads(m2, inps, inps)
+    test.assertEqual(r1, r2)
+    test.assertEqual(g1, g2)
+
+class TestGradients(TestCase):
+    exact_dtype = True
+
+    # Copies inputs to inplace operations to avoid inplace modifications
+    #   to leaves requiring gradient
+    def _get_safe_inplace(self, inplace_variant):
+        @wraps(inplace_variant)
+        def _fn(t, *args, **kwargs):
+            return inplace_variant(t.clone(), *args, **kwargs)
+
+        return _fn
+
+    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
+                      check_batched_grad=None, check_batched_forward_grad=False):
+        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
+        # NB: check_backward_ad does not affect gradgradcheck (always True)
+        if variant is None:
+            self.skipTest("Skipped! Variant not implemented.")
+        if not op.supports_dtype(dtype, torch.device(device).type):
+            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
+
+        def is_inplace(variant):
+            if hasattr(variant, "__wrapped__"):
+                return variant.__wrapped__ is op.get_inplace()
+            return variant is op.get_inplace()
+
+        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
+
+        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
+                                   small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
+
+        for sample in samples:
+            if sample.broadcasts_input and is_inplace(variant):
+                continue
+
+            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
+            #   and tensors passed as kwargs. The following creates a function that accepts just
+            #   the tensors that require grad as varargs, and then recomposes them back into the
+            #   original input.
+
+            # Creates gradcheck inputs by identifying tensors requiring grad
+            all_args = None
+            if is_iterable_of_tensors(sample.input):
+                all_args = chain(sample.input, sample.args, sample.kwargs.values())
+            else:
+                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))  # type: ignore[assignment]
+            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))  # type: ignore[union-attr]
+
+            # Verifies sample input tensors should have no grad
+            # This may happen if the same tensor is used in two different SampleInputs
+            for t in gradcheck_args:
+                self.assertIsNone(t.grad,
+                                  "A sampled input has a gradient before running autograd. "
+                                  "This usually means that (at least) one input tensor is reused "
+                                  "across different SampleInputs. "
+                                  "Please create a new tensor for each SampleInput.")
+
+            def _input_recomposition_helper(inputs, inp, input_idx):
+                if is_iterable_of_tensors(inp):
+                    tensor_list = []
+                    for x in inp:
+                        if isinstance(x, torch.Tensor) and x.requires_grad:
+                            tensor_list.append(inputs[input_idx])
+                            input_idx = input_idx + 1
+                        else:
+                            tensor_list.append(x)
+                    return tensor_list, input_idx
+                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
+                    return inputs[input_idx], input_idx + 1
+                else:
+                    return inp, input_idx
+
+            def fn(*inputs):
+                # Puts inputs back into sample properly
+                positional_args = []
+                input_idx = 0
+                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
+                positional_args.append(inp)
+
+                for x in sample.args:
+                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
+                    positional_args.append(inp)
+
+                # Recreates kwargs
+                kwargs = {}
+                for k, v in sample.kwargs.items():
+                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
+                    kwargs[k] = inp
+
+                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
+                if sample.output_process_fn_grad is not None:
+                    return sample.output_process_fn_grad(output)
+                return output
+
+            if check == 'gradcheck':
+                if check_batched_grad is None:
+                    check_batched_grad = op.check_batched_grad
+                self.assertTrue(gradcheck(fn, gradcheck_args,
+                                          check_batched_grad=check_batched_grad,
+                                          check_grad_dtypes=True,
+                                          nondet_tol=op.gradcheck_nondet_tol,
+                                          fast_mode=op.gradcheck_fast_mode,
+                                          check_forward_ad=check_forward_ad,
+                                          check_backward_ad=check_backward_ad,
+                                          check_undefined_grad=True,
+                                          check_batched_forward_grad=check_batched_forward_grad))
+            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
+                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
+                for gen_non_contig_grad_outputs in (False, True):
+                    kwargs = {
+                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
+                        "check_batched_grad": op.check_batched_gradgrad,
+                        "check_grad_dtypes": True,
+                        "nondet_tol": op.gradcheck_nondet_tol,
+                        "fast_mode": op.gradcheck_fast_mode
+                    }
+                    if check == "fwgrad_bwgrad":
+                        kwargs["check_fwd_over_rev"] = True
+                        kwargs["check_rev_over_rev"] = False
+                        kwargs["check_batched_grad"] = False
+                        kwargs["check_undefined_grad"] = False
+
+                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
+            else:
+                self.assertTrue(False, msg="Unknown check requested!")
+
+    def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
+                          check_batched_grad=None, check_batched_forward_grad=False):
+        return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
+                                  check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
+                                  check_batched_forward_grad=check_batched_forward_grad)
+
+    def _skip_helper(self, op, device, dtype):
+        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
+            self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
+        if not op.supports_autograd and not op.supports_forward_ad:
+            self.skipTest("Skipped! autograd not supported.")
+
+
+
+
+# Base TestCase for NT tests; used to define common helpers, etc.
+class NestedTensorTestCase(TestCase):
+    def assertEqualIgnoringNestedInts(self, a, b):
+        # unbinding NJTs allows us to compare them as essentially equal without
+        # caring about exact nested int comparison
+        def _unbind_njts(x):
+            if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
+                return x.unbind()
+            else:
+                return x
+
+        self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
+
+    def assertEqualNoncontigAware(self, a, b):
+        # assertEqual() doesn't take into account lengths, so hack around this
+        # by comparing unbound components and shapes
+        self.assertEqualIgnoringNestedInts(a, b)
+
+        def _get_njt_shapes(x):
+            return (
+                x.shape
+                if isinstance(x, torch.Tensor) and x.is_nested
+                else None
+            )
+
+        a_shapes = pytree.tree_map(_get_njt_shapes, a)
+        b_shapes = pytree.tree_map(_get_njt_shapes, b)
+        self.assertEqual(a_shapes, b_shapes)
+
+    @contextlib.contextmanager
+    def branch_nested_state(self):
+        """Context manager to branch and restore the nested tensor state."""
+        nested_tensor_module = torch.nested._internal.nested_tensor
+        original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
+        original_tensor_id_counter = nested_tensor_module._tensor_id_counter
+        try:
+            yield
+        finally:
+            nested_tensor_module._tensor_id_counter = original_tensor_id_counter
+            nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
+
+
+def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
+    from torch._dynamo.trace_rules import _as_posix_path
+
+    if file is None:
+        file = inspect.stack()[1 + skip].filename  # skip one frame
+
+    file = _as_posix_path(file)
+    s = _as_posix_path(str(e))
+
+    # Remove everything that looks like stack frames in NOT this file
+    def repl_frame(m):
+        if m.group(1) != file:
+            return ""
+        # Don't accept top-level, even for this script, these will wobble
+        # depending on how the testing script was invoked
+        if m.group(2) == "":
+            return ""
+
+        return m.group(0)
+
+    s = re.sub(r'  File "([^"]+)", line \d+, in (.+)\n(    .+\n( +[~^]+ *\n)?)+', repl_frame, s)
+    s = re.sub(r"line \d+", "line N", s)
+    s = re.sub(r".py:\d+", ".py:N", s)
+    s = re.sub(r'https:/([a-zA-Z0-9_.-]+)', r'https://\1', s)
+    s = re.sub(file, _as_posix_path(os.path.basename(file)), s)
+    s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s)
+    # 3.10 CALL_FUNCTION bytecode compatibility for dynamo graph break messages
+    s = re.sub(
+        r"attempting to trace CALL_FUNCTION:.*$",
+        "attempting to trace CALL: a function call, e.g. f(x, y):",
+        s,
+        flags=re.MULTILINE,
+    )
+    if suppress_suffix:
+        s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*Set TORCHDYNAMO_VERBOSE=1.+", "", s, flags=re.DOTALL)
+    if suppress_prefix:
+        s = re.sub(r"Cannot export model.+\n\n", "", s)
+    s = re.sub(r" +$", "", s, flags=re.MULTILINE)
+    return s
+
+
+@contextmanager
+def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
+    """Wrap around operations you want to ensure are not leaking tensor memory.
+
+    This code intentionally ignores other reference cycles, which can be benign and which we have plenty
+    of in pytorch code.  It focuses on any reference cycles that directly or indirectly result holding a Tensor alive,
+    since this is likely a more serious leak than typical python refcycles.
+
+    limit specifies how many tensors to dump debug graphs for (default=1)
+    """
+    def match_obj(obj):
+        return isinstance(obj, matched_type)
+
+    try:
+        gc.collect()
+        gc.set_debug(gc.DEBUG_SAVEALL)
+        garbage_objs = []  # type: ignore[var-annotated]
+
+        # run the user code, after cleaning any existing refcycles, and then check for new ones
+        # also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr
+        yield garbage_objs
+
+        gc.collect()
+        garbage_objs.extend(filter(match_obj, gc.garbage))
+        num_garbage_objs = len(garbage_objs)
+        if num_garbage_objs > 0:
+            warnings.warn(
+                f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?", stacklevel=2
+            )
+            try:
+                import objgraph  # type: ignore[import-not-found,import-untyped]
+                warnings.warn(
+                    f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png", stacklevel=2
+                )
+                for g in garbage_objs[:limit]:
+                    objgraph.show_backrefs([g], max_depth=10)
+            except ImportError:
+                warnings.warn("`pip install objgraph` to enable memory leak debugging", stacklevel=2)
+
+    finally:
+        gc.set_debug(0)
+
+
+def remove_cpp_extensions_build_root():
+    """
+    Removes the default root folder under which extensions are built.
+    """
+    default_build_root = cpp_extension.get_default_build_root()
+    if os.path.exists(default_build_root):
+        if IS_WINDOWS:
+            # rmtree returns permission error: [WinError 5] Access is denied
+            # on Windows, this is a workaround
+            subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
+        else:
+            shutil.rmtree(default_build_root, ignore_errors=True)
+
+
+def install_cpp_extension(extension_root):
+    # Wipe the build / install dirs if they exist
+    build_dir = os.path.join(extension_root, "build")
+    install_dir = os.path.join(extension_root, "install")
+    for d in (build_dir, install_dir):
+        if os.path.exists(d):
+            shutil.rmtree(d)
+
+    # Build the extension
+    cmd = [sys.executable, "-m", "pip", "install", extension_root, "-v", "--no-build-isolation", "--root", install_dir]
+    return_code = shell(cmd, cwd=extension_root, env=os.environ)
+    if return_code != 0:
+        raise RuntimeError(f"build failed for cpp extension at {extension_root}")
+
+    mod_install_dir = None
+    # install directory is the one that is named site-packages
+    for root, directories, _ in os.walk(install_dir):
+        for directory in directories:
+            if "-packages" in directory:
+                mod_install_dir = os.path.join(root, directory)
+
+    if mod_install_dir is None:
+        raise RuntimeError(f"installation failed for cpp extension at {extension_root}")
+
+    if mod_install_dir not in sys.path:
+        sys.path.insert(0, mod_install_dir)
+
+
+# Decorator to provide a helper to load inline extensions to a temp directory
+def scoped_load_inline(func):
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        def load_inline(*args, **kwargs):
+            if IS_WINDOWS:
+                # TODO(xmfan): even using TemporaryDirectoryName will result in permission error
+                return cpp_extension.load_inline(*args, **kwargs)
+
+            assert "build_directory" not in kwargs
+            with TemporaryDirectoryName() as temp_dir_name:
+                if kwargs.get("verbose", False):
+                    print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr)
+                kwargs["build_directory"] = temp_dir_name
+                return cpp_extension.load_inline(*args, **kwargs)
+
+        return func(*args, load_inline=load_inline, **kwargs)
+    return wrapper
+
+def recover_orig_fp32_precision(fn):
+    @contextlib.contextmanager
+    def recover():
+        old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision  # type: ignore[attr-defined]
+        old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision  # type: ignore[attr-defined]
+        old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision  # type: ignore[attr-defined]
+        old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision  # type: ignore[attr-defined]
+        old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision  # type: ignore[attr-defined]
+        old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
+        try:
+            yield
+        finally:
+            torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p  # type: ignore[attr-defined]
+            torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p  # type: ignore[attr-defined]
+            torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p  # type: ignore[attr-defined]
+            torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p  # type: ignore[attr-defined]
+            torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p  # type: ignore[attr-defined]
+            torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
+
+    return recover()(fn)
+
+def skipIfPythonVersionMismatch(predicate):
+    vi = sys.version_info
+
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if predicate(vi.major, vi.minor, vi.micro):
+                return fn(self, *args, **kwargs)
+            else:
+                raise unittest.SkipTest("Python version mismatch")
+        return wrap_fn
+    return dec_fn
+
+# Decorator to patch multiple test class members for the duration of the subtest
+def patch_test_members(updates: dict[str, Any]):
+    def decorator(test_func):
+        @wraps(test_func)
+        def wrapper(self, *args, **kwargs):
+            # Store the original values of the specified members
+            original_values = {member: getattr(self, member) for member in updates}
+
+            # Update the members before running the subtest
+            for member, value in updates.items():
+                setattr(self, member, value)
+
+            # Run the test function, allowing subtests to run
+            try:
+                return test_func(self, *args, **kwargs)
+            finally:
+                # Restore the original values of the specified members after the subtest finishes
+                for member, original_value in original_values.items():
+                    setattr(self, member, original_value)
+
+        return wrapper
+    return decorator
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad3ec2d6303ec73625aa55b3ccf394e65303121c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd61c567ffa7b378ccd90219fed27927f902e6b6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network1.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4643266fdf53d5874e9ee7f03c3d87d1295e6c0c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/data/__pycache__/network2.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..faa55161bfe564d1c541db32250df7667088ec15
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b8dd6198744e794659bd9a5b72d2f863ebe49f5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4170e168817518ec40f18efb41f18d17133c7a6c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85831cdb018a4b0ba48acdcec5d49bffecb010c6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4ad912fa563b018f5a74027f901c759e1e24ab5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccfd2d227a2e0416e6bebfeda8a513a0c8f9868c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f089c9fc60f5d6b34eed0564b530f40f17695de
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96afe0e443e2e57a051e0a55be1b745fad497cc7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acc7005c6b9e3d64d1ca50714839b0732d41b5a5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
@@ -0,0 +1 @@
+# mypy: allow-untyped-defs
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bdb1218e034a4949fbc08202bf1b35b8c6d7b77
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcab081021b671b7bcd0bb637cd123353651e477
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..60c744ac1a84cfb9220221a583a4849b6039c353
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
@@ -0,0 +1,103 @@
+# mypy: allow-untyped-defs
+
+import sys
+from functools import partial, wraps
+
+import torch
+import torch.distributed as dist
+from torch.distributed import rpc
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    TEST_SKIPS,
+    tp_transports,
+)
+
+
+TEST_GPU_NUM = 4
+
+
+class ShardedTensorTestBase(MultiProcessTestCase):
+    @property
+    def world_size(self):
+        return TEST_GPU_NUM
+
+    def init_pg(self, backend="nccl"):
+        if backend not in ["nccl", "gloo", "mpi", "hccl"]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,
+            init_method=f"file://{self.file_name}",
+        )
+
+        # set device for nccl pg for collectives
+        if backend == "nccl":
+            torch.cuda.set_device(self.rank)
+
+    def init_rpc(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            _transports=tp_transports()
+        )
+        rpc_backend_options.init_method = f"file://{self.file_name}"
+        for rank in range(self.world_size):
+            rpc_backend_options.set_device_map(
+                f"worker{rank}", {rank: self.rank, self.rank: rank}
+            )
+
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+    def init_comms(self, init_rpc=True, backend="nccl"):
+        if init_rpc:
+            self.init_rpc()
+        self.init_pg(backend=backend)
+
+    def destroy_comms(self, destroy_rpc=True):
+        # Wait for all ranks to reach here before starting shutdown.
+        dist.barrier()
+
+        if destroy_rpc:
+            rpc.shutdown()
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def assert_sharded_tensor_equal(self, st1, st2):
+        st1_local_shards = st1.local_shards()
+        st2_local_shards = st2.local_shards()
+        self.assertEqual(len(st1_local_shards), len(st2_local_shards))
+        for i, st1_local_shard in enumerate(st1_local_shards):
+            self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
+            self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
+
+        self.assertEqual(st1.metadata(), st2.metadata())
+        self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
+        self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
+
+
+# wrapper to initialize comms (processgroup + rpc)
+def with_comms(func=None, init_rpc=True, backend="nccl"):
+    if func is None:
+        return partial(
+            with_comms,
+            init_rpc=init_rpc,
+            backend=backend,
+        )
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        if backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        self.init_comms(init_rpc=init_rpc, backend=backend)
+        func(self, *args, **kwargs)
+        self.destroy_comms(destroy_rpc=init_rpc)
+
+    return wrapper
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bf8f36c93707f326ab3c53c1414fdd2f5b98a6d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abfc4be00fcec1c294d651b07c8b4f70643c1485
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5724f2b5c12dfc94e1f27889ed0c861495879434
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83bc3a35102a051d42587352c2dcb7967510903
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
@@ -0,0 +1,137 @@
+# mypy: allow-untyped-defs
+
+import builtins
+
+import torch
+from torch.distributed._shard.sharding_spec import (
+    ChunkShardingSpec,
+    EnumerableShardingSpec,
+    ShardMetadata,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+    get_chunked_dim_size,
+    get_split_size,
+)
+
+
+def generate_chunk_sharding_specs_for_test(sharding_dim):
+    return [
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        ),
+        # Test different ordering. (Case 1)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+            ],
+        ),
+        # Test different ordering. (Case 2)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+            ],
+        ),
+    ]
+
+
+def generate_enumerable_sharding_specs_for_test():
+    return [
+        EnumerableShardingSpec(
+            [
+                ShardMetadata(
+                    shard_offsets=[0, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:0/cuda:0",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:1/cuda:1",
+                ),
+                ShardMetadata(
+                    shard_offsets=[0, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:2/cuda:2",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:3/cuda:3",
+                ),
+            ]
+        )
+    ]
+
+
+def generate_local_weight_sharding_params_for_test(
+    local_weight, sharded_dim, gpu_num, spec, rank
+):
+    """
+    Shard the local weight based the given spec, so we can compare against
+    the one from sharded tensor.
+
+    Args:
+        local_weight: weight matrix to be sharded.
+        sharded_dim: The dimension which we shard on.
+        gpu_num: number of ranks.
+        spec: sharding spec.
+        rank: # of cuda process.
+
+    Returns:
+        start_pos: start position of sharded weight on the given rank.
+        chunk_size: chunk size of sharded weight on the given rank.
+    """
+    sharding_dim_size = local_weight.size(sharded_dim)
+    split_size = get_split_size(sharding_dim_size, gpu_num)
+    current_offsets = 0
+    start_pos = current_offsets
+    for idx, placement in enumerate(spec.placements):
+        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+        if rank == placement.rank():
+            start_pos = current_offsets
+            break
+        current_offsets += chunk_size
+    return start_pos, chunk_size
+
+
+def clone_module_parameter(module, param_name):
+    """
+    Clone a parameter from a given existing module.
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
+        param_name (str): Name of the parameter of ``module`` that needs to be cloned.
+
+    Returns: cloned tensor as :class:`torch.nn.Parameter`.
+    """
+    tensor = getattr(module, param_name)
+    return torch.nn.Parameter(tensor.detach().clone())
+
+
+def gen_binary_op_func(python_op, inplace=False):
+    src_lines = ["def f(lhs, rhs):"]
+    if "torch" in python_op:
+        src_lines.append(f"  return {python_op}(lhs, rhs)\n")
+    elif inplace:
+        src_lines.append(f"  lhs {python_op}= rhs\n  return lhs\n")
+    else:
+        src_lines.append(f"  return lhs {python_op} rhs\n")
+
+    code_str = "\n".join(src_lines)
+    g = {"torch": torch}
+    builtins.exec(code_str, g)
+    return g["f"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe82a8dc43f8f876cb4c8d0c000cda9a32d46fb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
@@ -0,0 +1,56 @@
+# mypy: allow-untyped-defs
+
+import copy
+import random
+
+import torch
+from torch.distributed._shard import sharded_tensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+
+
+PLACEMENTS = [
+    "rank:0/cuda:0",
+    "rank:1/cuda:1",
+    "rank:2/cuda:2",
+    "rank:3/cuda:3",
+]
+
+DEFAULT_GPU_NUM = 4
+
+
+def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
+    spec_list = []
+    for i in range(len(sharding_dims)):
+        random.Random(seed + i).shuffle(PLACEMENTS)
+        spec_list.append(
+            ChunkShardingSpec(
+                dim=sharding_dims[i],
+                placements=copy.deepcopy(PLACEMENTS),
+            )
+        )
+    return spec_list
+
+
+class MyShardedModel2(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor2 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor2 = None
+        self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
+
+
+class MyShardedModel1(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor1 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor1 = None
+        self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
+        self.submodule = MyShardedModel2(spec, group, init_rrefs)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9390da489851872ec1d0715a0b3e46275e5752b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
@@ -0,0 +1,41 @@
+# mypy: allow-untyped-defs
+
+import torch
+import torch.nn as nn
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+
+class SimpleMegatronLM(nn.Module):
+    def __init__(self, linear_size, rank=None, dtype=torch.float32):
+        super().__init__()
+        self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
+        self.gelu = nn.GELU()
+        self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
+        if rank is not None:
+            self.fc1.cuda(rank)
+            self.fc2.cuda(rank)
+
+    def forward(self, inp):
+        return self.fc2(self.gelu(self.fc1(inp)))
+
+    def get_weights(self):
+        if isinstance(self.fc1.weight, ShardedTensor):
+            weight1 = self.fc1.weight.local_tensor()
+        else:
+            weight1 = self.fc1.weight
+
+        if isinstance(self.fc2.weight, ShardedTensor):
+            weight2 = self.fc2.weight.local_tensor()
+        else:
+            weight2 = self.fc2.weight
+
+        return (weight1, weight2)
+
+    def get_biases(self):
+        return (self.fc1.bias, self.fc2.bias)
+
+    def get_weight_grads(self):
+        return (self.fc1.weight.grad, self.fc2.weight.grad)
+
+    def get_bias_grads(self):
+        return (self.fc1.bias.grad, self.fc2.bias.grad)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e62993a3d2b32fdff6261915013e6fd76468f2b8
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed25f72fcb421379ecd1e4f1179dc2ec1b18d632
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c749ca2d541659cb0b9ef67242b48aa235831cb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -0,0 +1,1019 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import contextlib
+import copy
+import functools
+import itertools
+import sys
+import types
+from collections.abc import Callable, Iterator, Sequence
+from dataclasses import dataclass
+from functools import partial, wraps
+from typing import Any, cast, Optional, TypeVar, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._local_tensor import (
+    LocalIntNode,
+    LocalTensor,
+    LocalTensorMode,
+    maybe_disable_local_tensor_mode,
+    maybe_run_for_local_tensor,
+)
+from torch.distributed.tensor import (
+    DeviceMesh,
+    distribute_tensor,
+    DTensor,
+    init_device_mesh,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
+from torch.distributed.tensor._redistribute import redistribute_local_tensor
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    PrepareModuleInput,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.testing._internal.common_distributed import (
+    ACCELERATOR_DIST_BACKENDS,
+    MultiProcContinuousTest,
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    skip_if_lt_x_gpu,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import (
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_PRIVATEUSE1,
+    TEST_XPU,
+)
+from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
+
+
+DEVICE_COUNT: int
+
+if TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1:
+    DEVICE_TYPE = torch.accelerator.current_accelerator().type
+    DEVICE_COUNT = torch.accelerator.device_count()
+    PG_BACKEND = dist.Backend.default_device_backend_map[DEVICE_TYPE]
+else:
+    DEVICE_TYPE = "cpu"
+    PG_BACKEND = "gloo"
+
+NUM_DEVICES = 4
+
+# We use this as a proxy for "multiple GPUs exist"
+if (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1) and DEVICE_COUNT > 1:
+    # when we actually have multiple GPUs, relax the requirement to smaller counts.
+    NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
+
+T = TypeVar("T")
+
+
+# simple RMSNorm layer for testing
+class RMSNormPython(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = torch.nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        output = self._norm(x)
+        return output * self.weight
+
+
+class MLPModule(nn.Module):
+    def __init__(self, device, bias: bool = True):
+        super().__init__()
+        torch.manual_seed(5)
+        self.net1 = nn.Linear(10, 16, bias=bias, device=device)
+        self.relu = nn.ReLU()
+        self.net2 = nn.Linear(16, 10, bias=bias, device=device)
+
+    def forward(self, x):
+        return self.net2(self.relu(self.net1(x)))
+
+    def reset_parameters(self):
+        self.net1.reset_parameters()
+        self.net2.reset_parameters()
+
+
+class MLPStacked(nn.Module):
+    def __init__(self, device, n_layers: int = 2):
+        super().__init__()
+        self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)])
+
+    def forward(self, x):
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+@dataclass
+class ModelArgs:
+    n_layers: int = 2
+    vocab_size: int = 8
+    max_seq_len: int = 16
+    dim: int = 16
+    n_heads: int = 4
+    dropout_p: float = 0.1
+    use_attn_mask: bool = True
+    weight_tying: bool = True
+    checkpoint_activations: bool = False
+
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.dim % args.n_heads == 0
+        self.head_dim = args.dim // args.n_heads
+        self.n_heads = args.n_heads
+        self.dropout_p = args.dropout_p
+        self.resid_dropout = nn.Dropout(args.dropout_p)
+        self.use_attn_mask = args.use_attn_mask
+
+        self.wq = nn.Linear(args.dim, args.dim, bias=False)
+        self.wk = nn.Linear(args.dim, args.dim, bias=False)
+        self.wv = nn.Linear(args.dim, args.dim, bias=False)
+        self.wo = nn.Linear(args.dim, args.dim, bias=False)
+
+    def forward(self, x):
+        bsz, seq_len, _ = x.size()
+        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
+        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
+        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
+        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
+
+        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+
+        output = F.scaled_dot_product_attention(
+            queries,
+            keys,
+            values,
+            None,
+            self.dropout_p if self.training else 0,
+            self.use_attn_mask,
+        )
+        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
+        return self.resid_dropout(self.wo(output))
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, hidden_dim, dropout_p):
+        super().__init__()
+        self.w1 = nn.Linear(dim, hidden_dim)
+        self.gelu = nn.GELU()
+        self.w2 = nn.Linear(hidden_dim, dim)
+        self.resid_dropout = nn.Dropout(dropout_p)
+
+    def forward(self, x):
+        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.attention_norm = nn.LayerNorm(args.dim)
+        self.attention = Attention(args)
+        self.ffn_norm = nn.LayerNorm(args.dim)
+        self.feed_forward = FeedForward(
+            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
+        )
+
+    def forward(self, x):
+        h = x + self.attention(self.attention_norm(x))
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+
+# A toy transformer model, partly inspired by the nanoGPT model:
+# https://github.com/karpathy/nanoGPT.
+class Transformer(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.vocab_size is not None
+        assert args.max_seq_len is not None
+        self.model_args = args
+        self.max_seq_len = args.max_seq_len
+        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
+        self.dropout = nn.Dropout(args.dropout_p)
+        self.layers = nn.ModuleList()
+        for _ in range(args.n_layers):
+            self.layers.append(TransformerBlock(args))
+        self.norm = nn.LayerNorm(args.dim)
+        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+        if args.weight_tying:
+            self.output.weight = self.tok_embeddings.weight
+        self.checkpoint_activations = args.checkpoint_activations
+
+    def forward(self, tokens):
+        _bsz, seq_len = tokens.size()
+        assert seq_len <= self.max_seq_len
+        h = self.tok_embeddings(tokens)
+        pos = torch.arange(0, seq_len, device=tokens.device)
+        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
+        h = h + p
+        h = self.dropout(h)
+        for layer in self.layers:
+            if self.checkpoint_activations:
+                h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False)
+            else:
+                h = layer(h)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output
+
+    @staticmethod
+    def parallelize(
+        module: "Transformer",
+        device_mesh: DeviceMesh,
+        use_seq_parallel: bool,
+        local_output_for_attn: bool = False,
+    ) -> nn.Module:
+        assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
+        # Parallelize the root submodules.
+        if use_seq_parallel:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(1)
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(0)
+                ),
+                "norm": SequenceParallel(),
+            }
+        else:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+            }
+
+        module_tp = parallelize_module(module, device_mesh, root_plan)
+        # Parallelize the attention and feed forward submodules.
+        for layer in module_tp.layers:
+            layer_parallelize_plan = {}
+            if use_seq_parallel:
+                layer_parallelize_plan["attention"] = PrepareModuleInput(
+                    input_layouts=Shard(1),
+                    desired_input_layouts=Replicate(),
+                )
+                # shard the RMSNorms
+                layer_parallelize_plan["attention_norm"] = SequenceParallel()
+                layer_parallelize_plan["ffn_norm"] = SequenceParallel()
+            layer_parallelize_plan["attention.wq"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wk"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wv"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wo"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            layer_parallelize_plan["feed_forward.w1"] = (
+                ColwiseParallel(input_layouts=Shard(1))
+                if use_seq_parallel
+                else ColwiseParallel()
+            )
+            layer_parallelize_plan["feed_forward.w2"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            parallelize_module(layer, device_mesh, layer_parallelize_plan)
+
+        # Parallelize the output submodule. If weight tying is enabled, we need to
+        # make sure output.weight is sharded consistently as tok_embeddings.weight,
+        # at the cost of the all_reduce operation using RowwiseParallel.
+        output_parallelize_plan = (
+            ColwiseParallel(
+                input_layouts=Shard(1),
+                output_layouts=Replicate(),
+            )
+            if use_seq_parallel
+            else ColwiseParallel(output_layouts=Replicate())
+        )
+        parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
+
+        if local_output_for_attn:
+            for layer in module_tp.layers:
+                layer.attention.n_heads = (
+                    module_tp.model_args.n_heads // device_mesh.size()
+                )
+
+        # Manually set output.weight so that parameters and gradients are shared.
+        if module_tp.model_args.weight_tying:
+            module_tp.output.weight = module_tp.tok_embeddings.weight
+
+        return module_tp
+
+
+def skip_unless_torch_gpu(method: T) -> T:
+    """
+    Test decorator which skips the test unless there's a GPU available to torch.
+
+    >>> # xdoctest: +SKIP
+    >>> @skip_unless_torch_gpu
+    >>> def test_some_method(self) -> None:
+    >>>   ...
+    """
+    # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
+    return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
+
+
+class DTensorContinuousTestBase(MultiProcContinuousTest):
+    @classmethod
+    def device_type(cls) -> str:
+        # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
+        if (
+            not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
+            or DEVICE_COUNT < cls.world_size
+        ):
+            return "cpu"
+        else:
+            return DEVICE_TYPE
+
+    @classmethod
+    def backend_str(cls) -> str:
+        backend = dist.get_default_backend_for_device(DEVICE_TYPE)
+        return backend
+
+
+class DTensorTestBase(MultiProcessTestCase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return False
+
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
+        if (
+            not (TEST_CUDA or TEST_XPU or TEST_HPU or TEST_PRIVATEUSE1)
+            or DEVICE_COUNT < self.world_size
+        ):
+            return "cpu"
+        else:
+            return DEVICE_TYPE
+
+    @property
+    def backend(self) -> str:
+        backend = dist.get_default_backend_for_device(self.device_type)
+        return backend
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(self.rank)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        return init_device_mesh(self.device_type, (self.world_size,))
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        if backend is None:
+            backend = self.backend
+
+        requires_gpu = any(
+            gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS
+        )
+        if requires_gpu and torch.accelerator.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        curr_backend = dist.get_default_backend_for_device(self.device_type)
+
+        if backend not in [
+            "nccl",
+            "gloo",
+            "mpi",
+            f"cpu:gloo,{self.device_type}:{curr_backend}",
+            "hccl",
+            "xccl",
+            "fake",
+            "cpu:gloo,xpu:xccl",
+        ]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        device_id = None
+        if "nccl" in backend or "xccl" in backend:
+            # set device for nccl pg for collectives
+            # TODO: if users want to enable testing across hosts, we may need
+            # to change this part.
+            torch.accelerator.set_device_index(self.rank)
+            # we only need to set device_id for nccl backend with eager init
+            device_id = (
+                torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
+            )
+
+        # For nccl backend, bind the device to the process if device_id is not None
+        # so the nccl communicator is immediately formed and we can use `ncclCommSplit`
+        # for form subgroup to avoid unnecessary overhead.
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,  # pyre-ignore[16]
+            init_method=f"file://{self.file_name}",  # pyre-ignore[16]
+            device_id=device_id,
+        )
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        # Wait for all ranks to reach here before starting shutdown.
+        # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
+        # dist.all_reduce(torch.zeros((1,), device="cuda" if TEST_CUDA else "cpu"))
+        # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
+        #  test_dtensor.py  -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
+        if device_id is None:
+            device_id = (
+                torch.cuda.current_device() if self.device_type == "cuda" else self.rank
+            )
+
+        if self.device_type == "cpu":
+            # NOTE: when `device_id` is not None, barrier() will choose the accelerator
+            # of the most pripority, which means if the test specifies to use CPU for
+            # testing while CUDA is available on the host, the barrier() will use CUDA.
+            # To avoid this and better respect `self.device_type`, we add this branch to
+            # enforce barrier() to use CPU when `self.device_type` is CPU and other
+            # accelerator is also available.
+            dist.barrier()
+        else:
+            dist.barrier(device_ids=[device_id])
+
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None:
+        """
+        This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``.
+        Unlike _test_op where the DTensor sharding is generated by DTensorConverter,
+        this function takes in DTensor object directly as argument and test the equality
+        of calling op on full_tensor() and DTensor.
+        """
+        # call full_tensor() on DTensor args/kwargs
+        args_flattened, args_spec = tree_flatten(args)
+        full_tensor_args_flattened = tuple(
+            arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg
+            for arg in args_flattened
+        )
+        full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec)
+        full_tensor_kwargs = {
+            k: v.full_tensor() if isinstance(v, DTensor) else v
+            for k, v in kwargs.items()
+        }
+
+        out_flattened, _ = tree_flatten(
+            op_call(*full_tensor_args, **full_tensor_kwargs)
+        )
+        d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs))
+        d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened]
+        self.assertEqual(out_flattened, d_out_full_tensor_flattened)
+
+    # pyre-ignore[2]:
+    def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
+        out = op_call(*args, **kwargs)
+        dtc = DTensorConverter(mesh, args, kwargs)
+        for d_args, d_kwargs in dtc:
+            # pyre can't find assertTrue anymore?
+            self.assertEqual(dtc.successful(), True)
+            d_out = op_call(*d_args, **d_kwargs)
+            self.assertEqual(d_out.full_tensor(), out)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+
+TestFunc = Callable[[...], object]
+
+
+# wrapper to initialize comms (processgroup)
+def with_comms(
+    eager_init: Union[TestFunc, bool] = False, backend: Optional[str] = None
+) -> TestFunc:
+    def decorator(func, eager_init: bool = False, backend: Optional[str] = None):
+        @wraps(func)  # pyre-ignore[6]
+        def wrapper(
+            self,
+            *args: tuple[object],
+            **kwargs: dict[str, Any],  # type: ignore[misc]
+        ) -> None:
+            # just passthrough if harness doesn't
+            # support init_pg e.g., DTensorOpTestBase
+            if not hasattr(self, "init_pg"):
+                func(self, *args, **kwargs)
+                return
+
+            self.init_pg(eager_init, backend)
+
+            try:
+                func(self, *args, **kwargs)  # type: ignore[misc]
+            except Exception as e:
+                dist.destroy_process_group()
+                raise e
+
+            self.destroy_pg()
+
+        return wrapper
+
+    return (
+        decorator(func=eager_init)
+        if callable(eager_init)
+        else partial(decorator, eager_init=eager_init, backend=backend)
+    )
+
+
+class DTensorOpTestBase(MultiThreadedTestCase):
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        return DEVICE_TYPE
+
+    def build_device_mesh(self):
+        return init_device_mesh(self.device_type, (self.world_size,))
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_threads()
+
+
+# This is a class for converting args/kwargs of an op into distributed args/kwargs
+class DTensorConverter:
+    def __init__(
+        self,
+        mesh: DeviceMesh,
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> None:
+        self.hit = 0
+        self.miss = 0
+        self.mesh = mesh
+        self.args = args
+        self.kwargs = kwargs
+        flatten_args, flatten_args_spec = tree_flatten(args)
+        flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
+
+        self.flatten_args: list[object] = flatten_args
+        self.flatten_args_spec: TreeSpec = flatten_args_spec
+        self.flatten_kwargs: list[object] = flatten_kwargs
+        self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
+
+        choices_for_args = [
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_args
+            if isinstance(arg, torch.Tensor)
+        ]
+
+        choices_for_args.extend(
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_kwargs
+            if isinstance(arg, torch.Tensor)
+        )
+
+        self.sharding_combs: Iterator[Sequence[Placement]] = iter(
+            itertools.product(*choices_for_args)
+        )
+
+    def successful(self) -> bool:
+        return self.hit > 0 and self.miss == 0
+
+    def is_supported_tensor(self, t: torch.Tensor) -> bool:
+        # TODO: dist tensor need to support quantized and sparse
+        # tensors, quantized tensor might be relatively easy, but
+        # sparse tensor have special layouts that we need to possibly
+        # deal with, until we are clear about them, we don't officially
+        # support them.
+        return not any(
+            [
+                t.is_sparse_csr,
+                t.is_sparse,
+                t.is_mkldnn,
+                t.is_quantized,
+                t.is_nested,
+                torch._is_functional_tensor(t),
+                t.is_neg(),
+                t.is_conj(),
+                t.device.type in ("lazy", "meta"),
+                # We need a way to test if a tensor is batched but there
+                # is no official APi to do it
+                # torch._C._is_batched(t),
+            ]
+        )
+
+    def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
+        mesh_size = self.mesh.size()
+        sharding_choices: list[Placement] = [Replicate()]
+        # c10d collective does not support bool tensor
+        # for bool tensor we treat it as replicated
+        if arg.dtype != torch.bool:
+            # only generating choices with: replicate, or sharding
+            # evenly on a dimension that could be sharded
+            sharding_choices = sharding_choices + [
+                Shard(i)
+                for i, s in enumerate(arg.shape)
+                if s > 1 and s % mesh_size == 0
+            ]
+        # TODO: add multi mesh choices
+        # all_choices = itertools.product(
+        #     *(self.mesh.ndim * [sharding_choices])
+        # )
+        return sharding_choices
+
+    def __iter__(self) -> "DTensorConverter":
+        return self
+
+    def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]:
+        try:
+            next_sharding_choices = next(self.sharding_combs)
+            idx = 0
+
+            new_args: list[object] = []
+            for arg in self.flatten_args:
+                if isinstance(arg, torch.Tensor):
+                    new_args.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_args.append(arg)
+
+            new_kwargs: list[object] = []
+            for arg in self.flatten_kwargs:
+                if isinstance(arg, torch.Tensor):
+                    new_kwargs.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_kwargs.append(arg)
+
+            return (
+                tree_unflatten(new_args, self.flatten_args_spec),
+                tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
+            )
+        except StopIteration as e:
+            raise StopIteration from e
+
+    def to_dist_tensor(
+        self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
+    ) -> torch.Tensor:
+        if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
+            if self.is_supported_tensor(t):
+                self.hit += 1
+                if t.ndim == 0:
+                    # scalar tensor by default will be replicated
+                    r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim)
+                else:
+                    # distribute non-scalar tensors
+                    r = distribute_tensor(t, mesh, placements)
+                if isinstance(t, nn.Parameter):
+                    r = nn.Parameter(  # type: ignore[assignment]
+                        r, requires_grad=r.requires_grad
+                    )
+                return r
+            else:
+                self.miss += 1
+                return t
+        elif torch.overrides.is_tensor_like(t):
+            # Blindly converting tensor subclasses to dist tensor can cause
+            # unpredictable problems, we explicitly disable this conversion
+            # for now (i.e. we don't support DTensor holding tensor subclass
+            # until there's a strong reason later).
+            self.miss += 1
+            return t
+        else:
+            raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
+
+
+class LocalDTensorOpTestBase(DTensorOpTestBase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return True
+
+    def _handle_test_skip(self, msg: str) -> None:
+        self.skipTest(msg)
+
+    def _get_local_tensor_mode(self):
+        return LocalTensorMode(frozenset(range(self.world_size)))
+
+    def setUp(self) -> None:
+        super().setUp()
+        torch.autograd._enable_record_function(False)
+
+    def tearDown(self) -> None:
+        from torch.distributed.tensor import _random as random
+
+        random._rng_tracker = None
+        super().tearDown()
+        torch.autograd._enable_record_function(True)
+
+    @property
+    def rank(self):
+        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
+
+    @rank.setter
+    def rank(self, rank):
+        pass
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            fn()
+
+        return types.MethodType(wrapper, self)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        with maybe_disable_local_tensor_mode():
+            return super().build_device_mesh()
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        dist.init_process_group("fake", rank=0, world_size=self.world_size)
+        self._pg = dist.distributed_c10d._get_default_group()
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        dist.destroy_process_group(self._pg)
+        self._pg = None
+
+    def _spawn_processes(self) -> None:
+        pass
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        getattr(self, test_name)()
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(0)
+
+
+class LocalDTensorTestBase(DTensorTestBase):
+    @property
+    def is_local_tensor_enabled(self) -> bool:
+        return True
+
+    def _handle_test_skip(self, msg: str) -> None:
+        self.skipTest(msg)
+
+    def _get_local_tensor_mode(self):
+        return LocalTensorMode(frozenset(range(self.world_size)))
+
+    def setUp(self) -> None:
+        super().setUp()
+        torch.autograd._enable_record_function(False)
+
+    def tearDown(self) -> None:
+        from torch.distributed.tensor import _random as random
+
+        random._rng_tracker = None
+        super().tearDown()
+        torch.autograd._enable_record_function(True)
+
+    @property
+    def rank(self):
+        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
+
+    @rank.setter
+    def rank(self, rank):
+        pass
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            fn()
+
+        return types.MethodType(wrapper, self)
+
+    def build_device_mesh(self) -> DeviceMesh:
+        with maybe_disable_local_tensor_mode():
+            return super().build_device_mesh()
+
+    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
+        dist.init_process_group("fake", rank=0, world_size=self.world_size)
+        self._pg = dist.distributed_c10d._get_default_group()
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        dist.destroy_process_group(self._pg)
+        self._pg = None
+
+    def _spawn_processes(self) -> None:
+        pass
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        getattr(self, test_name)()
+
+    def init_manual_seed_for_rank(self) -> None:
+        torch.manual_seed(0)
+
+
+def make_wrapped(fn, ctxs):
+    @functools.wraps(fn)
+    def wrapped(self):
+        torch._dynamo.reset()
+        stack = contextlib.ExitStack()
+        for ctx in ctxs:
+            if callable(ctx):
+                stack.enter_context(ctx(self))
+            else:
+                stack.enter_context(ctx)
+        try:
+            out = fn(self)
+        finally:
+            stack.close()
+        return out
+
+    return wrapped
+
+
+def create_local_tensor_test_class(
+    orig_cls, skipped_tests=None, base_class=LocalDTensorTestBase
+):
+    if skipped_tests is None:
+        skipped_tests = []
+
+    dct = orig_cls.__dict__.copy()
+    for name in list(dct.keys()):
+        fn = dct[name]
+        if not callable(fn):
+            continue
+        elif name in skipped_tests:
+            dct[name] = lambda self: self.skipTest("Skipped test")
+        elif name.startswith("test_"):
+            ctxs = [
+                lambda test: test._get_local_tensor_mode(),
+            ]
+            dct[name] = make_wrapped(fn, ctxs)
+
+    cls = type(
+        orig_cls.__name__ + "WithLocalTensor",
+        (base_class,) + orig_cls.__bases__,
+        dct,
+    )
+    cls.__file__ = __file__
+    return cls
+
+
+@maybe_run_for_local_tensor
+def map_local_tensor_for_rank(tensor, rank, func):
+    return func(tensor, rank)
+
+
+@maybe_run_for_local_tensor
+def map_local_for_rank(rank, func):
+    return func(rank)
+
+
+def reduce_local_int(val, func):
+    return func(val.node._local_ints)
+
+
+def _convert_shard_order_dict_to_ShardOrder(shard_order):
+    """Convert shard_order dict to ShardOrder"""
+    return tuple(
+        ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
+        for tensor_dim, mesh_dims in shard_order.items()
+    )
+
+
+# TODO(zpcore): remove once the native redistribute supports shard_order arg
+def redistribute(
+    dtensor_input,
+    device_mesh,
+    placements,
+    shard_order,
+    use_graph_based_transform=True,
+):
+    """
+    wrapper function to support shard_order for redistribution
+    This is a simpler version of Redistribute, only considers the forward.
+    """
+    if placements is None:
+        placements = shard_order_to_placement(shard_order, device_mesh)
+    placements = tuple(placements)
+    old_spec = dtensor_input._spec
+    new_spec = copy.deepcopy(old_spec)
+    new_spec.placements = placements
+    if shard_order is not None:
+        new_spec.shard_order = shard_order
+    else:
+        new_spec.shard_order = ()
+    if old_spec == new_spec:
+        return dtensor_input
+    dtensor_input = DTensor.from_local(
+        redistribute_local_tensor(
+            dtensor_input.to_local(),
+            old_spec,
+            new_spec,
+            use_graph_based_transform=use_graph_based_transform,
+        ),
+        device_mesh,
+    )
+    dtensor_input._spec = copy.deepcopy(new_spec)
+    return dtensor_input  # returns DTensor
+
+
+# TODO(zpcore): remove once the native distribute_tensor supports
+# shard_order arg
+def patched_distribute_tensor(
+    input_tensor,
+    device_mesh,
+    placements,
+    shard_order,
+    use_graph_based_transform=True,
+):
+    """wrapper function to support shard_order for tensor distribution"""
+    if placements is None:
+        placements = shard_order_to_placement(shard_order, device_mesh)
+    placements = tuple(placements)
+    tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
+    # fix the shard order
+    return redistribute(
+        tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
+    )
+
+
+# TODO(zpcore): remove once the native redistribute supports shard_order arg
+def make_full_tensor(dtensor_input):
+    """wrapper function to support DTensor.full_tensor"""
+    return redistribute(
+        dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
+    ).to_local()
+
+
+def shard_order_to_placement(shard_order, mesh):
+    """convert shard_order to placement with only Replicate() and Shard()"""
+    placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
+    if shard_order is not None:
+        for entry in shard_order:
+            tensor_dim = entry.tensor_dim
+            mesh_dims = entry.mesh_dims
+            for mesh_dim in mesh_dims:
+                placements[mesh_dim] = Shard(tensor_dim)
+    return tuple(placements)
+
+
+def generate_shard_orders(mesh, tensor_rank):
+    # Generate all possible sharding placement of tensor with rank
+    # `tensor_rank` over mesh.
+    def _split_list(lst: list, N: int):
+        def compositions(n: int, k: int):
+            # yields lists of length k, positive ints summing to n
+            for cuts in itertools.combinations(range(1, n), k - 1):
+                # add 0 and n as sentinels, then take consecutive differences
+                yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
+
+        length = len(lst)
+        for comp in compositions(length, N):
+            result = []
+            start = 0
+            for size in comp:
+                result.append(lst[start : start + size])
+                start += size
+            yield result
+
+    all_mesh = list(range(mesh.ndim))
+    all_device_order = list(itertools.permutations(all_mesh))
+    for device_order in all_device_order:
+        # split on device orders, and assign each device order segment to a tensor dim
+        for num_split in range(1, mesh.ndim + 1):
+            for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
+                for tensor_dims in itertools.combinations(
+                    range(tensor_rank), len(splitted_list)
+                ):
+                    shard_order = {}
+                    assert len(tensor_dims) == len(splitted_list)
+                    for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
+                        shard_order[tensor_dim] = device_order[
+                            mesh_dims[0] : mesh_dims[-1] + 1
+                        ]
+                    yield _convert_shard_order_dict_to_ShardOrder(shard_order)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a57ca2639916b24d2aa6fc2fed5a7051aa3d91
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import io
+import logging
+import os
+import shutil
+import tempfile
+from collections.abc import Callable
+from functools import wraps
+from typing import Any, cast, IO, Optional
+
+# introduced as collections.abc.Buffer in Python 3.12
+from typing_extensions import Buffer
+
+import torch.distributed as dist
+from torch.distributed.checkpoint._extension import (
+    ExtensionRegistry,
+    StreamTransformExtension,
+)
+
+
+class Rot13Example(StreamTransformExtension):
+    """
+    This is an example stream transform extension which just does rot13 on each
+    alphanumeric character of the stream.  It is mainly intended as a demonstration
+    and for testing; there isn't a production use case for this.
+    """
+
+    def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None:
+        super().__init__()
+        self._chunk_size = chunk_size
+
+    @staticmethod
+    def from_descriptor(version: str) -> "Rot13Example":
+        if version.partition(".")[0] != "1":
+            raise ValueError(f"Unknown extension {version=}")
+        return Rot13Example()
+
+    @staticmethod
+    def registry_name() -> str:
+        return "stream.rot13"
+
+    def get_descriptor(self) -> str:
+        return f"{self.registry_name()}/1"
+
+    @staticmethod
+    def _rot13bytes(b: Buffer, count: int) -> None:
+        b = memoryview(b)
+        for i in range(count):
+            ch = b[i]
+            if ch >= ord("A") and ch <= ord("Z"):
+                ch += ord("a") - ord("A")
+            elif ch >= ord("a") and ch <= ord("z"):
+                ch += ord("A") - ord("a")
+            b[i] = ch
+
+    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
+        class Writer(io.RawIOBase):
+            def __init__(self, output: IO[bytes]) -> None:
+                self.output = output
+
+            def writeable(self) -> bool:
+                return True
+
+            def write(self, b: Buffer) -> Optional[int]:
+                # Don't mutate the input
+                chunk = bytearray(b)
+                Rot13Example._rot13bytes(chunk, len(chunk))
+                return self.output.write(chunk)
+
+            def flush(self) -> None:
+                self.output.flush()
+
+        return cast(IO[bytes], Writer(output))
+
+    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
+        class Reader(io.RawIOBase):
+            def __init__(self, input: IO[bytes]) -> None:
+                self.input = input
+
+            def readable(self) -> bool:
+                return True
+
+            def readinto(self, b: Buffer) -> Optional[int]:
+                if hasattr(self.input, "readinto"):
+                    count = self.input.readinto(b)
+                else:
+                    # It's possible self.input is an IO[bytes] with no readinto method.
+                    # In that case, we emulate with a read and copy.  In practice,
+                    # all of the current concrete extensions have readinto.
+                    view = memoryview(b)
+                    r = self.input.read(len(view))
+                    if r is None:
+                        count = None
+                    else:
+                        count = len(r)
+                        view[:count] = r
+                if count == 0 or count is None:
+                    return count
+
+                Rot13Example._rot13bytes(b, count)
+                return count
+
+            def seekable(self) -> bool:
+                return self.input.seekable()
+
+            def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
+                return self.input.seek(offset, whence)
+
+            def tell(self) -> int:
+                return self.input.tell()
+
+        return cast(IO[bytes], Reader(input))
+
+
+def get_test_extension_registry() -> ExtensionRegistry:
+    registry = ExtensionRegistry()
+    registry.register(Rot13Example)
+    return registry
+
+
+def with_temp_dir(
+    func: Optional[Callable] = None,
+) -> Optional[Callable]:
+    """
+    Wrapper to initialize temp directory for distributed checkpoint.
+    """
+    assert func is not None
+
+    @wraps(func)
+    def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
+        if dist.is_initialized():
+            # Only create temp_dir when rank is 0
+            if dist.get_rank() == 0:
+                temp_dir = tempfile.mkdtemp()
+                print(f"Using temp directory: {temp_dir}")
+            else:
+                temp_dir = ""
+            object_list = [temp_dir]
+
+            # Broadcast temp_dir to all the other ranks
+            os.sync()
+            dist.broadcast_object_list(object_list)
+            self.temp_dir = object_list[0]
+            os.sync()
+        else:
+            temp_dir = tempfile.mkdtemp()
+            print(f"No process group initialized, using temp directory: {temp_dir}")
+            self.temp_dir = temp_dir
+
+        try:
+            func(self, *args, **kwargs)
+        finally:
+            if dist.is_initialized() and dist.get_rank() == 0:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+            else:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+
+    return wrapper
+
+
+def with_checkpoint_logging(
+    func: Optional[Callable] = None,
+    logger_name: str = "torch.distributed.checkpoint",
+    level: int = logging.INFO,
+) -> Optional[Callable]:
+    """
+    Wrapper to configure checkpoint logging for distributed tests.
+
+    Args:
+        func: The test function to wrap
+        logger_name: Name of the logger to configure (default: 'torch.distributed.checkpoint')
+        level: Logging level to set (default: logging.INFO)
+    """
+    assert func is not None
+
+    @wraps(func)
+    def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
+        # Get the logger and store original level
+        target_logger = logging.getLogger(logger_name)
+        original_level = target_logger.level
+
+        # Set the desired logging level
+        target_logger.setLevel(level)
+
+        try:
+            func(self, *args, **kwargs)
+        finally:
+            # Restore original logging level
+            target_logger.setLevel(original_level)
+
+    return wrapper
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78e312306ba2500afa3722d6271c645d25f97cf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
@@ -0,0 +1,170 @@
+# mypy: allow-untyped-defs
+
+# Owner(s): ["oncall: distributed"]
+
+import copy
+from itertools import chain
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.distributed._sharded_tensor import ShardedTensor
+from torch.distributed._state_dict_utils import _gather_state_dict
+from torch.distributed.checkpoint.state_dict import (
+    _PG,
+    _STATE,
+    set_state_dict,
+    StateDictOptions,
+)
+from torch.distributed.tensor import DTensor
+
+
+class VerifyStateDictMixin:
+    def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False):
+        if isinstance(dist_tensor, (DTensor, ShardedTensor)):
+            dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
+
+        if offload_to_cpu:
+            orig_tensor = orig_tensor.cpu()
+            dist_tensor = dist_tensor.cpu()
+        self.assertTrue(isinstance(dist_tensor, torch.Tensor))
+        self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
+
+    def _verify_msd(
+        self,
+        msd: dict[str, Any],
+        dist_msd: dict[str, Any],
+        options: StateDictOptions = StateDictOptions(),
+        offload_to_cpu=False,
+    ) -> None:
+        if not options.ignore_frozen_params:
+            self.assertEqual(len(msd), len(dist_msd))
+        for fqn, param in msd.items():
+            dist_param = dist_msd.get(fqn)
+            if not options.ignore_frozen_params:
+                self.assertIsNotNone(dist_param, f"{fqn=}")
+                try:
+                    self._compare_tensor(param, dist_param, offload_to_cpu)
+                except AssertionError as e:
+                    raise AssertionError(
+                        f"{fqn} has mismatched value {param} {dist_param}"
+                    ) from e
+            elif dist_param is None:
+                self.assertFalse(param.requires_grad, f"{fqn=}")
+
+    def _verify_osd(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        osd: dict[str, Any],
+        dist_osd: dict[str, Any],
+    ) -> None:
+        params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
+        param_pid_mapping = dict(zip(params, range(len(params)), strict=True))
+        fqn_pid_mapping = {}
+        for fqn, param in model.named_parameters():
+            pid = param_pid_mapping[param]
+            fqn_pid_mapping[fqn] = pid
+            fqn_pid_mapping[pid] = fqn
+        # Check optimizer_state_dict state
+
+        self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE]))
+        for pid, states in osd[_STATE].items():
+            fqn = fqn_pid_mapping[pid]
+            dist_states = dist_osd[_STATE].get(fqn, None)
+            self.assertIsNotNone(dist_states, fqn)
+            self.assertEqual(len(states), len(dist_states))
+            for key, state in states.items():
+                dist_state = states.get(key, None)
+                self.assertIsNotNone(dist_state)
+                self._compare_tensor(state, dist_state)
+
+        # Check optimizer_state_dict param_group
+        old_dist_osd_pg = dist_osd[_PG]
+        if len(osd[_PG]) != len(dist_osd[_PG]):
+            self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG]))
+            new_pg = copy.deepcopy(dist_osd[_PG][0])
+            new_pg["params"] = []
+            for dist_group in dist_osd[_PG]:
+                new_pg["params"].extend(dist_group["params"])
+            dist_osd[_PG] = [new_pg]
+
+        self.assertEqual(len(osd[_PG]), len(dist_osd[_PG]))
+        for group, dist_group in zip(osd[_PG], dist_osd[_PG], strict=True):
+            self.assertEqual(len(group), len(dist_group))
+            for key, value in group.items():
+                # Below doesn't work because param_groups can have None
+                # values.
+                # dist_value = dist_group.get(key, None)
+                # self.assertIsNotNone(dist_value, (dist_group, group))
+                dist_value = dist_group[key]
+                if key == "params":
+                    fqns = [fqn_pid_mapping[pid] for pid in value]
+                    self.assertEqual(sorted(fqns), sorted(dist_value))
+                else:
+                    self.assertEqual(value, dist_value)
+        dist_osd[_PG] = old_dist_osd_pg
+
+    def _verify_osd_by_load(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        new_optim: torch.optim.Optimizer,
+        dist_osd: dict[str, Any],
+    ) -> None:
+        new_dist_osd = _gather_state_dict(dist_osd)
+        set_state_dict(
+            model,
+            optimizers=new_optim,
+            model_state_dict={},
+            optim_state_dict=new_dist_osd,
+        )
+        self.assertEqual(optim.state_dict(), new_optim.state_dict())
+
+
+class FusionEmbedding(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+
+
+class FusionEmbeddingWithHook(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+        self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook)
+        self._register_load_state_dict_pre_hook(
+            FusionEmbeddingWithHook._load_state_dict_hook, with_module=True
+        )
+
+    def _state_dict_hook(self, destination, prefix, keep_vars):
+        """Remove "embedding" from the original embedding in the state_dict
+        name. This keeps the original state dict name for the embedding
+        from before fusing with the FusionEmbedding.
+        """
+        key = prefix + "embedding.weight"
+        new_key = prefix + "weight"
+        destination[new_key] = destination[key]
+        del destination[key]
+
+    def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
+        """Apply extra "embedding" prefix to the state_dict key to
+        account for the FusionEmbedding wrapping.
+        """
+        if state_dict:
+            key = prefix + "weight"
+            new_key = prefix + "embedding.weight"
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
+
+
+class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
+    # _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
+    # keys, they need to provide a mapping from the new key to the original key. This is used to ensure
+    # consistency between the state_dict keys and fqn.
+    def _fqn_modifiers(self) -> dict[str, str]:
+        return {
+            "weight": "embedding",
+        }
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..32498f6d14917511f599af30e6afc3c5972280fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
@@ -0,0 +1,748 @@
+# mypy: allow-untyped-defs
+
+import contextlib
+import enum
+import logging
+import os
+import threading
+from typing import NamedTuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.nn as nn
+from torch.distributed import rpc
+from torch.distributed.nn import RemoteModule
+from torch.nn.parallel import DistributedDataParallel
+from torch.testing._internal.common_distributed import (
+    requires_gloo,
+    requires_nccl,
+    skip_if_lt_x_gpu,
+    skip_if_rocm_multiprocess,
+)
+from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+NUM_EM_ROW = 2
+D_SPARSE = 3
+D_DENSE = 2
+D_HID = 3
+D_OUT = 1
+NUM_TRAINERS = 4
+# Trainers + the master + the remote worker
+WORLD_SIZE = NUM_TRAINERS + 2
+TRAINER_RANKS = list(range(NUM_TRAINERS))
+REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
+MASTER_RANK = REMOTE_WORKER_RANK + 1
+
+
+class DdpMode(enum.Enum):
+    # Don't apply DDP
+    NONE = enum.auto()
+    # Apply DDP to the top level nn.Module
+    OUTSIDE = enum.auto()
+    # Embed DDP inside the top level nn.Module
+    INSIDE = enum.auto()
+
+
+def init_logger():
+    logger = logging.getLogger(__name__)
+    level = logging.DEBUG if "debug" in os.environ else logging.INFO
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+gLogger = init_logger()
+
+
+class FeatureSet(NamedTuple):
+    """A feature set has 2 types of features"""
+
+    dense_features: torch.Tensor
+    sparse_features: torch.LongTensor
+    values: torch.Tensor
+
+
+def _call_method(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+def _remote_method_async(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+class RemoteEM(nn.Module):
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim)
+        super().__init__()
+        init_em = [0.5] * embedding_dim
+        self.em = nn.EmbeddingBag(
+            num_embeddings,
+            embedding_dim,
+            _weight=torch.tensor([init_em] * num_embeddings),
+        )
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteEM.forward() on: %s", input)
+        return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
+
+
+# Return a linear module with predefined parameters.
+def getLinear(d_in, d_out):
+    l = nn.Linear(d_in, d_out, bias=False)
+    w = torch.ones((d_out, d_in))
+    w[0][0] = -1
+    w.requires_grad_()
+    l.weight.data = w
+    return l
+
+
+class RemoteNet(nn.Module):
+    def __init__(self, d_in: int, d_out: int):
+        gLogger.info("Initing RemoteNet with %s %s", d_in, d_out)
+        super().__init__()
+        self.fc = getLinear(d_in, d_out)
+        self.relu = nn.ReLU()
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteNet.forward() on: %s", input)
+        return self.relu(self.fc(input))
+
+
+class HybridModel(nn.Module):
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        process_group_for_ddp: dist.ProcessGroup = None,
+    ):
+        super().__init__()
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.fc1 = getLinear(D_DENSE, D_DENSE)
+        self.fc2 = getLinear(D_HID, D_OUT)
+
+        self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
+            self.fc2.parameters()
+        )
+        self.ddp_params = ()
+
+        if process_group_for_ddp is not None:
+            self.non_ddp_params, self.ddp_params = (
+                tuple(self.fc1.parameters()),
+                tuple(self.fc2.parameters()),
+            )
+            gLogger.info("Use DDP for the second local net.")
+            self.fc2 = DistributedDataParallel(
+                self.fc2, check_reduction=True, process_group=process_group_for_ddp
+            )
+
+        gLogger.info(
+            "HybridModel has %s groups of parameters.", len(list(self.parameters()))
+        )
+
+    def forward(self, input: FeatureSet):
+        gLogger.debug("Running HybridModel.forward on %s", input)
+        sparse = _remote_method(
+            RemoteEM.forward, self.remote_em_rref, input.sparse_features
+        )
+        # The same size of mini batch.
+        assert sparse.shape[0] == input.dense_features.shape[0]
+        dense = self.fc1(input.dense_features)
+        x = torch.cat((dense, sparse), 1)
+        gLogger.debug("Concatenated feature: %s", x)
+        x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
+        return self.fc2(x)
+
+
+class Trainer:
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        ddp_mode: DdpMode,
+        rank: int,
+    ):
+        self.rank = rank
+        self.trainer_group = (
+            dist.new_group(TRAINER_RANKS)
+            if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
+            else None
+        )
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.hybrid_module = HybridModel(
+            self.remote_em_rref,
+            self.remote_net_rref,
+            self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
+        )
+        self.ddp_params, self.non_ddp_params = (
+            self.hybrid_module.ddp_params,
+            self.hybrid_module.non_ddp_params,
+        )
+        if ddp_mode == DdpMode.OUTSIDE:
+            gLogger.info("Wrapping the whole hybrid module into DDP.")
+            self.ddp_params += self.non_ddp_params
+            self.non_ddp_params = ()
+            self.hybrid_module = DistributedDataParallel(
+                self.hybrid_module,
+                check_reduction=True,
+                process_group=self.trainer_group,
+            )
+        gLogger.info(
+            "Succeeded in creating a HybridModel instance with "
+            "%s ddp params and %s other local params.",
+            len(self.ddp_params),
+            len(self.non_ddp_params),
+        )
+
+    def destroy_pg(self):
+        if self.trainer_group:
+            dist.destroy_process_group(self.trainer_group)
+
+    def train_batch(
+        self,
+        mini_batch: FeatureSet,
+        trainer_has_less_inputs: bool,
+        simulate_uneven_inputs: bool,
+    ):
+        grads_dict = None
+
+        if not simulate_uneven_inputs:
+            input_batches = [mini_batch]
+        else:
+            # Split into microbatches, and trim to simulate uneven inputs.
+            dense_features = mini_batch.dense_features
+            sparse_features = mini_batch.sparse_features
+            values = mini_batch.values
+
+            dense_microbatch = torch.split(dense_features, 2)
+            sparse_microbatch = torch.split(sparse_features, 2)
+            values_microbatch = torch.split(values, 2)
+            batches = []
+            for d, s, v in zip(
+                dense_microbatch, sparse_microbatch, values_microbatch, strict=True
+            ):
+                feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
+                batches.append(feature_set)
+
+            if trainer_has_less_inputs:
+                input_batches = batches[: len(batches) // 2]
+                gLogger.info(
+                    "Trainer reduced input patches from %s "
+                    "to %s to simulate uneven inputs.",
+                    len(batches),
+                    len(input_batches),
+                )
+            else:
+                input_batches = batches
+
+        with (
+            self.hybrid_module.join()
+            if simulate_uneven_inputs
+            else contextlib.nullcontext()
+        ):
+            for b in input_batches:
+                with dist_autograd.context() as context_id:
+                    output = self.hybrid_module.forward(b)
+                    loss = (output * mini_batch.values).sum()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                    gLogger.info(
+                        "Loss is %s for mini batch: %s. Grads dict has %s entries: %s",
+                        loss,
+                        mini_batch,
+                        len(grads_dict),
+                        grads_dict,
+                    )
+        return (
+            tuple(grads_dict[param] for param in self.ddp_params),
+            tuple(grads_dict[param] for param in self.non_ddp_params),
+        )
+
+
+def get_training_examples():
+    n = 16
+    training_examples = FeatureSet(
+        dense_features=torch.zeros((n, D_DENSE)),
+        sparse_features=torch.zeros(n, dtype=torch.long),
+        values=torch.zeros(n),
+    )
+    idx = 0
+    # Every example has another one that has exactly the same features but an
+    # opposite value. Therefore, their grads cancel each other in all-reduce.
+    for value in (-1, 1):
+        for x in (-1.0 * value, 1.0 * value):
+            for y in (1.0 * value, -1.0 * value):
+                for z in (0, 1):
+                    training_examples.dense_features[idx, :] = torch.tensor((x, y))
+                    training_examples.sparse_features[idx] = z
+                    training_examples.values[idx] = value
+                    idx += 1
+
+    # Split the examples among NUM_TRAINERS trainers
+    assert 0 == (n % NUM_TRAINERS)
+    examples_per_trainer = int(n / NUM_TRAINERS)
+    return [
+        FeatureSet(
+            dense_features=training_examples.dense_features[
+                start : start + examples_per_trainer, :
+            ],
+            sparse_features=training_examples.sparse_features[
+                start : start + examples_per_trainer
+            ],
+            values=training_examples.values[start : start + examples_per_trainer],
+        )
+        for start in range(0, n, examples_per_trainer)
+    ]
+
+
+shutdown_signal = threading.Condition()
+
+
+def set_shutdown_signal():
+    global shutdown_signal
+    with shutdown_signal:
+        shutdown_signal.notify()
+
+
+class DdpUnderDistAutogradTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return WORLD_SIZE
+
+    def remote_worker_name(self) -> str:
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{REMOTE_WORKER_RANK}"
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    def _remote_worker_process(self, ddp_mode):
+        gLogger.info("The remote worker is running.")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting remote worker.")
+        dist.destroy_process_group()
+
+    def _trainer_process(self, rank: int):
+        gLogger.info("Running the trainer #%s...", rank)
+        gLogger.info(
+            "Initing trainer process group by trainer #%s with ranks %s",
+            rank,
+            TRAINER_RANKS,
+        )
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        gLogger.info("Waiting for shutdown signal on trainer #%s...", rank)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting the trainer #%s...", rank)
+        dist.destroy_process_group()
+
+    def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
+        gLogger.info("Running the master process...")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_em_rref = rpc.remote(
+            self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
+        )
+        remote_net_rref = rpc.remote(
+            self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
+        )
+        gLogger.info("Created remote rrefs on master")
+        self.do_test_on_master(
+            ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
+        )
+
+    def do_test_on_master(
+        self,
+        ddp_mode: DdpMode,
+        simulate_uneven_inputs: bool,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+    ):
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Running DDP + RPC test with simulating uneven inputs across trainers."
+            )
+
+        trainer_rrefs = []
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            trainer_rrefs.append(
+                rpc.remote(
+                    trainer,
+                    Trainer,
+                    args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
+                )
+            )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        training_examples = get_training_examples()
+        for _ in range(3):
+            futures = []
+            num_trainers = len(trainer_rrefs)
+            for idx, trainer_rref in enumerate(trainer_rrefs):
+                # Half the trainers will deplete inputs earlier than the rest.
+                trainer_has_less_inputs = (
+                    simulate_uneven_inputs and idx < num_trainers // 2
+                )
+                futures.append(
+                    _remote_method_async(
+                        Trainer.train_batch,
+                        trainer_rref,
+                        training_examples[idx],
+                        trainer_has_less_inputs,
+                        simulate_uneven_inputs,
+                    )
+                )
+
+            for future in futures:
+                ddp_grads, non_ddp_grads = future.wait()
+                # When there are uneven inputs, it is not necessary that grads
+                # cancel each other out, since some trainers contribute 0 grad.
+                if not simulate_uneven_inputs:
+                    for grad in ddp_grads:
+                        self.assertEqual(
+                            grad,
+                            torch.zeros_like(grad),
+                            msg=f"The grad for any ddp parameter should be zeros, because "
+                            "the training examples' grads cancel each other. Received "
+                            f"gradient {grad}",
+                        )
+                for grad in non_ddp_grads:
+                    self.assertNotEqual(
+                        grad,
+                        torch.zeros_like(grad),
+                        msg="The grad for any non-ddp parameter shouldn't be zeros",
+                    )
+
+        # Destroy process groups
+        for trainer_rref in trainer_rrefs:
+            _remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
+
+        # Send shutdown signals.
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            rpc.rpc_sync(trainer, set_shutdown_signal, args=())
+
+        rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
+
+    def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
+        if self.rank == MASTER_RANK:
+            self._master_process(ddp_mode, simulate_uneven_inputs)
+        elif self.rank == REMOTE_WORKER_RANK:
+            self._remote_worker_process(ddp_mode)
+        elif self.rank in TRAINER_RANKS:
+            self._trainer_process(self.rank)
+        else:
+            raise RuntimeError(f"Unknown process rank: {self.rank}")
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_no_ddp(self):
+        self._do_test(DdpMode.NONE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside(self):
+        self._do_test(DdpMode.OUTSIDE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside_uneven_inputs(self):
+        self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_inside(self):
+        self._do_test(DdpMode.INSIDE)
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDdpComparisonTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return NUM_TRAINERS
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    @staticmethod
+    def get_remote_grads(rref, context_id):
+        return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
+
+
+class DdpComparisonTest(CommonDdpComparisonTest):
+    def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
+        gLogger.info("Running trainer rank: %s", self.rank)
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            # Postfix file_name with "pg" since file_name is also used by RPC agent
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+        net = nn.Linear(2, 3)
+        ddp_net = DistributedDataParallel(net)
+
+        # Odd ranks join early if simulate_uneven_inputs.
+        num_inputs = 1
+        if simulate_uneven_inputs:
+            if self.rank % 2 == 0:
+                num_inputs += 2
+        inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
+
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Rank %s training with %s inputs.", self.rank, len(inputs_list)
+            )
+
+        # Use distributed autograd. The gradients will be in RPC context map.
+        grads_dict = {}
+        with ddp_net.join(simulate_uneven_inputs):
+            for i, inputs in enumerate(inputs_list):
+                with dist_autograd.context() as context_id:
+                    loss = ddp_net(inputs).norm()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict)
+
+                # Use local autograd. The gradients will be in each variable's '.grad'.
+                ddp_net.zero_grad()
+                loss = ddp_net(inputs).norm()
+                loss.backward()
+
+                # The gradients should be the same
+                for param in net.parameters():
+                    self.assertTrue(
+                        param in grads_dict,
+                        msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
+                    )
+                    self.assertEqual(
+                        grads_dict[param],
+                        param.grad,
+                        msg=f"The grads for param {param} are different under local "
+                        f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
+                    )
+        dist.destroy_process_group()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison(self):
+        self._run_test_ddp_comparision()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison_uneven_inputs(self):
+        # test with simulating uneven inputs in DDP
+        self._run_test_ddp_comparision(simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_sparse_grads(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        model = nn.EmbeddingBag(10, 3, sparse=True)
+        ddp_model = DistributedDataParallel(model)
+
+        # Different inputs for each
+        input = torch.LongTensor(10).random_(0, 10)
+        offsets = torch.LongTensor([0, 4])
+
+        # Run local.
+        loss = ddp_model(input, offsets).sum()
+        loss.backward()
+
+        with dist_autograd.context() as context_id:
+            loss = ddp_model(input, offsets).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads_dict))
+            self.assertEqual(model.weight.grad, grads_dict[model.weight])
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_local_vs_remote(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        # Use two different remote device input string, w/ and w/o the default
+        # device string "cpu", respectively.
+        for remote_device in ["worker0/cpu", "worker0"]:
+            remote_layer1 = RemoteModule(
+                remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
+            )
+            layer1 = nn.Linear(10, 5, False)
+            # Start with the same parameters for remote and local
+            layer1.weight = remote_layer1.module_rref.to_here().weight
+
+            # Run local case.
+            layer2 = nn.Linear(5, 1)
+            inputs = torch.rand((10, 10))
+            ddp_model = DistributedDataParallel(layer2)
+            loss = ddp_model(layer1(inputs)).sum()
+            loss.backward()
+
+            # Run remote case.
+            with dist_autograd.context() as context_id:
+                loss = ddp_model(remote_layer1(inputs)).sum()
+                dist_autograd.backward(context_id, [loss])
+                grads_dict = dist_autograd.get_gradients(context_id)
+                dist.barrier()
+                self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+                self.assertEqual(
+                    layer1.weight.grad,
+                    rpc.rpc_sync(
+                        "worker0",
+                        CommonDdpComparisonTest.get_remote_grads,
+                        args=(remote_layer1.module_rref, context_id),
+                    ),
+                )
+
+
+class CudaDdpComparisonTest(CommonDdpComparisonTest):
+    @skip_if_lt_x_gpu(NUM_TRAINERS)
+    @requires_nccl()
+    @dist_init
+    @skip_if_rocm_multiprocess
+    def test_ddp_dist_autograd_local_vs_remote_gpu(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_layer1 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
+        )
+        layer1 = nn.Linear(10, 7, False)
+        # Start with the same parameters for remote and local
+        layer1.weight = remote_layer1.module_rref.to_here().weight
+
+        layer2 = nn.Linear(7, 5).cuda(self.rank)
+        ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
+
+        remote_layer3 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
+        )
+        layer3 = nn.Linear(5, 3, False)
+        # Start with the same parameters for remote and local
+        layer3.weight = remote_layer3.module_rref.to_here().weight
+
+        layer4 = nn.Linear(3, 1).cuda(self.rank)
+        ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
+
+        # Run local case.
+        inputs = torch.rand((10, 10))
+        loss = ddp_layer4(
+            layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
+        ).sum()
+        loss.backward()
+
+        # Run remote case.
+        with dist_autograd.context() as context_id:
+            loss = ddp_layer4(
+                remote_layer3(
+                    ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
+                ).cuda(self.rank)
+            ).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            dist.barrier()
+            self.assertEqual(
+                layer1.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer1.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+            self.assertEqual(
+                layer3.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer3.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..45bd2d1035b1b190520d712666d7d449adc25664
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
@@ -0,0 +1,10420 @@
+# mypy: allow-untyped-defs
+
+import copy
+import itertools
+import json
+import math
+import operator
+import os
+import random
+import re
+import sys
+import tempfile
+import time
+import unittest
+from collections import defaultdict, namedtuple, OrderedDict
+from collections.abc import Callable
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
+from datetime import timedelta
+from functools import reduce
+from typing import Any, NamedTuple, Union
+
+import numpy as np
+
+import torch
+import torch.cuda
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
+import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
+import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
+import torch.nn as nn
+import torch.nn.functional as F
+from torch._utils_internal import (
+    TEST_MASTER_ADDR as MASTER_ADDR,
+    TEST_MASTER_PORT as MASTER_PORT,
+)
+from torch.autograd import DeviceType
+from torch.cuda.amp import autocast, GradScaler
+from torch.distributed.algorithms.ddp_comm_hooks import (
+    default_hooks as default,
+    post_localSGD_hook as post_localSGD,
+    powerSGD_hook as powerSGD,
+    quantization as quantization_hooks,
+)
+from torch.distributed.distributed_c10d import (
+    _get_default_group,
+    _get_pg_config,
+    get_world_size,
+)
+from torch.distributed.optim import _apply_optimizer_in_backward
+from torch.distributed.utils import (
+    _sync_module_states,
+    _verify_param_shape_across_processes,
+)
+from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
+from torch.profiler import ExecutionTraceObserver, ProfilerActivity
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    cleanup_temp_dir,
+    DistTestCases,
+    init_multigpu_helper,
+    initialize_temp_directories,
+    MultiProcessTestCase,
+    nccl_skip_if_lt_x_gpu,
+    require_n_gpus_for_nccl_backend,
+    requires_nccl_version,
+    simple_sparse_reduce_tests,
+    skip_if_lt_x_gpu,
+    skip_if_no_gpu,
+    skip_if_odd_worldsize,
+    skip_if_rocm_multiprocess,
+    skip_if_small_worldsize,
+    TEST_SKIPS,
+    verify_ddp_error_logged,
+    with_dist_debug_levels,
+    with_nccl_blocking_wait,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    instantiate_parametrized_tests,
+    IS_FBCODE,
+    IS_MACOS,
+    IS_SANDCASTLE,
+    IS_WINDOWS,
+    MI200_ARCH,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    skipIfRocm,
+    skipIfRocmArch,
+    TemporaryFileName,
+)
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils.data.distributed import DistributedSampler
+
+
+try:
+    import torchvision
+
+    HAS_TORCHVISION = True
+except Exception:  # Covering both ImportError and RuntimeError
+    HAS_TORCHVISION = False
+
+if sys.platform == "win32":
+    import msvcrt
+else:
+    import fcntl
+
+
+class NetWithBuffers(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+        self.register_buffer("buffer", torch.randn(1, 2))
+
+    def forward(self, x):
+        self.buffer.add_(1)
+        return self.b(self.a(x))
+
+
+class Foo:
+    def __init__(self, x):
+        # Can be tensor or int
+        self.x = x
+
+    def __eq__(self, other):
+        def eq(value, other):
+            if isinstance(value, torch.Tensor):
+                return torch.equal(value, other)
+            return value == other
+
+        for attr, value in self.__dict__.items():
+            other_value = other.__dict__[attr]
+            if not eq(value, other_value):
+                return False
+        return True
+
+
+f = Foo(10)
+f.bar = 1
+
+
+# Defer instantiation until the seed is set so that randn() returns the same
+# values in all processes.
+def create_collectives_object_test_list():
+    return [
+        {"key1": 3, "key2": 4, "key3": {"nested": True}},
+        f,
+        Foo(torch.randn(3, 3)),
+        "foo",
+        [1, 2, True, "string", [4, 5, "nested"]],
+    ]
+
+
+# Allowlist of distributed backends where profiling collectives is supported.
+PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.NCCL,
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported with use_cuda=True
+CUDA_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported for p2p ops
+SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.MPI,
+    dist.Backend.GLOO,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
+EXPECTED_FIELDS = ("a", "b")
+TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)
+
+
+class TestNamedTupleInput_1(NamedTuple):
+    a: torch.tensor
+    b: torch.tensor
+
+
+skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
+    not HAS_TORCHVISION, "no torchvision"
+)
+
+BACKEND = os.environ["BACKEND"]
+INIT_METHOD = os.getenv("INIT_METHOD", "env://")
+
+DEFAULT_TIMEOUT = 300
+CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
+
+
+def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
+    event_list = (
+        profiler.events()
+        if isinstance(profiler, torch.profiler.profile)
+        else profiler.function_events
+    )
+    return [
+        event
+        for event in event_list
+        if (
+            (event.name.endswith(event_name) or event.name.startswith(event_name))
+            and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
+        )
+    ]
+
+
+def get_profiler_nccl_meta(prof):
+    """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+    We will need to test metadata obtained from profiler here"""
+    with TemporaryFileName(mode="w+t", suffix=".json") as trace_file:
+        prof.export_chrome_trace(trace_file)
+        with open(trace_file) as f:
+            events = json.load(f)["traceEvents"]
+        print(f"Trace saved to {trace_file}")
+
+        return [e for e in events if e.get("name") == "record_param_comms"]
+
+
+# Base error message substring on unfinished reductions.
+ddp_prev_reduction_unfinished_str = (
+    "Expected to have finished reduction in the prior iteration"
+)
+# Error message substring when find_unused_parameters=True has not been passed
+ddp_recommend_find_unused_params_str = (
+    "passing the keyword argument `find_unused_parameters=True`"
+)
+# Error message substring when find_unused_parameters=True is enabled
+ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
+# Error message substring for possibility of not all model outputs being used
+# in loss computation
+ddp_outputs_not_used_in_loss_str = (
+    "`forward` function outputs participate in calculating loss"
+)
+# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
+ddp_suggest_debug_mode_str = (
+    "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
+)
+
+
+class DDPUnevenTestInput(NamedTuple):
+    name: str
+    model: nn.Module
+    inp: Union[torch.tensor, tuple]
+    sync_interval: int
+    throw_on_early_termination: bool = False
+    hook: Callable = None
+    state: Any = None
+
+
+class _FC2(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Linear(10, 50, bias=True)
+        self.fc.bias.requires_grad = False
+
+    def forward(self, x):
+        x = self.fc(x)
+        return x
+
+
+class Net(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False)
+        self.fc2 = _FC2()
+        self.fc3 = nn.Linear(50, 4, bias=False)
+        self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(
+            torch.tensor([2, 2]).long(), requires_grad=False
+        )
+
+    def forward(self, x):
+        x = self.relu(self.fc1(x))
+        x = self.relu(self.fc2(x))
+        x = self.fc3(x)
+        return F.softmax(x, dim=1)
+
+
+class LargeNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(1000, 2000, bias=False)
+        self.fc2 = nn.Linear(2000, 500, bias=False)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+
+class Task(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.p = nn.Parameter(torch.ones(2, 2))
+
+    def forward(self, x):
+        return self.p + x
+
+
+class BatchNormNet(nn.Module):
+    def __init__(self, affine=True):
+        super().__init__()
+        self.fc1 = nn.Linear(2, 40, bias=False)
+        self.bn = nn.BatchNorm1d(4, affine=affine)
+        self.fc2 = nn.Linear(40, 4, bias=False)
+
+    def forward(self, x):
+        x = torch.reshape(self.fc1(x), (-1, 4, 10))
+        x = self.bn(x)
+        x = torch.reshape(x, (-1, 40))
+        x = self.fc2(x)
+        return F.softmax(x, dim=1)
+
+
+class UnusedParamTwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 10, bias=False)
+        self.c = nn.Linear(5, 5, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class DictOutputModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.module = UnusedParamTwoLinLayerNet()
+
+    def forward(self, x):
+        predictions = self.module(x)
+        loss = (predictions[0] + predictions[1]).sum()
+        return {
+            "predictions": predictions,
+            "loss": loss,
+        }
+
+
+class TwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class EmbeddingNetDifferentParams(nn.Module):
+    """
+    A module containing an embedding with different dimension or different # of
+    parameters depending on the rank.
+    """
+
+    def __init__(self, rank, diff_num_params=False):
+        super().__init__()
+        embedding_dim = 500 if diff_num_params or rank == 0 else 50
+        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
+        self.lin = nn.Linear(embedding_dim, 1)
+        if diff_num_params:
+            self.lin2 = nn.Linear(1, 1, bias=False)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        return self.lin(x)
+
+
+class ControlFlowToyModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(10, 10, bias=False)
+        self.lin2 = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        # Second layer is used dependent on input x.
+        use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
+        if use_second_layer:
+            return self.lin2(F.relu(self.lin1(x)))
+        else:
+            return F.relu(self.lin1(x))
+
+
+def get_timeout(test_id):
+    test_name = test_id.split(".")[-1]
+    if test_name in CUSTOMIZED_TIMEOUT:
+        return CUSTOMIZED_TIMEOUT[test_name]
+    else:
+        return DEFAULT_TIMEOUT
+
+
+default_pg_timeout = 60
+
+CUSTOM_PG_TIMEOUT = {
+    # This test runs slowly and needs additional time to complete, otherwise can
+    # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING
+    "test_ddp_uneven_inputs": 300,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_model_diff_across_ranks": 5,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_has_finalized": 5,
+}
+
+
+def require_backend_is_available(backends):
+    def check(backend):
+        if backend == dist.Backend.GLOO:
+            return dist.is_gloo_available()
+        if backend == dist.Backend.NCCL:
+            return dist.is_nccl_available()
+        if backend == dist.Backend.MPI:
+            return dist.is_mpi_available()
+        if backend == dist.Backend.UCC:
+            return dist.is_ucc_available()
+        if backend in DistTestCases.backend_feature["plugin"]:
+            return True
+        return False
+
+    if BACKEND not in backends:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be one of {backends}"
+        )
+
+    if not check(dist.Backend(BACKEND)):
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be available"
+        )
+    return lambda func: func
+
+
+def require_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) < world_size:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires world size of {world_size:d}"
+        )
+    return lambda func: func
+
+
+def require_exact_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) != world_size:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires an exact world size of {world_size:d}"
+        )
+    return lambda func: func
+
+
+@contextmanager
+def _lock():
+    TEMP_DIR = os.environ["TEMP_DIR"]
+    lockfile = os.path.join(TEMP_DIR, "lockfile")
+    with open(lockfile, "w") as lf:
+        try:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
+                yield
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
+                yield
+        finally:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
+            lf.close()
+
+
+@contextmanager
+def _rank_temp_file():
+    if dist.get_rank() == 0:
+        fd, name = tempfile.mkstemp()
+        os.close(fd)
+    else:
+        name = None
+    object_list = [name]
+    dist.broadcast_object_list(object_list)
+    name = object_list[0]
+    try:
+        yield name
+    finally:
+        if dist.get_rank() == 0:
+            os.remove(name)
+
+
+def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
+    if value is None:
+        value = size
+    if device_id is None:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value)
+    else:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
+
+
+def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
+    if value is None:
+        value = dim
+    return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)
+
+
+def _create_autograd_profiler():
+    return torch.autograd.profiler.profile(record_shapes=True)
+
+
+def _create_torch_profiler():
+    return torch.profiler.profile(
+        activities=[
+            torch.profiler.ProfilerActivity.CPU,
+        ],
+        record_shapes=True,
+    )
+
+
+class Barrier:
+    barrier_id = 0
+
+    @classmethod
+    def init(cls):
+        cls.barrier_id = 0
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        for f_name in os.listdir(barrier_dir):
+            os.unlink(os.path.join(barrier_dir, f_name))
+
+    @classmethod
+    def sync(cls, wait_for=None, timeout=10):
+        if wait_for is None:
+            wait_for = dist.get_world_size()
+        cls.barrier_id += 1
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        pid = str(os.getpid())
+        barrier_file = os.path.join(barrier_dir, pid)
+        with _lock():
+            with open(barrier_file, "w") as f:
+                f.write(str(cls.barrier_id))
+
+        start_time = time.time()
+        while True:
+            arrived = 0
+            with _lock():
+                for f_name in os.listdir(barrier_dir):
+                    with open(os.path.join(barrier_dir, f_name)) as f:
+                        data = f.read()
+                        if int(data) >= cls.barrier_id:
+                            arrived += 1
+            if arrived == wait_for:
+                break
+
+            if time.time() - start_time > timeout:
+                raise RuntimeError("barrier timeout")
+            time.sleep(0.1)
+
+
+class TestDistBackend(MultiProcessTestCase):
+    @classmethod
+    def setUpClass(cls):
+        os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
+        # Not setting MASTER_PORT and get a random free port
+        super().setUpClass()
+
+    def setUp(self):
+        super().setUp()
+        # initialize temp directories
+        initialize_temp_directories()
+        # initialize Barrier
+        Barrier.init()
+        # Skip return code checking for following tests as they are expected to
+        # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING.
+        self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__]
+
+    def tearDown(self):
+        cleanup_temp_dir()
+        super().tearDown()
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
+        if BACKEND == "nccl" and not torch.cuda.is_available():
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+
+        if torch.cuda.is_available() and torch.cuda.device_count() < int(
+            self.world_size
+        ):
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        try:
+            pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
+            timeout = timedelta(seconds=pg_timeout_seconds)
+            dist.init_process_group(
+                init_method=self.init_method,
+                backend=BACKEND,
+                world_size=int(self.world_size),
+                rank=self.rank,
+                timeout=timeout,
+            )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        self._barrier()
+
+        self.run_test(test_name, pipe)
+        self._barrier()
+        dist.destroy_process_group()
+        sys.exit(0)
+
+    # Needed since MultiProcessTestCase assumes a world_size of 4, but we
+    # run these tests under other various world_sizes.
+    @property
+    def world_size(self):
+        return os.environ["WORLD_SIZE"]
+
+
+class DistributedTest:
+    class _DistTestBase:
+        def _barrier(self, *args, **kwargs):
+            Barrier.sync(*args, **kwargs)
+
+        def _init_group_test(self, **kwargs):
+            group = [1, 2]
+            group_id = dist.new_group(group, **kwargs)
+            rank = dist.get_rank()
+            if rank not in group:
+                return ([], None, rank)
+
+            return (group, group_id, rank)
+
+        def _init_full_group_test(self, **kwargs):
+            group = list(range(dist.get_world_size()))
+            group_id = dist.new_group(**kwargs)
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _init_global_test(self):
+            group = list(range(dist.get_world_size()))
+            group_id = dist.group.WORLD
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _verify_buffers_equal(self, m1, m2):
+            # verify buffers across models
+            m1_buf_dict = dict(m1.module.named_buffers())
+            for name, buf in m2.module.named_buffers():
+                self.assertEqual(buf, m1_buf_dict[name])
+
+            # Verify buffers across ranks.
+            m1_buffers = list(m1.buffers())
+            m2_buffers = list(m2.buffers())
+            for buf1, buf2 in zip(m1_buffers, m2_buffers, strict=True):
+                gathered_bufs = [
+                    torch.empty_like(buf1) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(gathered_bufs, buf1)
+                gathered_bufs_m2 = [
+                    torch.empty_like(buf2) for _ in range(dist.get_world_size())
+                ]
+                for b in gathered_bufs:
+                    self.assertEqual(b, buf1)
+                dist.all_gather(gathered_bufs_m2, buf2)
+                for b in gathered_bufs_m2:
+                    self.assertEqual(b, buf2)
+
+        def _sanity_check_profiler_nccl_meta(self, nccl_meta_events):
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in this profiler event that correspond to the nccl communication
+            collectives"""
+            per_coll_meta = defaultdict(list)
+            for e in nccl_meta_events:
+                args = e.get("args", {})
+                collname = args.get("Collective name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(args.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(args)
+                if collname == "wait":
+                    continue
+
+                self.assertEqual(args["Process Group Description"], "default_pg")
+                self.assertNotEqual(args["Process Group Ranks"], "")
+
+                self.assertGreaterEqual(args.get("In msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Out msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Group size", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank start", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank stride", -1), 0)
+
+            # print(per_coll_meta)
+            return per_coll_meta
+
+        def test_dump_DDP_relevant_env_vars(self):
+            with captured_output() as (out, _):
+                _dump_DDP_relevant_env_vars()
+                lines = out.getvalue().splitlines()
+
+            def format_line(var):
+                return f"env:{var}={os.environ.get(var, 'N/A')}"
+
+            # Check relevant env vars
+            vars = [
+                "MASTER_ADDR",
+                "MASTER_PORT",
+                "WORLD_SIZE",
+                "NCCL_TOPO_DUMP_FILE",  # N/A
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertIn(line, lines)
+            # Check irrelevant env vars
+            vars = [
+                "xxx",
+                "yyy",
+                "zzz",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertNotIn(line, lines)
+
+        # GET RANK
+        def test_get_rank(self):
+            test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
+            pid = str(os.getpid())
+            num_processes = dist.get_world_size()
+            with open(os.path.join(test_dir, pid), "w") as f:
+                f.write(str(dist.get_rank()))
+
+            self._barrier()
+
+            all_ranks = set()
+            for f_name in os.listdir(test_dir):
+                with open(os.path.join(test_dir, f_name)) as f:
+                    all_ranks.add(int(f.read()))
+            self.assertEqual(len(all_ranks), num_processes)
+
+            self._barrier()
+
+            if dist.get_rank() == 0:
+                for f_name in os.listdir(test_dir):
+                    os.unlink(os.path.join(test_dir, f_name))
+
+            self._barrier()
+
+        def test_get_backend(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            backend_str = BACKEND.lower()
+            self.assertEqual(dist.get_backend(), backend_str)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_backend(group_id), backend_str)
+            else:
+                with self.assertRaisesRegex(
+                    ValueError, "Invalid process group specified"
+                ):
+                    dist.get_backend(group_id)
+
+        def test_Backend_enum_class(self):
+            # test parsing
+            backend = BACKEND.lower()
+            self.assertEqual(dist.Backend(BACKEND.upper()), backend)
+            self.assertEqual(dist.Backend(BACKEND), backend)
+            with self.assertRaises(ValueError):
+                dist.Backend(None)
+            with self.assertRaises(ValueError):
+                dist.Backend(3)
+            with self.assertRaises(ValueError):
+                dist.Backend(["gloo"])
+
+        # Test destroy
+        def test_destroy_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of group
+        def test_get_rank_size_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_world_size(group_id), 2)
+                self.assertTrue(dist.get_rank(group_id) in list(range(2)))
+            else:
+                self.assertEqual(dist.get_world_size(group_id), -1)
+                self.assertEqual(dist.get_rank(group_id), -1)
+
+        # Test destroy full groups
+        def test_destroy_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of full group
+        def test_get_rank_size_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
+            self.assertEqual(dist.get_rank(group_id), dist.get_rank())
+
+        def _test_barrier_timeout(self, group_id, timeout):
+            local_rank = dist.get_rank(group_id)
+
+            # Only execute barrier on rank == 0, causing it to timeout
+            if local_rank == 0:
+                expected_time = time.time() + timeout.total_seconds()
+                # In debug mode, we execute a monitored_barrier before the
+                # collective, so assert on that.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, "failed to pass monitoredBarrier"
+                    )
+                else:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, " (Timed out|closed|timeout) "
+                    )
+                with exception_ctx:
+                    dist.barrier(group_id)
+                self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            not INIT_METHOD.startswith("file://"),
+            "Requires file:// initialization method. "
+            + "Both tcp:// and env:// rely on the TCP store for which "
+            "reinitialization has proven racy.",
+        )
+        def test_barrier_timeout_global(self):
+            dist.destroy_process_group()
+
+            # Explicitly pass world size to the barrier because we've
+            # just destroyed any state in torch.distributed.
+            self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))
+
+            # Reinitialize global process group
+            timeout = timedelta(seconds=1)
+            dist.init_process_group(
+                init_method=INIT_METHOD,
+                backend=BACKEND,
+                world_size=int(os.environ["WORLD_SIZE"]),
+                rank=self.rank,
+                timeout=timeout,
+            )
+            self._test_barrier_timeout(dist.group.WORLD, timeout)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_group(self):
+            timeout = timedelta(seconds=5)
+            _, group_id, _ = self._init_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_full_group(self):
+            timeout = timedelta(seconds=1)
+            _, group_id, _ = self._init_full_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(2)
+        def test_new_subgroups(self):
+            subgroup_size = 2
+            cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)
+
+            world_size = dist.get_world_size()
+            self.assertEqual(cur_subgroup.size(), subgroup_size)
+            self.assertEqual(len(subgroups), world_size / subgroup_size)
+            self.assertFalse(dist._rank_not_in_group(cur_subgroup))
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_exact_world_size(4)
+        def test_new_subgroups_with_group_param(self):
+            # Initialize global test environment
+            self._init_global_test()
+            # Set up GPU devices for each rank
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            # Create two subgroups: one with ranks [0,2] and another with ranks [1,3]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+
+            # Further divide the current subgroup into sub-subgroups of size 1
+            cur_sub_subgroup, sub_subgroups = dist.new_subgroups(
+                group_size=1, group=cur_subgroup
+            )
+            # Verify we have 2 sub-subgroups (one for each rank in the original subgroup)
+            self.assertEqual(len(sub_subgroups), 2)
+            # Verify the current process's sub-subgroup has size 1
+            self.assertEqual(cur_sub_subgroup.size(), 1)
+            # Verify the current process is in its assigned sub-subgroup
+            self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup))
+
+            # Clean up by destroying all created process groups
+            for sub_subgroup in sub_subgroups:
+                dist.destroy_process_group(sub_subgroup)
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_group_size_exceeds_world_size(self):
+            with self.assertRaisesRegex(ValueError, "must not exceed"):
+                dist.new_subgroups(100)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_world_size_not_divisible_by_group_size(self):
+            expected_msg = f"The world size ({dist.get_world_size()}) must be divisible by 'group_size=3'"
+            with self.assertRaisesRegex(
+                ValueError,
+                re.escape(expected_msg),
+            ):
+                dist.new_subgroups(3)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+            if device_id >= 4:
+                self.assertIsNone(cur_subgroup)
+            else:
+                self.assertEqual(cur_subgroup.size(), 2)
+                self.assertEqual(len(subgroups), 2)
+                if device_id == 0 or device_id == 2:
+                    self.assertEqual(cur_subgroup, subgroups[0])
+                else:
+                    self.assertEqual(cur_subgroup, subgroups[1])
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
+            _group, group_id, _rank = self._init_global_test()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            world_size = get_world_size(group_id)
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_by_enumeration_negative_input_rank(self):
+            self._init_global_test()
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_overlap_not_allowed(self):
+            with self.assertRaisesRegex(
+                ValueError, "Rank 1 has appeared in both subgroup"
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_average_parameters(self):
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Sequential(
+                nn.Conv2d(3, 3, kernel_size=3, padding=1),
+                nn.ReLU(),
+                nn.Linear(1, 5, bias=False),
+            ).cuda(device_id)
+            # Test global model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data)
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=None
+            )
+            # Every element will be the same as the input.
+            for p in model.parameters():
+                self.assertEqual(p.data, torch.ones_like(p.data))
+
+            # Test partial model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data) * rank
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=group_nccl
+            )
+            if not dist._rank_not_in_group(group_nccl):
+                # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
+            else:
+                # Every element on device not in the subgroup should remain the same.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager_param_group(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            opt = torch.optim.SGD(model.parameters(), lr=0.1)
+
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    for param_group in opt.param_groups:
+                        for params in param_group["params"]:
+                            # mock grad
+                            params.grad = torch.ones_like(param.data) * rank
+                            params.data = torch.ones_like(param.data) * rank
+                    averager.average_parameters(opt.param_groups)
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data,
+                                    torch.ones_like(param.data)
+                                    * sum(range(world_size))
+                                    / world_size,
+                                )
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data, torch.ones_like(param.data) * rank
+                                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(
+            self,
+        ):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = hierarchicalSGD.HierarchicalModelAverager(
+                    # Run the global averaging at a period of 4,
+                    # which is equivalent to the above periodic model averaging test case.
+                    period_group_size_dict=OrderedDict([(period, world_size)]),
+                    warmup_steps=warmup_steps,
+                )
+
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_exact_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_3_level_hierarchical_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            # Set up such a hierarchical model averaging as follows:
+            # after the first 10 warmup steps,
+            # run model averaging every 2 steps within each subgroup of size 2,
+            # run model averaging every 4 steps within each subgroup of size 3,
+            # and run the global model averaging every 8 steps.
+            # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
+            warmup_steps = 10
+            subgroup_size1 = 2
+            subgroup_avg_period1 = 2
+            subgroup_size2 = 4
+            subgroup_avg_period2 = 4
+            global_avg_period = 8
+            period_group_size_dict = OrderedDict(
+                [
+                    (subgroup_avg_period1, subgroup_size1),
+                    (subgroup_avg_period2, subgroup_size2),
+                    (global_avg_period, world_size),
+                ]
+            )
+            averager = hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
+            )
+            self.assertEqual(dist.get_pg_count(), len(period_group_size_dict))
+
+            subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
+            subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
+            real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"]
+            real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"]
+
+            expect_group_ranks_res1 = (
+                rank // subgroup_size1 * subgroup_size1
+                + np.array(list(range(subgroup_size1)))
+            ).tolist()
+            expect_group_ranks_res2 = (
+                rank // subgroup_size2 * subgroup_size2
+                + np.array(list(range(subgroup_size2)))
+            ).tolist()
+            self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
+            self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)
+
+            expected_avg_tensor_within_subgroup1 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res1)
+                / subgroup_size1
+            )
+            expected_avg_tensor_within_subgroup2 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res2)
+                / subgroup_size2
+            )
+            expected_global_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            for step in range(25):
+                # Reset the parameters at every step.
+                param.data = copy.deepcopy(tensor)
+                for params in model.parameters():
+                    # mock grad
+                    params.grad = torch.ones_like(param.data)
+                averager.average_parameters(model.parameters())
+                if step == 16 or step == 24:
+                    # Run global model averaging when `step` can be divided by 8.
+                    self.assertEqual(param.data, expected_global_avg_tensor)
+                elif step == 12 or step == 20:
+                    # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
+                elif step == 10 or step == 14 or step == 18 or step == 22:
+                    # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
+                else:
+                    # No model averaging, so the parameters are not updated.
+                    self.assertEqual(param.data, tensor)
+
+        # Coalescing manager (sync mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager():
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # Coalescing manager (async mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager_async(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager(async_ops=True) as cm:
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+            cm.wait()
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+            recv_tensors = [None for _ in range(world_size)]
+            expected_tensors = [None for _ in range(world_size)]
+
+            for val in ["1", "0"]:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
+                for src in range(world_size):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(
+                        src
+                    )
+                    recv_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(-1)
+                    expected_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(rank)
+                    recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
+                    p2p_op_list.append(recv_op)
+                    send_op = dist.P2POp(dist.isend, send_tensor, src)
+                    p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+                for src in range(world_size):
+                    self.assertEqual(recv_tensors[src], expected_tensors[src])
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_ring_exchange_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            send_tensor = _build_tensor(world_size, device_id=device_id)
+            recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
+            send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
+            recv_op = dist.P2POp(
+                dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
+            )
+            reqs = dist.batch_isend_irecv([send_op, recv_op])
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_self_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            p2p_op_list = []
+
+            if rank == 0:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, 0)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_no_rank_zero_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+
+            if rank == 1:
+                peer = 2
+            elif rank == 2:
+                peer = 1
+
+            if rank in [1, 2]:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, peer)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU with provided tags
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo_tags(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV Op Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                    send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
+                    dist.batch_isend_irecv([send_op])
+
+        # NCCL Batch SEND RECV p2p_op_list Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_list_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
+                    dist.batch_isend_irecv([1, 2])
+
+        # NCCL Batch SEND RECV Mixed Backend Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_mixed_backend_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            if rank == 0:
+                with self.assertRaisesRegex(
+                    ValueError, "All ops need to use the same group"
+                ):
+                    send_tensor = _build_tensor(rank + 1)
+                    send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
+                    send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
+                    dist.batch_isend_irecv([send_op_gloo, send_op_nccl])
+
+        # NCCL SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def _test_send_recv_nccl(self, profiler_ctx=None):
+            # TODO: now that nccl send/recv is supported, there does not seem to
+            # be a need to have nccl send/recv be tested separately.
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            tensor = _build_tensor(rank + 1, device_id=device_id)
+            profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with profiler_cls as prof:
+                for src in range(world_size):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(world_size):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        expected_tensor = _build_tensor(src + 1)
+                        output_tensor = _build_tensor(
+                            src + 1, value=-1, device_id=device_id
+                        )
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(
+                            event_name, prof, dedup_gpu_user_annotation=True
+                        )
+                        self.assertTrue(events)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl(self):
+            self._test_send_recv_nccl()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl_autograd_profiler(self):
+            profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
+            self._test_send_recv_nccl(profiler_ctx)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_nccl_torch_profiler(self):
+            profiler_ctx = torch.profiler.profile(
+                activities=[
+                    torch.profiler.ProfilerActivity.CPU,
+                    torch.profiler.ProfilerActivity.CUDA,
+                ],
+                record_shapes=True,
+            )
+            self._test_send_recv_nccl(profiler_ctx)
+
+        # SEND RECV
+        def _test_send_recv(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_size = rank + 1
+            tensor = _build_tensor(send_size)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for src in range(dist.get_world_size()):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(dist.get_world_size()):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        recv_size = src + 1
+                        expected_tensor = _build_tensor(recv_size)
+                        output_tensor = _build_tensor(recv_size, value=-1)
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks.
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv(self):
+            self._test_send_recv(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV ANY SOURCE
+        def _test_send_recv_any_source(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            recv_ranks = []
+            irecv_ranks = []
+
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(dist.get_world_size()):
+                    if dst == rank:
+                        # Recv mode
+                        for dst in range(dist.get_world_size()):
+                            if dst == rank:
+                                continue
+
+                            for recv in ["recv", "irecv"]:
+                                output_tensor = _build_tensor(send_recv_size, value=-1)
+
+                                if recv == "recv":
+                                    sender = dist.recv(output_tensor)
+                                    recv_ranks.append(sender)
+                                elif recv == "irecv":
+                                    work = dist.irecv(output_tensor)
+                                    work.wait()
+                                    sender = work._source_rank()
+                                    irecv_ranks.append(sender)
+
+                                # Assert the scalar value "sender" that should be
+                                # equal to the rank of the sender is equal to all
+                                # values in the received tensor.
+                                self.assertTrue(output_tensor.eq(sender).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst)  # recv
+                        dist.send(tensor, dst)  # irecv
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from other rank twice.
+                        self.assertEqual(
+                            sum(event.count for event in events),
+                            2 * (dist.get_world_size() - 1),
+                        )
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+                # Each rank would have 2 * (world_size - 1) sends, verify that
+                # globally we receive the same amount on the other end.
+                recv_ranks_tensor = torch.cat(
+                    (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
+                )
+                global_recv_ranks = [
+                    torch.empty_like(recv_ranks_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(global_recv_ranks, recv_ranks_tensor)
+                global_recv_ranks_list = []
+                for tensor in global_recv_ranks:
+                    global_recv_ranks_list += tensor.tolist()
+
+                from itertools import groupby
+
+                global_recv_ranks_list.sort()
+                frequency = [
+                    len(list(group)) for key, group in groupby(global_recv_ranks_list)
+                ]
+                self.assertEqual(dist.get_world_size(), len(frequency))
+                self.assertEqual(
+                    [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source(self):
+            self._test_send_recv_any_source(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_any_source_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV WITH TAG
+        def _test_send_recv_with_tag(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(world_size):
+                    if dst == rank:
+                        # Recv mode
+                        for src in range(world_size):
+                            if src == rank:
+                                continue
+                            output_tensor = _build_tensor(send_recv_size, value=-1)
+                            dist.recv(output_tensor, src, tag=src)
+                            self.assertTrue(output_tensor.eq(src).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst, tag=rank)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.name, event_name)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag(self):
+            self._test_send_recv_with_tag(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_with_tag_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)
+
+        # ISEND
+        def _test_isend(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                if rank == 0:
+                    requests = [
+                        dist.isend(_build_tensor(dest, 10), dest)
+                        for dest in range(1, world_size)
+                    ]
+                    for request in requests:
+                        request.wait()
+                        self.assertTrue(request.is_completed())
+                else:
+                    tensor = _build_tensor(rank, -1)
+                    dist.recv(tensor, 0)
+                    self.assertEqual(tensor, _build_tensor(rank, 10))
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    expected_event_name = (
+                        f"{backend}:send" if rank == 0 else f"{backend}:recv"
+                    )
+                    events = get_profiling_event(expected_event_name, prof)
+                    event_count = sum(e.count for e in events)
+                    expected_count = dist.get_world_size() - 1 if rank == 0 else 1
+                    self.assertEqual(expected_count, event_count)
+                    # Event ordering is not guaranteed, so simply ensure the shapes are
+                    # found in the following map.
+                    expected_shapes = {
+                        r: [[r] * 3] for r in range(1, dist.get_world_size())
+                    }
+                    for event in events:
+                        self.assertTrue(event.is_async)
+                        self.assertEqual(event.name, expected_event_name)
+                        if rank == 0:
+                            self.assertTrue(
+                                event.input_shapes in expected_shapes.values()
+                            )
+                        else:
+                            self.assertEqual(event.input_shapes, expected_shapes[rank])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend(self):
+            self._test_isend(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_isend(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_isend_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            self._test_isend(profiler_ctx=torch_profiler_ctx)
+
+        # IRECV
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support irecv"
+        )
+        def test_irecv(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+
+            if rank == 0:
+                expected_tensors = [
+                    _build_tensor(src, -1) for src in range(1, world_size)
+                ]
+                requests = [
+                    dist.irecv(expected_tensors[src - 1], src)
+                    for src in range(1, world_size)
+                ]
+
+                for src in range(1, world_size):
+                    requests[src - 1].wait()
+                    self.assertTrue(requests[src - 1].is_completed())
+                    self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
+            else:
+                tensor = _build_tensor(rank, 10)
+                dist.send(tensor, 0)
+
+            self._barrier()
+
+        # BROADCAST
+        def _test_broadcast_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            with_options=False,
+        ):
+            for dtype, value, requires_cuda in [
+                (torch.float, -1e-10, False),
+                (torch.double, -1e-100, False),
+                (torch.half, -0.1, True),
+                (torch.int8, -2, False),
+                (torch.uint8, 129, False),
+                (torch.int, -1e5, False),
+                (torch.long, -1e15, False),
+            ]:
+                if requires_cuda and not cuda:
+                    continue
+                for src in group:
+                    expected_tensor = _build_tensor(src + 1, value, dtype)
+                    if cuda:
+                        expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    if rank == src:
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast",
+                                True,
+                                group_id.broadcast,
+                                [expected_tensor],
+                                opts,
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                expected_tensor,
+                                src,
+                                group_id,
+                            )
+                    else:
+                        tensor = _build_tensor(src + 1, -1, dtype)
+                        if cuda:
+                            tensor = tensor.cuda(rank_to_GPU[rank][0])
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast", True, group_id.broadcast, [tensor], opts
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                tensor,
+                                src,
+                                group_id,
+                            )
+                        self.assertEqual(tensor.size(), expected_tensor.size())
+                        self.assertEqual(
+                            tensor.ne(expected_tensor).max(), torch.tensor(False)
+                        )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and Nccl backend supports CUDA allReduce",
+        )
+        @skip_if_no_gpu
+        def test_broadcast_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl",
+            "Only NCCL backend supports high priority stream",
+        )
+        @skip_if_no_gpu
+        def test_nccl_high_priority_stream(self):
+            group, _, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            new_port = str(MASTER_PORT + 1)
+            os.environ["MASTER_PORT"] = new_port
+            gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
+            store, rank, size = next(gen_iterator)
+            store = dist.PrefixStore(new_port, store)
+
+            opts = dist.ProcessGroupNCCL.Options()
+            opts.is_high_priority_stream = False
+            group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
+
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
+
+        # REDUCE
+        def _test_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensor = _build_tensor(src + 1).fill_(
+                    master_value if rank == src else worker_value
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensor,
+                    src,
+                    op,
+                    group_id,
+                    tensor_shapes=[tensor.shape],
+                )
+                if rank == src:
+                    self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # REDUCE TWICE
+        def _test_reduce_twice_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensors = [
+                    _build_tensor(src + 1).fill_(
+                        master_value if rank == src else worker_value
+                    )
+                    for i in range(2)
+                ]
+                if cuda:
+                    for i in range(2):
+                        tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensors[0],
+                    src,
+                    op,
+                    group_id,
+                    secondary_op_call=lambda: dist.reduce(
+                        tensors[1], src, op, group_id
+                    ),
+                    tensor_shapes=[tensors[0].shape],
+                )
+                if rank == src:
+                    for tensor in tensors:
+                        self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum_twice(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda_twice(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports reduce_scatter_v"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            input_split_sizes = [src + 1 for src in group]
+            start_len = sum(input_split_sizes[:rank])
+            end_len = start_len + input_split_sizes[rank]
+            sum_len = sum(input_split_sizes)
+            master_value = 2
+            worker_value = 10
+
+            for async_val in [True, False]:
+                tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
+                tensor[start_len:end_len].fill_(master_value)
+                out_tensor = (
+                    torch.empty(
+                        input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(-1)
+                    .cuda(device_id)
+                )
+
+                req = dist.reduce_scatter(
+                    out_tensor,
+                    list(torch.split(tensor, input_split_sizes)),
+                    dist.ReduceOp.SUM,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = 2 + (10 * (len(group) - 1))
+                expected_tensor = torch.empty(
+                    input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                )
+                expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test reduce_scatter_tensor accepting single tensor as input
+        def _reduce_scatter_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            tensor_shapes = [tensor_out.shape]
+            self.call_dist_op(
+                ":reduce_scatter_tensor",
+                False,
+                dist.reduce_scatter_tensor,
+                tensor_out,
+                tensor_in,
+                dist.ReduceOp.SUM,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor"
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_out = torch.zeros(size, dtype=torch.int64)
+
+            # Concatenated input
+            tensor_in = torch.arange(len(group) * size)
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+            # Stacked input
+            tensor_in = torch.reshape(tensor_in, (len(group), size))
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            # Should be the same as the result in concatenated case
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def call_dist_op(
+            self,
+            profiling_title_postfix,
+            is_async,
+            op,
+            *args,
+            expect_event=True,
+            secondary_op_call=None,
+            profile_cuda=False,
+            tensor_shapes=None,
+            **kwargs,
+        ):
+            op_calls = [lambda: op(*args, **kwargs)]
+            if secondary_op_call is not None:
+                op_calls.append(secondary_op_call)
+
+            autograd_profiler_ctx = torch.autograd.profiler.profile(
+                use_cuda=profile_cuda, record_shapes=True
+            )
+
+            # TODO: move this test to use torch.profiler once kineto issues are
+            # fixed internally.
+            with autograd_profiler_ctx:
+                works = [op_call() for op_call in op_calls]
+                if is_async:
+                    for work in works:
+                        work.wait()
+
+            if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
+                # We are only interested in the backend's implementation not the dispatcher wrapper.
+                events = get_profiling_event(
+                    dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
+                )
+                # DETAIL debug mode can use a pg wrapper that issues more collectives
+                # under the hood
+                if dist.get_debug_level() != dist.DebugLevel.DETAIL:
+                    self.assertEqual(len(events), len(op_calls))
+                for e in events:
+                    self.assertTrue(e.is_async)
+                    self.assertEqual(e.count, 1)
+                    self.assertGreaterEqual(e.cpu_time, 0)
+                    # Verify tensor shapes if given
+                    # DETAIL debug mode can use a pg wrapper that issues more collectives
+                    # under the hood
+                    if (
+                        tensor_shapes is not None
+                        and dist.get_debug_level() != dist.DebugLevel.DETAIL
+                    ):
+                        self.assertEqual(
+                            e.input_shapes,
+                            tensor_shapes,
+                            f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
+                        )
+
+        # ALL REDUCE
+        def _test_all_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+            async_op=False,
+        ):
+            for src in group:
+                curr_value = master_value if rank == src else worker_value
+
+                tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                if tensor.dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensor).shape]
+                else:
+                    tensor_shapes = [tensor.shape]
+                self.call_dist_op(
+                    ":all_reduce",
+                    async_op,
+                    dist.all_reduce,
+                    tensor,
+                    op,
+                    group_id,
+                    async_op=async_op,
+                    tensor_shapes=tensor_shapes,
+                )
+                # Currently, only Gloo backend has profiling tested with CUDA enabled.
+                # Only run cuda profiling test for one rank to speed up since
+                # running with different src_rank does not affect the correctness.
+                if (
+                    src == 0
+                    and cuda
+                    and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
+                ):
+                    self.call_dist_op(
+                        ":all_reduce",
+                        async_op,
+                        dist.all_reduce,
+                        tensor,
+                        op,
+                        group_id,
+                        async_op=async_op,
+                        profile_cuda=True,
+                        tensor_shapes=tensor_shapes,
+                    )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_async(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_async(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_complex_unsupported_ops(self):
+            unsupported_ops = [
+                dist.ReduceOp.MAX,
+                dist.ReduceOp.MIN,
+                dist.ReduceOp.PRODUCT,
+                dist.ReduceOp.BAND,
+                dist.ReduceOp.BOR,
+                dist.ReduceOp.BXOR,
+            ]
+            _group, group_id, _rank = self._init_global_test()
+            for unsupported_op in unsupported_ops:
+                with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                    dist.all_reduce(
+                        _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_complex(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # SPARSE ALL REDUCE
+        def _test_sparse_all_reduce_sum(self, fn):
+            _group, group_id, rank = self._init_global_test()
+
+            tests = simple_sparse_reduce_tests(
+                rank, dist.get_world_size(), num_inputs=1
+            )
+            for inputs, outputs in tests:
+                tensors = [fn(input) for input in inputs]
+                dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
+                self.assertEqual(tensors[0], outputs[0])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        def test_sparse_all_reduce_sum(self):
+            self._test_sparse_all_reduce_sum(lambda t: t)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        @skip_if_no_gpu
+        def test_sparse_all_reduce_sum_cuda(self):
+            self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
+
+        # ALL REDUCE - COALESCED
+        @staticmethod
+        def _all_reduce_coalesced_sum_test_cases(group_size):
+            return (
+                [2, 3, complex(2, 3)],
+                [10, 11, complex(10, 11)],
+                [
+                    2 + 10 * (group_size - 1),
+                    3 + 11 * (group_size - 1),
+                    complex(2, 3) + complex(10, 11) * (group_size - 1),
+                ],
+                [torch.float, torch.float, torch.cfloat],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_product_test_cases(group_size):
+            return (
+                [1, 2],
+                [3, 4],
+                [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_min_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [1, 3],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_max_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [2, 4],
+                [torch.float, torch.float],
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_coalesced_max_complex_unsupported(self):
+            _group, group_id, _rank = self._init_global_test()
+            with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                dist.all_reduce_coalesced(
+                    [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
+                )
+
+        def _test_all_reduce_coalesced_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            test_case_func = {
+                dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
+                dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
+                dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
+                dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
+            }[op]
+
+            master_values, worker_values, expected_values, dtypes = test_case_func(
+                len(group)
+            )
+
+            for src in group:
+                curr_values = master_values if rank == src else worker_values
+                tensors = [
+                    _build_tensor(src + 1, val, dtype=dtype)
+                    for dtype, val in zip(dtypes, curr_values, strict=True)
+                ]
+                if cuda:
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                tensor_shapes = []
+                for tensor in tensors:
+                    if tensor.dtype == torch.complex64:
+                        tensor_shapes.append(torch.view_as_real(tensor).shape)
+                    else:
+                        tensor_shapes.append(tensor.shape)
+                self.call_dist_op(
+                    ":all_reduce",
+                    False,
+                    dist.all_reduce_coalesced,
+                    tensors,
+                    op,
+                    group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                expected_tensors = [
+                    _build_tensor(src + 1, expected_value, dtype=dtype)
+                    for dtype, expected_value in zip(
+                        dtypes, expected_values, strict=True
+                    )
+                ]
+                self.assertEqual(tensors, expected_tensors)
+
+            self._barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        # SCATTER
+        def _test_scatter_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, -1, dtype=dtype)
+                expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = (
+                    [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
+                    if rank == dest
+                    else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
+                else:
+                    tensor_shapes = [t.shape for t in tensors]
+                self.call_dist_op(
+                    ":scatter",
+                    False,
+                    dist.scatter,
+                    tensor,
+                    src=dest,
+                    scatter_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(tensor, expected_tensor)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify scatter_list argument only on source rank.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, src=0, scatter_list=scatter_list)
+            else:
+                dist.scatter(output, src=0)
+            self.assertEqual(output, one * rank)
+
+            # Don't specify src argument.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, scatter_list=scatter_list)
+            else:
+                dist.scatter(output)
+            self.assertEqual(output, one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_scatter_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        # GATHER
+        def _test_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank)
+                tensors = (
+                    [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                self.call_dist_op(
+                    ":gather",
+                    False,
+                    dist.gather,
+                    tensor,
+                    dst=dest,
+                    gather_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
+                )
+                if rank == dest:
+                    expected_tensors = [_build_tensor(dest + 1, i) for i in group]
+                    for t1, t2 in zip(tensors, expected_tensors, strict=True):
+                        self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify gather_list argument only on destination rank.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, dst=0, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank, dst=0)
+
+            # Don't specify dst argument.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        # ALL GATHER
+        def _test_all_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
+                allgather = dist.all_gather
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if tensors[0].dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
+                else:
+                    tensor_shapes = [tensors[0].shape]
+                self.call_dist_op(
+                    ":all_gather",
+                    False,
+                    allgather,
+                    tensors,
+                    tensor,
+                    group_id,
+                    False,
+                    tensor_shapes=tensor_shapes,
+                )
+
+                expected_tensors = [
+                    _build_tensor(dest + 1, i, dtype=dtype) for i in group
+                ]
+                for t1, t2 in zip(tensors, expected_tensors, strict=True):
+                    self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports all_gather_v"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            output_split_sizes = [dst + 1 for dst in group]
+            sum_len = sum(output_split_sizes)
+            value = 2
+
+            for async_val in [True, False]:
+                tensor = (
+                    torch.empty(
+                        output_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(value)
+                    .cuda(device_id)
+                )
+                out_tensor = _build_tensor(sum_len, -1, device_id=device_id)
+
+                req = dist.all_gather(
+                    list(torch.split(out_tensor, output_split_sizes)),
+                    tensor,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = value
+                expected_tensor = _build_tensor(
+                    sum_len, expected_value, device_id=device_id
+                )
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test all_gather accepting single tensor as output
+        def _all_gather_into_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            if tensor_out.dtype == torch.complex64:
+                tensor_shapes = [torch.view_as_real(tensor_in).shape]
+            else:
+                tensor_shapes = [tensor_in.shape]
+            self.call_dist_op(
+                ":all_gather_into_tensor",
+                False,
+                dist.all_gather_into_tensor,
+                tensor_out,
+                tensor_in,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_cat_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Concatenated output
+            tensor_out = torch.ones([len(group) * size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Concatenate all blocks into a bigger tensor
+            expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_stack_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Stacked output
+            tensor_out = torch.ones([len(group), size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Stack all blocks into a bigger tensor
+            expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def _run_all_gather_coalesced_and_verify(
+            self, output_tensor_lists, input_tensors, expected_tensors, group_id
+        ):
+            """
+            Helper that runs all_gather_coalesced and returns true if output
+            matches expectations.
+            """
+            tensor_shapes = []
+            for input_tensor in input_tensors:
+                if input_tensor.dtype == torch.complex64:
+                    tensor_shapes.append(torch.view_as_real(input_tensor).shape)
+                else:
+                    tensor_shapes.append(input_tensor.shape)
+            self.call_dist_op(
+                ":all_gather",
+                False,
+                dist.all_gather_coalesced,
+                output_tensor_lists,
+                input_tensors,
+                group_id,
+                tensor_shapes=tensor_shapes,
+            )
+
+            for l1, l2 in zip(output_tensor_lists, expected_tensors, strict=True):
+                for t1, t2 in zip(l1, l2, strict=True):
+                    if not torch.equal(t1, t2):
+                        return False
+            return True
+
+        def _test_all_gather_coalesced_helper(
+            self, group, group_id, rank, dtype=torch.float
+        ):
+            # TODO: Instead we should probably go through _rank_not_in_group
+            # mechanism to disable sending tensors
+            if group_id is not None:
+                for test_case_id in range(2, 5):
+                    # Make sure we create tensors of incompatible sizes, e.g.
+                    # [1], [2x2], [3x3x3] ... to be sent in one batch
+                    input_tensors = [
+                        _build_multidim_tensor(
+                            tensor_id, tensor_id, rank + tensor_id, dtype=dtype
+                        )
+                        for tensor_id in range(1, test_case_id)
+                    ]
+                    output_tensor_lists = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, -1, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for _ in group
+                    ]
+                    expected_tensors = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for rank_iter in group
+                    ]
+                    assert self._run_all_gather_coalesced_and_verify(
+                        output_tensor_lists, input_tensors, expected_tensors, group_id
+                    ), "output tensors do not match expected outputs"
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_simple(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_with_empty(self):
+            group, group_id, rank = self._init_global_test()
+            input_tensors = [
+                rank * torch.ones([2, 2]),
+                torch.ones([0]),
+                (rank + 1) * torch.ones([3, 3]),
+                torch.ones([0]),
+                torch.ones([0]),
+            ]
+            output_tensors_lists = [
+                [
+                    -1 * torch.ones([2, 2]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([3, 3]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([0]),
+                ]
+                for _ in group
+            ]
+            expected_tensors = [
+                [
+                    r * torch.ones([2, 2]),
+                    torch.ones([0]),
+                    (r + 1) * torch.ones([3, 3]),
+                    torch.ones([0]),
+                    torch.ones([0]),
+                ]
+                for r in group
+            ]
+            assert self._run_all_gather_coalesced_and_verify(
+                output_tensors_lists, input_tensors, expected_tensors, group_id
+            )
+            self._barrier()
+
+        # AllToAll
+        def _test_all_to_all_single_equal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_tensor = torch.ones([size, size], dtype=dtype) * rank
+                expected_tensor = torch.cat(
+                    [torch.ones([1, size], dtype=dtype) * i for i in group]
+                )
+                out_tensor = torch.ones([size, size], dtype=dtype) * -1
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(in_tensor).shape]
+                else:
+                    tensor_shapes = [in_tensor.shape]
+                self.call_dist_op(
+                    ":all_to_all",
+                    False,
+                    dist.all_to_all_single,
+                    out_tensor,
+                    in_tensor,
+                    group=group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_single_unequal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                out_splits = [rank + 1 for _ in group]
+                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
+                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
+                expected_tensor = torch.cat(
+                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
+                )
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                dist.all_to_all_single(
+                    out_tensor, in_tensor, out_splits, in_splits, group=group_id
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                in_tensors = [
+                    torch.ones([in_splits[i], size], dtype=dtype) * rank
+                    for i, _ in enumerate(group)
+                ]
+                out_tensors = [
+                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
+                ]
+                expected_tensors = [
+                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
+                ]
+                if cuda:
+                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
+                    expected_tensors = [
+                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
+                    ]
+                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
+                dist.all_to_all(out_tensors, in_tensors, group=group_id)
+                for t1, t2 in zip(out_tensors, expected_tensors, strict=True):
+                    self.assertEqual(t1, t2)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        # BARRIER
+        def _test_barrier_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            WAIT_TIME = 0.3  # seconds
+
+            for dest in group:
+                expected_time = torch.DoubleTensor(1).fill_(0.0)
+                if cuda:
+                    expected_time = expected_time.cuda(rank_to_GPU[rank][0])
+                if dest == rank:
+                    expected_time.fill_(time.time() + WAIT_TIME)
+                    dist.broadcast(expected_time, dest, group_id)
+                    time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
+                    dist.barrier(group_id)
+                else:
+                    dist.broadcast(expected_time, dest, group_id)
+                    dist.barrier(group_id)
+                    self.assertGreaterAlmostEqual(
+                        float(time.time()),
+                        float(expected_time[0]),
+                        msg=f"destination rank: {dest:d}, my rank: {rank:d}"
+                        + " (if you see this failure, please report in #14554)",
+                    )
+
+            # Use higher timeout for the instance where the test runs
+            # against a subgroup and uses a CUDA tensor for expected time.
+            # The CUDA initialization for the participating processes can
+            # take long enough for the barrier timeout to trigger on the
+            # process that doesn't participate in the group.
+            self._barrier(timeout=20)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        def test_barrier_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        def _model_step(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad = None
+
+        def _model_step_with_zero_grad(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad.requires_grad_(False)
+                    param.grad.zero_()
+
+        def _prepare_dummy_data(self, local_bs):
+            # global_bs for DDP should be divisible by WORLD_SIZE
+            world_size = int(os.environ["WORLD_SIZE"])
+            global_bs = world_size * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+            return global_bs, input_cpu, target, loss
+
+        # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
+        def _test_DDP_helper(
+            self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
+        ):
+            model.train()
+            output = model(input_var)
+            l = loss(output, target) * scale_factor
+            l.backward()
+            if memory_format is not None:
+                self.assertTrue(output.is_contiguous(memory_format=memory_format))
+
+        def _assert_equal_param(self, param_gpu, param_DDP):
+            self.assertEqual(len(param_gpu), len(param_DDP))
+            for p_gpu, p_DDP in zip(param_gpu, param_DDP, strict=True):
+                self.assertEqual(p_gpu, p_DDP)
+
+        def _test_DDP_niter(
+            self,
+            model_base,
+            model_DDP,
+            input,
+            target,
+            loss,
+            local_bs,
+            rank,
+            batch_size,
+            test_save,
+            offset=None,
+            world_size=0,
+            zero_grad=False,
+            memory_format=None,
+            n_iter=5,
+        ):
+            for idx in range(n_iter):
+                # single cpu/gpu training
+                self._test_DDP_helper(
+                    model_base, input, target, loss, memory_format=memory_format
+                )
+
+                if offset is None:
+                    offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    world_size * local_bs / batch_size if world_size != 0 else 1,
+                    memory_format=memory_format,
+                )
+
+                # Update weights and run a second iteration to shake out errors
+                if zero_grad:
+                    self._model_step_with_zero_grad(model_base)
+                    self._model_step_with_zero_grad(model_DDP)
+                else:
+                    self._model_step(model_base)
+                    self._model_step(model_DDP)
+                self._assert_equal_param(
+                    list(model_base.parameters()), list(model_DDP.module.parameters())
+                )
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+                # save the model in the middle and reload
+                if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
+                    with tempfile.NamedTemporaryFile() as tmp:
+                        if sys.platform == "win32":
+                            torch.save(model_DDP, tmp)
+                            tmp.seek(0)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp, weights_only=False)
+                        else:
+                            torch.save(model_DDP, tmp.name)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp.name, weights_only=False)
+
+            with tempfile.TemporaryFile() as tmp_file:
+                torch.save(model_DDP, tmp_file)
+                tmp_file.seek(0)
+                # weights_only=False as this is legacy code that saves the model
+                saved_model = torch.load(tmp_file, weights_only=False)
+            for k in model_DDP.state_dict():
+                self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
+
+        def _test_DistributedDataParallel(
+            self,
+            gpu_subset,
+            rank,
+            output_device=None,
+            gradient_as_bucket_view=False,
+            static_graph=False,
+            set_static_graph_twice=False,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = Net()
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = copy.deepcopy(model)
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=gpu_subset,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+                static_graph=static_graph,
+            )
+
+            if set_static_graph_twice:
+                model_DDP._set_static_graph()
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # dummy data initialization
+            local_bs = len(gpu_subset)
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+            )
+            self._barrier()
+
+        def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
+            # Run a simple end to end DDP-CPU model, use result of single node
+            # model as baseline
+            _group, _group_id, rank = self._init_global_test()
+
+            # cpu training setup
+            model_base = Net()
+
+            # DDP-CPU training setup
+            model_DDP = copy.deepcopy(model_base)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
+            )
+
+            # dummy data initialization
+            local_bs = 2
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_base,
+                model_DDP,
+                input_cpu,
+                target,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                False,
+                zero_grad=True,
+            )
+            self._barrier()
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU(self):
+            self._test_DistributedDataParallelCPU()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU_grad_is_view(self):
+            self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_DistributedDataParallel_requires_grad(self):
+            # a module without gradients shouldn't be accepted
+            self.assertRaises(
+                RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_zero_output_features(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+                    self.net2 = nn.Linear(10, 0)
+
+            model = ToyModel().to(self.rank)
+            nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+
+        @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
+        def test_ddp_create_graph(self):
+            class Model(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.p = nn.Parameter(torch.tensor(1.0))
+
+                def forward(self):
+                    return self.p.pow(2)
+
+            model = Model()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(model)
+            for _ in range(6):
+                # Verify DDP doesn't throw when ran with create_graph=True.
+                # Although we do warn about potential issues, please see
+                # https://github.com/pytorch/pytorch/issues/63929 for details.
+                ddp_model().backward(create_graph=True)
+                # grad tensors should require grad.
+                self.assertTrue(
+                    all(param.requires_grad for param in ddp_model.parameters())
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedDataParallel_non_default_stream(self):
+            stream = torch.cuda.Stream(self.rank)
+            rank = self.rank
+            with torch.cuda.stream(stream):
+                net = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
+                )
+                for i in range(1000):
+                    # Clear gradients manually
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    # Forward + BW
+                    batch = torch.tensor([rank]).float().cuda(rank)
+                    loss = net(batch).sum()
+                    loss.backward()
+                    # For each worker, the gradient on the weight should be worker_rank.
+                    grad = net.module.weight.grad
+                    avg = grad.clone()
+                    # All-reducing the gradient averages should give us the gradient
+                    # average. If not, then one of the workers has not correctly
+                    # written back the averaged gradient before this all-reduce call.
+                    dist.all_reduce(avg)
+                    world_size = int(os.environ["WORLD_SIZE"])
+                    avg.div_(world_size)
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(
+                        avg[0, 0],
+                        expected_grad,
+                        msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_comm_hook_logging(self):
+            hooks = [
+                default.allreduce_hook,
+                default.fp16_compress_hook,
+                powerSGD.powerSGD_hook,
+                powerSGD.batched_powerSGD_hook,
+                quantization_hooks.quantization_pertensor_hook,
+                quantization_hooks.quantization_perchannel_hook,
+            ]
+
+            cpp_builtin_hooks = [
+                dist.BuiltinCommHookType.ALLREDUCE,
+                dist.BuiltinCommHookType.FP16_COMPRESS,
+            ]
+
+            for hook in hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model.register_comm_hook(None, hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)
+
+            for hook in cpp_builtin_hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model._register_builtin_comm_hook(hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))
+
+            # No hook registered
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Hook not registered yet, so should be empty
+            self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+            # After second forward pass, hook should still be empty string
+            for _ in range(2):
+                inp = torch.ones(1, 1, device=self.rank)
+                loss = ddp_model(inp).sum()
+                loss.backward()
+
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")
+
+        def _test_ddp_hook_with_optimizer_parity(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optim_cls,
+            optimize_subset,
+            *functional_optim_args,
+            **functional_optim_kwargs,
+        ):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            models_to_test = [
+                (LargeNet(), torch.randn(1, 1000).cuda()),
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(
+                    (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
+                )
+            for model, inp in models_to_test:
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    # Create DDP model that runs optimizer in fused fashion.
+                    ddp_model_with_optimizer_hook = (
+                        torch.nn.parallel.DistributedDataParallel(
+                            copy.deepcopy(model).cuda(),
+                            device_ids=[self.rank],
+                            gradient_as_bucket_view=grad_as_bucket_view,
+                            static_graph=static_graph,
+                        )
+                    )
+
+                    # Create DDP model with no hook that does optimizer after
+                    # backward.
+                    ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
+                        copy.deepcopy(model).cuda(),
+                        device_ids=[self.rank],
+                        gradient_as_bucket_view=grad_as_bucket_view,
+                        static_graph=static_graph,
+                    )
+                    hook_params = ddp_model_with_optimizer_hook.parameters()
+                    no_hook_params = ddp_model_with_no_hook.parameters()
+                    if optimize_subset:
+                        hook_params = list(hook_params)
+                        no_hook_params = list(no_hook_params)
+                        self.assertGreater(len(hook_params), 0)
+                        hook_params = [hook_params[0]]
+                        no_hook_params = [no_hook_params[0]]
+
+                    # Register a fused optimizer that will run optimizer in step
+                    # with allreduce.
+
+                    if optimize_subset:
+                        # API where optim_params is specified.
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            optim_params=hook_params,
+                            **functional_optim_kwargs,
+                        )
+                    else:
+                        # API where optim_params is omitted
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            **functional_optim_kwargs,
+                        )
+
+                    optimizer_no_hook = optim_cls(
+                        no_hook_params,
+                        *functional_optim_args,
+                        **functional_optim_kwargs,
+                    )
+
+                    # Verify parameters are equal initially.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Save old parameters to later verify optimizer modified them.
+                    opt_hook_init_params = copy.deepcopy(
+                        list(ddp_model_with_optimizer_hook.parameters())
+                    )
+
+                    # Run optimizer with hook model.
+                    for _ in range(6):
+                        ddp_model_with_optimizer_hook.zero_grad()
+                        out = ddp_model_with_optimizer_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                    dist.barrier()
+
+                    # Run regular model.
+                    for _ in range(6):
+                        ddp_model_with_no_hook.zero_grad()
+                        out = ddp_model_with_no_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+                        optimizer_no_hook.step()
+
+                    dist.barrier()
+
+                    # Now verify parameters are equal.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Verify optimizer modified appropriate parameter set,
+                    # otherwise they'd be trivially equal above.
+                    if optimize_subset:
+                        self.assertNotEqual(
+                            opt_hook_init_params[0],
+                            next(iter(ddp_model_with_optimizer_hook.parameters())),
+                        )
+                        # Untouched params should be equal
+                        self.assertEqual(
+                            opt_hook_init_params[1:],
+                            list(ddp_model_with_optimizer_hook.parameters())[1:],
+                        )
+                    else:
+                        self.assertNotEqual(
+                            opt_hook_init_params,
+                            list(ddp_model_with_optimizer_hook.parameters()),
+                        )
+                    dist.barrier()
+
+        """
+        # Commenting out the following 3 tests as they cause Sandcastle jobs to fail
+        # Failure signature:
+        # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw
+
+        from torch.testing._internal.common_utils import parametrize
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("grad_as_bucket_view", [True, False])
+        @parametrize("static_graph", [True, False])
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adamw(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optimize_subset,
+        ):
+            adamw_lr = 1e-2
+            adamw_betas = (0.9, 0.99)
+            adamw_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                grad_as_bucket_view,
+                static_graph,
+                torch.optim.AdamW,
+                optimize_subset,
+                adamw_lr,
+                betas=adamw_betas,
+                eps=adamw_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
+            adam_lr = 1e-2
+            adam_betas = (0.9, 0.99)
+            adam_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static graph
+                torch.optim.Adam,
+                optimize_subset,
+                adam_lr,
+                betas=adam_betas,
+                eps=adam_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
+            sgd_lr = 1e-2
+            sgd_momentum = 0.9
+            sgd_weight_decay = 0.01
+            # Not testing grad_as_bucket_view and static_graph as they are
+            # tested in AdamW test above.
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static_graph
+                torch.optim.SGD,
+                optimize_subset,
+                sgd_lr,
+                momentum=sgd_momentum,
+                weight_decay=sgd_weight_decay,
+            )
+        """
+
+        @skip_if_lt_x_gpu(2)
+        def test_get_data_parallel_params(self):
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=True
+                )
+            )
+            for name, _ in dp_params:
+                self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
+
+            # test named_params=False, just check if returns the expected
+            # no of parameters.
+            num_ddp_params = len(list(model.parameters())) - 1
+            count = 0
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=False
+                )
+            )
+            for _ in dp_params:
+                count += 1
+            self.assertEqual(count, num_ddp_params)
+
+        def _test_ddp_apply_optim_in_backward(
+            self,
+            optim_cls,
+            optim_kwargs,
+            init_before,
+            gradient_as_bucket_view=True,
+        ):
+            # Need to seed to ensure inputs are unique across rank. Otherwise,
+            # allreduce won't have any effect.
+            torch.manual_seed(self.rank)
+            torch.cuda.manual_seed(self.rank)
+            torch.cuda.set_device(self.rank)
+
+            # Test a simple linear as well as a ResNet model.
+            models_to_test = [
+                nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda()
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(torchvision.models.resnet50().cuda())
+
+            for j, model in enumerate(models_to_test):
+                model_optim_in_bwd = copy.deepcopy(model)
+                model = nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                optim = optim_cls(model.parameters(), **optim_kwargs)
+                if init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+                model_optim_in_bwd = nn.parallel.DistributedDataParallel(
+                    model_optim_in_bwd,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                if not init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+
+                for p1, p2 in zip(
+                    model.parameters(), model_optim_in_bwd.parameters(), strict=True
+                ):
+                    self.assertEqual(p1, p2, "Parameters not initially equal!")
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    for i in range(8):
+                        inp = (
+                            torch.randn(1, 3, 1000, 1000, device="cuda")
+                            if j == 1
+                            else torch.randn(10, 3, device="cuda")
+                        )
+                        model(inp).sum().backward()
+                        optim.step()
+                        model_optim_in_bwd(
+                            inp
+                        ).sum().backward()  # runs optimizer as well
+                        for p1, p2 in zip(
+                            model.parameters(),
+                            model_optim_in_bwd.parameters(),
+                            strict=True,
+                        ):
+                            self.assertEqual(
+                                p1, p2, f"Params not equal at iteration {i}"
+                            )
+                            self.assertTrue(
+                                p2.grad is None,
+                                f"Optim in backward grad is not None at {i}",
+                            )
+
+                        # set_to_none for regular optimizer to match in backward
+                        # case.
+                        optim.zero_grad(set_to_none=True)
+
+        @skipIfRocm
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward(self):
+            for optim_cls, init_before in itertools.product(
+                [torch.optim.SGD, torch.optim.Adam], [True, False]
+            ):
+                with self.subTest(optim_cls=optim_cls):
+                    self._test_ddp_apply_optim_in_backward(
+                        optim_cls=optim_cls,
+                        optim_kwargs={"lr": 0.03},
+                        init_before=init_before,
+                    )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
+            for init_before in [True, False]:
+                self._test_ddp_apply_optim_in_backward(
+                    optim_cls=torch.optim.SGD,
+                    optim_kwargs={"lr": 0.03},
+                    init_before=init_before,
+                    gradient_as_bucket_view=False,
+                )
+
+        @skipIfRocmArch(MI200_ARCH)
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_ignored_params(self):
+            torch.cuda.set_device(self.rank)
+            for init_before in [True, False]:
+                with self.subTest(init_before=init_before):
+                    torch.manual_seed(self.rank)
+                    torch.cuda.manual_seed(self.rank)
+                    model = TwoLinLayerNet()
+                    # Parameters to ignore are in the format {module_name}.{param_name}
+                    params_to_ignore = ["a.weight"]
+                    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                        model, params_to_ignore
+                    )
+                    if init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    net = torch.nn.parallel.DistributedDataParallel(
+                        model.cuda(self.rank),
+                        device_ids=[self.rank],
+                    )
+                    if not init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    inp = torch.randn(1, 10)
+                    a, b = net(inp)
+                    (a.transpose(0, 1) @ b).sum().backward()
+                    # a.weight did not go through allreduce, so optimizer acted on local
+                    # gradient, which should be different across ranks. Remaining params
+                    # should be equal.
+                    models = [None for _ in range(dist.get_world_size())]
+                    dist.all_gather_object(models, model)
+                    rank0_model, remainder = models[0], models[1:]
+                    for m in remainder:
+                        self.assertNotEqual(rank0_model.a.weight, m.a.weight)
+                        self.assertEqual(
+                            list(rank0_model.b.parameters()), list(m.b.parameters())
+                        )
+                        self.assertEqual(rank0_model.a.bias, m.a.bias)
+
+        def _get_fp16_config(self) -> _MixedPrecision:
+            return _MixedPrecision(
+                param_dtype=torch.float16,
+                reduce_dtype=torch.float16,
+                buffer_dtype=torch.float16,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_ignored_params(self):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            model = TwoLinLayerNet()
+            model.register_buffer("buffer", torch.ones(5))
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            to_ignore = ["a.weight", "buffer"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model,
+                to_ignore,
+            )
+            mp_config = self._get_fp16_config()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=True,
+            )
+            to_ignore = [f"module.{name}" for name in to_ignore]
+            expected_ignored = len(to_ignore)
+            n_ignored = 0
+            # ignored params should not have _mp_param or _fp_param fields.
+            for n, p in itertools.chain(net.named_parameters(), net.named_buffers()):
+                if n in to_ignore:
+                    n_ignored += 1
+                    self.assertFalse(hasattr(p, "_mp_param"))
+                    self.assertFalse(hasattr(p, "_fp_param"))
+                else:
+                    self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                    self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            self.assertEqual(expected_ignored, n_ignored)
+
+        def _test_ddp_native_mixed_precision(
+            self, gradient_as_bucket_view, set_grad_to_none
+        ):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            inp = torch.randn(10, 1)
+            mp_config = self._get_fp16_config()
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.m = torch.nn.Linear(1, 5)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+                    self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False)
+
+                def forward(self_, x):  # noqa: B902
+                    params = self_.m.parameters()
+                    for p in params:
+                        self.assertEqual(mp_config.param_dtype, p.dtype)
+
+                    self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype)
+
+                    self.assertEqual(mp_config.param_dtype, x.dtype)
+                    return self_.m(x) + self_.p
+
+            m = MyModel()
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                m.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            # Buffers are casted in constructor.
+            self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype)
+            # Each param should have an mp_param in the lower precision, and
+            # an fp_param in the higher precision.
+            for p in net.parameters():
+                self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            for _ in range(6):
+                loss = net(inp).sum()
+                loss.backward()
+                # Verify gradient synchronization and params and grads are fp32.
+                for n, param in net.named_parameters():
+                    self.assertEqual(param.dtype, torch.float32)
+                    if param.grad is None:
+                        assert n == "module.p"  # Only param that doesn't require grad
+                    else:
+                        self.assertEqual(param.grad.dtype, torch.float32)
+                        tensor_list = [
+                            torch.zeros_like(param.grad)
+                            for _ in range(dist.get_world_size(net.process_group))
+                        ]
+                        dist.all_gather(tensor_list, param.grad)
+                        g, rest = tensor_list[0], tensor_list[1:]
+                        self.assertEqual(g.dtype, torch.float32)
+                        for g_ in rest:
+                            self.assertEqual(g_.dtype, torch.float32)
+                            self.assertEqual(g, g_)
+                net.zero_grad(set_to_none=set_grad_to_none)
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=False,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
+            rank = self.rank
+            m = torch.nn.Linear(1, 5)
+            try:
+                process_group = state.process_group
+            except AttributeError:
+                process_group = state
+
+            net_with_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            net_with_hook.register_comm_hook(state=state, hook=hook)
+            net_without_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            for i in range(100):
+                # Clear gradients manually.
+                for g in [
+                    net_without_hook.module.weight.grad,
+                    net_with_hook.module.weight.grad,
+                ]:
+                    if g is not None:
+                        g.requires_grad_(False)
+                        g.zero_()
+                # Forward + BW
+                batch = torch.tensor([rank]).float().cuda(rank)
+                loss = net_without_hook(batch).sum()
+                loss.backward()
+                # For each worker, the gradient on the weight should be worker_rank.
+                grad = net_without_hook.module.weight.grad
+                avg = grad.clone()
+                expected_grad = (
+                    sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
+                )
+                loss_hook = net_with_hook(batch).sum()
+                loss_hook.backward()
+                grad_hook = net_with_hook.module.weight.grad
+                avg_hook = grad_hook.clone()
+
+                if i < num_validated_iters:
+                    # Verify hook grad with expected.
+                    self.assertEqual(
+                        avg_hook[0, 0].item(),
+                        expected_grad,
+                        msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
+                    )
+                    # Verify hook grad with vanilla allreduce
+                    self.assertEqual(
+                        avg_hook[0, 0],
+                        avg[0, 0],
+                        msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce(self):
+            self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce_process_group(self):
+            # process_group is passed in to both DDP and comm. hook
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
+            process_group = torch.distributed.new_group(gpus)
+            self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_powerSGD(self):
+            for warm_start in [True, False]:
+                powersgd_state = powerSGD.PowerSGDState(
+                    process_group=None,
+                    matrix_approximation_rank=1,
+                    start_powerSGD_iter=2,
+                    warm_start=warm_start,
+                )
+                self._test_ddp_hook_parity(
+                    state=powersgd_state, hook=powerSGD.powerSGD_hook
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_post_localSGD(self):
+            # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
+            # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+            # Only validate the warmup iterations before local SGD is applied,
+            # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
+            # Note that in practice a model averager has to be applied to run model averaging,
+            # so local gradient averaging is not necessary.
+            start_localSGD_iter = 10
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None,
+                subgroup=dist.group.WORLD,
+                start_localSGD_iter=start_localSGD_iter,
+                post_local_gradient_allreduce=False,
+            )
+            self._test_ddp_hook_parity(
+                state=state,
+                hook=post_localSGD.post_localSGD_hook,
+                num_validated_iters=start_localSGD_iter,
+            )
+
+            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
+            # For this single-node test environment, the intra-node process group is equivalent to
+            # the global process group.
+            if self.world_size == dist.get_world_size():
+                state = post_localSGD.PostLocalSGDState(
+                    process_group=None, subgroup=None, start_localSGD_iter=10
+                )
+                self._test_ddp_hook_parity(
+                    state=state, hook=post_localSGD.post_localSGD_hook
+                )
+
+            # Since we start local SGD later than the total number of 100 iterations,
+            # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=None, start_localSGD_iter=1000
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+
+        def _prepare_single_device_module(
+            self,
+            rank,
+            process_group,
+            devices,
+            device_ids,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            device = devices[0] if devices else torch.device(f"cuda:{rank:d}")
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model).to(device),
+                device_ids=device_ids,
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+
+            model.to(device)
+
+            input = torch.randn(global_batch_size, 2).to(device)
+            target = torch.randn(global_batch_size, 4).to(device)
+
+            return model, ddp_model, input, target
+
+        def _prepare_cpu_module(
+            self,
+            process_group,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model),
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            input = torch.randn(global_batch_size, 2)
+            target = torch.randn(global_batch_size, 4)
+            return model, ddp_model, input, target
+
+        def _test_accumulate_gradients_no_sync(
+            self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
+        ):
+            """
+            This is the recommended way to implement accumulate grads.
+            If ``ddp_comm_hook`` input was specified, it will also register that hook
+            to the ``ddp_model``. The hook fed into this function should not change
+            the resulting gradients.
+            """
+            _group, group_id, rank = self._init_global_test()
+            world_size = get_world_size()
+
+            # FIXME: Add testing for gloo/CUDA
+            if BACKEND == "mpi" or BACKEND == "gloo":
+                global_batch_size = world_size
+                local_batch_size = 1
+                model, ddp_model, input, target = self._prepare_cpu_module(
+                    group_id, global_batch_size, gradient_as_bucket_view
+                )
+
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                int_devices = rank_to_GPU[rank][:1]
+                devices = [torch.device("cuda:" + str(i)) for i in int_devices]
+                global_batch_size = world_size
+                local_batch_size = len(devices)
+                model, ddp_model, input, target = self._prepare_single_device_module(
+                    rank,
+                    group_id,
+                    devices,
+                    devices,
+                    global_batch_size,
+                    gradient_as_bucket_view,
+                )
+
+            if ddp_comm_hook is not None:
+                ddp_model.register_comm_hook(group_id, ddp_comm_hook)
+
+            def step_model(model, input, target):
+                model.train()
+                output = model(input)
+                loss = F.mse_loss(output, target.to(output.device))
+                loss.backward()
+
+            # ensure accumulate grads works with no_grad => no grads are accumulated.
+            with torch.no_grad():
+                with ddp_model.no_sync():
+                    ddp_model.train()
+                    ddp_model(input)
+
+            # check two model parameters over num_iters iterations
+            for iteration in range(num_iters):
+                step_model(model, input, target)
+
+                ddp_input = input[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+                ddp_target = target[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+
+                if iteration % 2 == 0:
+                    # accumulate grads locally
+                    with ddp_model.no_sync():
+                        step_model(ddp_model, ddp_input, ddp_target)
+                else:
+                    # sync grads
+                    step_model(ddp_model, ddp_input, ddp_target)
+
+                for i, j in zip(
+                    model.parameters(), ddp_model.parameters(), strict=True
+                ):
+                    if not i.requires_grad:
+                        continue
+                    if iteration % 2 == 0:
+                        self.assertNotEqual(i.grad, j.grad)
+                    else:
+                        self.assertEqual(i.grad, j.grad)
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + iteration)
+                input = input[torch.randperm(global_batch_size)]
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_grad_is_view(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync
+            using allreduce hook and validates whether future result was properly
+            passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                tensors = [bucket.buffer() / world_size]
+                return (
+                    group_id.allreduce(tensors)
+                    .get_future()
+                    .then(lambda fut: fut.value()[0])
+                )
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
+            hook that also uses then callbacks. In first then callback result is multiplied
+            by 2, and the second callback divides the result by 2 * world_size. It validates
+            whether final result was properly passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_with_then_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                fut = group_id.allreduce([bucket.buffer()]).get_future()
+
+                def mult(fut):
+                    # Multiply the result by 2.
+                    return 2 * fut.wait()[0]
+
+                def div(fut):
+                    # Divide the result by 2 * world_size.
+                    return fut.wait() / (2 * world_size)
+
+                return fut.then(mult).then(div)
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_with_then_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_get_future(self):
+            def mult(fut):
+                return [t * 3 for t in fut.wait()]
+
+            def add(fut):
+                return [t + 1 for t in fut.wait()]
+
+            group, group_id, rank = self._init_global_test()
+            input = _build_tensor(3, 2)
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                input = input.to(device_id)
+            fut = group_id.allreduce([input]).get_future()
+            res = fut.then(mult).then(add).wait()
+            expected = _build_tensor(3, 2 * len(group) * 3 + 1)
+
+            self.assertEqual(res[0], expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            gpus = list(rank_to_GPU[rank])
+
+            for use_bucket_view, static_graph in itertools.product(
+                (False, True), (False, True)
+            ):
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test set static graph twice
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                    set_static_graph_twice=True,
+                )
+
+                # test output_device
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test device_ids
+                gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus_list,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+        def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
+            torch.manual_seed(31415)
+            # Creates model and optimizer in default precision
+            model = Net().cuda()
+            optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
+
+            # Creates a GradScaler once at the beginning of training.
+            scaler = GradScaler()
+
+            ddp_model = nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            # verify grads are none before training
+            for p in ddp_model.parameters():
+                self.assertTrue(p is not None)
+                self.assertTrue(p.grad is None)
+
+            for idx in range(20):
+                optimizer.zero_grad()
+                # Runs the forward pass with autocasting.
+                with autocast():
+                    output = ddp_model(input)
+                    loss = loss_fn(output, target)
+
+                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
+                # Backward passes under autocast are not recommended.
+                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
+                scaler.scale(loss).backward()
+
+                # verify grads are not none and are valid during training
+                for p in ddp_model.parameters():
+                    if p.requires_grad:
+                        self.assertTrue(p.grad is not None)
+                        self.assertFalse(p.grad.isnan().any())
+                        self.assertFalse(p.grad.isinf().any())
+
+                # scaler.step() first unscales the gradients of the optimizer's assigned params.
+                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
+                # otherwise, optimizer.step() is skipped.
+                scaler.step(optimizer)
+
+                # Updates the scale for next iteration.
+                scaler.update()
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + idx)
+                input = input[torch.randperm(dist.get_world_size() * 2)]
+
+            return ddp_model
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=False
+            )
+            ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=True
+            )
+            for i, j in zip(
+                ddp_model_grad_not_view.parameters(),
+                ddp_model_grad_is_view.parameters(),
+                strict=True,
+            ):
+                self.assertEqual(i, j)
+
+        def _test_DistributedDataParallel_SyncBatchNorm(
+            self,
+            gpu_subset,
+            rank,
+            local_bs,
+            global_bs,
+            offset,
+            output_device=None,
+            affine=True,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = BatchNormNet() if affine else BatchNormNet(affine=False)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, device_ids=gpu_subset
+            )
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # data initialization
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                offset,
+                dist.get_world_size(),
+                5 if affine else 2,
+            )
+            self._barrier()
+
+        def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
+            learning_rate = 0.03
+
+            DDP_NET = Net()
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            averager = create_averager()
+            opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            # Process group cannot be pickled in some environments,
+            # so cannot deep copy an averager. See:
+            # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
+            averager2 = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(opt, net, loss_fn, input, target)
+                averager.average_parameters(net.parameters())
+
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+                for p1, p2 in zip(
+                    net.parameters(),
+                    net_using_post_localSGD_opt.parameters(),
+                    strict=True,
+                ):
+                    self.assertEqual(p1.data, p2.data)
+
+            # Also check if the built-in step counters are the same to prevent a bug like #74737.
+            self.assertEqual(averager.step, averager2.step)
+
+        def _create_periodic_model_averager(self):
+            return averagers.PeriodicModelAverager(period=4, warmup_steps=10)
+
+        def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
+            return post_localSGD_optimizer.PostLocalSGDOptimizer(
+                optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
+                averager=averager,
+            )
+
+        def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
+            optimizer.zero_grad()
+            output = net(input)
+            loss = loss_fn(output, target)
+            loss.backward()
+            optimizer.step()
+
+        def _test_post_localSGD_optimizer_step_reload(
+            self, create_averager, chkpt_file
+        ):
+            learning_rate = 0.03
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                Net().cuda(), device_ids=[self.rank]
+            )
+
+            averager = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager
+            )
+
+            averager2 = create_averager()
+            dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+
+            if self.rank == 0:
+                torch.save(
+                    {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{self.rank:d}"}
+            checkpoint = torch.load(chkpt_file, map_location=map_location)
+            dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])
+
+            # Check that we didn't hit the trivial case
+            self.assertNotEqual(averager2.step, 0)
+            # Check if dummy averager was initialized to a correct value
+            self.assertEqual(averager.step, averager2.step)
+
+            # Remove 'step' entry from a checkpoint.
+            # And make sure it is not in the state dictionary
+            del checkpoint["optimizer_state_dict"]["step"]
+            self.assertNotIn("step", checkpoint["optimizer_state_dict"])
+
+            # Check if checkpoint without a 'step' entry invokes a warning
+            with self.assertWarnsRegex(
+                expected_warning=UserWarning,
+                expected_regex="Loaded state dict does not contain a step counter for an averager. "
+                "Setting step counter to 0.",
+            ):
+                dummy_post_localSGD_opt.load_state_dict(
+                    checkpoint["optimizer_state_dict"]
+                )
+
+            self.assertEqual(averager2.step, 0)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=True,
+            )
+
+        def _create_hierarchical_model_averager(self):
+            period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
+            return hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=4
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(
+            self,
+        ):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=True,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_step_reload(self):
+            torch.cuda.set_device(self.rank)
+            with _rank_temp_file() as tmp_file:
+                self._test_post_localSGD_optimizer_step_reload(
+                    self._create_periodic_model_averager, tmp_file
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last
+            )
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last_3d
+            )
+
+        def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+            self, memory_format
+        ):
+            _group, _group_id, rank = self._init_global_test()
+            num_processes = dist.get_world_size()
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(num_processes * 2)
+
+            model = nn.SyncBatchNorm(2, momentum=0.99)
+            model_gpu = copy.deepcopy(model).cuda(rank)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_gpu, device_ids=[rank]
+            )
+
+            shapes = [global_bs, 2, 4, 4] + (
+                [] if memory_format is torch.channels_last else [4]
+            )
+
+            input_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            target_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_gpu,
+                target_gpu,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                bs_offset,
+                dist.get_world_size(),
+                memory_format=memory_format,
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+            # test output_device
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+            # test device_ids
+            gpus = [torch.device("cuda:" + str(i)) for i in gpus]
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                affine=False,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = len(gpus) * 2
+            global_bs = dist.get_world_size() * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        @require_world_size(2)
+        def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = 1
+            global_bs = dist.get_world_size()
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
+            self,
+        ):
+            ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
+            _group, _group_id, rank = self._init_global_test()
+            model = nn.parallel.DistributedDataParallel(
+                ONLY_SBN_NET.cuda(rank), device_ids=[rank]
+            )
+
+            input_var = []
+            for i in range(dist.get_world_size()):
+                input_var_rank = torch.cat(
+                    [
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
+                    ],
+                    dim=1,
+                )
+                input_var.append(input_var_rank)
+
+            all_input_var = torch.cat(
+                [
+                    x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
+                    for x in input_var
+                ],
+                dim=1,
+            ).cuda(rank)
+
+            for _ in range(100):
+                y = model(input_var[rank].cuda(rank))
+                y.mean().backward()
+
+            running_mean, running_var = (
+                model.module.running_mean,
+                model.module.running_var,
+            )
+            torch.testing.assert_close(running_mean, all_input_var.mean(1))
+            torch.testing.assert_close(running_var, all_input_var.var(1))
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
+            _group, _group_id, rank = self._init_global_test()
+            # only do single GPU per process
+            gpus = [rank]
+
+            # cpu training setup
+            num_processes = dist.get_world_size()
+            local_bs = rank + 2
+            bs_offset = int((rank + 3) * rank / 2)
+            global_bs = int((num_processes + 3) * num_processes / 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_half(self):
+            _group, _group_id, rank = self._init_global_test()
+
+            model = BatchNormNet()
+            model = model.half()
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+            model = nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[rank]
+            )
+            inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
+            # Check that forward/backward do not error with dtype mismatch
+            out = model(inp)
+            self.assertEqual(out.dtype, torch.float16)
+            out.sum().backward()
+            for param in model.parameters():
+                self.assertEqual(param.grad.dtype, torch.float16)
+
+        def _test_ddp_logging_data(self, is_gpu):
+            rank = dist.get_rank()
+            model_DDP = Net()
+            if is_gpu:
+                model_DDP = nn.parallel.DistributedDataParallel(
+                    model_DDP.cuda(rank), device_ids=[rank]
+                )
+            else:
+                model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
+
+            # dummy data initialization
+            local_bs = 2
+            batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+            if is_gpu:
+                input = input.cuda(rank)
+                target = target.cuda(rank)
+
+            model_DDP._set_ddp_runtime_logging_sample_rate(2)
+
+            for idx in range(20):
+                offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+
+                self._model_step_with_zero_grad(model_DDP)
+
+                # Verify DDP logging data is sampled as expected
+                # If it has ran more than 10 iterations and this is
+                # the sampled iteration for measuring run time stats,
+                # the run time stats for this idx-th iteration will not
+                # be zeros.
+                ddp_logging_data = model_DDP._get_ddp_logging_data()
+                if idx > 0 and (idx < 10 or idx % 2 == 0):
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("forward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertEqual(ddp_logging_data.get("iteration"), idx)
+                elif idx > 0:
+                    # if the idx-th iteration is not sampled to set runtime stats,
+                    # ddp_logging_data.iteration will not be updated to current
+                    # iteration.
+                    self.assertNotEqual(ddp_logging_data.get("iteration"), idx)
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_ddp_logging_data_cpu(self):
+            def parse_env(var):
+                return os.environ.get(var, "N/A")
+
+            dist.set_debug_level(dist.DebugLevel.INFO)
+            _, group_id, _ = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=False)
+
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
+            self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
+            self.assertEqual(ddp_logging_data.get("module_name"), "Net")
+            self.assertEqual(ddp_logging_data.get("device_ids"), "")
+            # output_device is -1 in default if it is not set, e.g.
+            # output_device of CPU training is -1.
+            self.assertEqual(ddp_logging_data.get("output_device"), -1)
+            self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
+            self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
+            self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
+            self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
+            self.assertEqual(
+                ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
+            )
+            self.assertEqual(ddp_logging_data.get("iteration"), 18)
+            params = list(model_DDP.parameters())
+            num_params = 0
+            param_size = 0
+            params = list(filter(lambda parameter: parameter.requires_grad, params))
+            for p in params:
+                num_params += 1
+                param_size += p.numel() * p.element_size()
+            self.assertEqual(ddp_logging_data.get("dtypes"), "float")
+            self.assertEqual(
+                ddp_logging_data.get("total_parameter_size_bytes"), param_size
+            )
+            self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
+            self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
+            self.assertEqual(
+                ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("torch_distributed_debug"),
+                parse_env("TORCH_DISTRIBUTED_DEBUG"),
+            )
+            self.assertEqual(
+                ddp_logging_data.get("cuda_visible_devices"),
+                parse_env("CUDA_VISIBLE_DEVICES"),
+            )
+            if ddp_logging_data.get("backend_name") == "gloo":
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_socket_ifname"),
+                    parse_env("GLOO_SOCKET_IFNAME"),
+                )
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_device_transport"),
+                    parse_env("GLOO_DEVICE_TRANSPORT"),
+                )
+                default_gloo_threads = 2
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_num_threads"),
+                    default_gloo_threads,
+                )
+
+            self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
+            # test runtime logging fields
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
+            self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
+            self.assertEqual(
+                ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
+            )
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"), 1
+            )
+            self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+            # test larger net with mixed data types, verify multiple bucket sizes
+            model = LargeNet()
+            model.float()
+            model.fc1.double()
+            model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            params = list(model_DDP.parameters())
+            self.assertEqual(
+                ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
+            )
+            bucket_sizes = [
+                params[1].numel() * params[1].element_size(),
+                params[0].numel() * params[0].element_size(),
+            ]
+            self.assertEqual(
+                ddp_logging_data.get("bucket_sizes"),
+                ", ".join(str(x) for x in bucket_sizes),
+            )
+            self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_ddp_logging_data_gpu(self):
+            _group, _group_id, rank = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=True)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
+            self.assertEqual(ddp_logging_data.get("output_device"), rank)
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # test runtime logging fields
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_static_graph_api_cpu(self):
+            model_DDP = nn.parallel.DistributedDataParallel(Net())
+            expected_err = "should be called before training loop starts"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                local_bs = 2
+                _batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+                offset = dist.get_rank() * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+                model_DDP._set_static_graph()
+
+            # Verify error was logged in ddp_logging_data.
+            verify_ddp_error_logged(model_DDP, expected_err)
+
+        @skipIfNoTorchVision
+        def test_SyncBatchNorm_process_group(self):
+            # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
+            # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
+            # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
+
+            process_ids = 0
+            process_group = torch.distributed.new_group([process_ids])
+            res50_model = torchvision.models.resnet50()
+            res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(res50_model), process_group
+            )
+            process_group_sync = res50_model_sync.layer1[0].bn1.process_group
+            self.assertEqual(process_group_sync, process_group)
+
+        def _run_reduction_test(
+            self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
+        ):
+            if reduction_fn is not dist.all_reduce and dst is None:
+                raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
+            if dst is not None:
+                reduction_fn(tensor, dst, op)
+                # Only destination rank tensor is expected to have final result.
+                if dist.get_rank() == dst:
+                    self.assertEqual(tensor, expected_tensor)
+            else:
+                reduction_fn(tensor, op)
+                self.assertEqual(tensor, expected_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allreduce(self):
+            torch.cuda.set_device(self.rank)
+            # Run all_reduce with PRODUCT
+            element = self.rank % 2 == 0
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([False, False]).to(self.rank), op
+                )
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(input_tensor, expected_tensor, op)
+
+            # Run all_reduce with SUM
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([True, True]).to(self.rank), op
+                )
+            # TODO: NCCL backend does not work correctly for bitwise reduction ops
+            # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
+            # these once it is supported.
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allgather(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, True]}
+            input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+            # Preserve a copy of the tensor to compare against after allgather.
+            input_tensor_copy = input_tensor.clone()
+            tensor_list = [
+                torch.tensor([False, False]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, input_tensor)
+
+            self.assertEqual(len(tensor_list), dist.get_world_size())
+            for i, t in enumerate(tensor_list):
+                expected = torch.tensor(inp[i % 2]).to(self.rank)
+                self.assertEqual(t, expected)
+            # Ensure that the input tensor is not modified, since this collective
+            # does not modify its input.
+            self.assertEqual(input_tensor_copy, input_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_nccl_backend_bool_reduce(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, False]}
+            # Run reduce() with product op
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                # make sure rank 0 gets False if WORLD_SIZE=1 to match expected tensor
+                input_tensor = torch.tensor(inp[(self.rank + 1) % 2]).to(self.rank)
+                expected = torch.tensor([False, False]).to(self.rank)
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(
+                    input_tensor, expected_tensor, op, dist.reduce, dst=0
+                )
+
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+                expected = (
+                    torch.tensor([True, True]).to(self.rank)
+                    if self.rank == 0
+                    else input_tensor.clone()
+                )
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_broadcast(self):
+            tensor_size = 10
+            bcast_tensor = torch.tensor(
+                [
+                    (random.random() < 0.5 if self.rank == 0 else False)
+                    for _ in range(tensor_size)
+                ]
+            ).to(self.rank)
+            dist.broadcast(bcast_tensor, src=0)
+            # Now allgather and ensure the tensors are equal.
+            tensor_list = [
+                torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, bcast_tensor)
+            expected = tensor_list[0]
+            for tensor in tensor_list[1:]:
+                self.assertEqual(tensor, expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedSampler_padding(self):
+            # Tests padding of distributed sampler.
+            world_size = dist.get_world_size()
+
+            # Simulates the 'casual' dataset size
+            dataset_size = 100 + world_size + 1
+            dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)]
+
+            # Simulates the 'tiny' dataset size
+            dataset_tiny_size = max(world_size // 2 - 1, 1)
+            dataset_tiny = [
+                torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)
+            ]
+
+            # Specifying drop_last=True will cause the tail of the data to be dropped.
+            dist_sampler = DistributedSampler(dataset=dataset, drop_last=True)
+            local_num_samples, local_dataset_size = (
+                dist_sampler.num_samples,
+                dist_sampler.total_size,
+            )
+            # The effective dataset size should be the greatest integer that is <=
+            # dataset_size that is divisible by the world_size. This is to ensure each
+            # rank processes the same number of samples.
+            effective_dataset_size = (
+                math.ceil((dataset_size - world_size) / world_size)
+                if dataset_size % world_size != 0
+                else dataset_size / world_size
+            )
+            self.assertEqual(local_num_samples, effective_dataset_size)
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            def validate_global_samples(local_num_samples):
+                # Ensure that each rank processes the same number of samples.
+                world_samples = [
+                    torch.LongTensor([0]).to(self.rank) for _ in range(world_size)
+                ]
+                dist.all_gather(
+                    world_samples, torch.tensor([local_num_samples]).to(self.rank)
+                )
+                world_samples = [sample.item() for sample in world_samples]
+                self.assertEqual(len(set(world_samples)), 1)
+
+            validate_global_samples(local_num_samples)
+
+            # drop_last=False is the default and will add additional indices to be sampled,
+            # increasing the effective dataset size.
+            dist_sampler_added_samples = DistributedSampler(dataset=dataset)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples.num_samples,
+                dist_sampler_added_samples.total_size,
+            )
+            # The effective dataset size is the smallest integer that is >= dataset_size
+            # and divisible by the world size.
+            self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size))
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            # Ensure that each rank processes the same number of samples.
+            validate_global_samples(local_num_samples)
+
+            # Ensure additional samples are padded even when
+            # the extremely small dataset is given.
+            dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples_tiny.num_samples,
+                dist_sampler_added_samples_tiny.total_size,
+            )
+            self.assertEqual(
+                local_num_samples, math.ceil(dataset_tiny_size / world_size)
+            )
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples_tiny))
+            self.assertEqual(len(indices_list), local_num_samples)
+            validate_global_samples(local_num_samples)
+
+        def _test_allgather_object(self, subgroup=None):
+            # Only set device for NCCL backend since it must use GPUs.
+
+            gather_objects = create_collectives_object_test_list()
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            output_gathered = [None for _ in range(dist.get_world_size())]
+            dist.all_gather_object(
+                output_gathered,
+                gather_objects[self.rank % len(gather_objects)],
+                group=subgroup,
+            )
+
+            for i, val in enumerate(output_gathered):
+                expected = gather_objects[i % len(gather_objects)]
+                self.assertEqual(val, expected)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        def test_all_gather_object_default_pg(self):
+            return self._test_allgather_object()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        def test_all_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_allgather_object(subgroup=subgroup)
+
+        def _test_gather_object(self, pg=None):
+            # Ensure stateful objects can be gathered
+            gather_objects = create_collectives_object_test_list()
+            my_rank = dist.get_rank(pg)
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=my_rank)))
+
+            output_gathered = [None for _ in range(dist.get_world_size(pg))]
+            gather_on_rank = 0
+            dist.gather_object(
+                gather_objects[self.rank % len(gather_objects)],
+                object_gather_list=output_gathered
+                if my_rank == gather_on_rank
+                else None,
+                dst=gather_on_rank,
+                group=pg,
+            )
+            if my_rank != gather_on_rank:
+                self.assertEqual(
+                    output_gathered, [None for _ in range(dist.get_world_size())]
+                )
+            else:
+                for i, val in enumerate(output_gathered):
+                    expected = gather_objects[i % len(gather_objects)]
+                    self.assertEqual(val, expected)
+
+            # Validate errors when objects can't be pickled.
+            class Bar:
+                pass
+
+            b = Bar()
+            gather_objects = [b for _ in range(dist.get_world_size())]
+            with self.assertRaises(AttributeError):
+                dist.all_gather_object(
+                    [None for _ in range(dist.get_world_size())],
+                    gather_objects[self.rank],
+                    group=pg,
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        @require_exact_world_size(4)
+        def test_gather_object(self):
+            return self._test_gather_object()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        @require_exact_world_size(4)
+        def test_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_gather_object(subgroup)
+
+        def validate_net_equivalence(self, net):
+            # Helper to validate synchronization of nets across ranks.
+            net_module_states = list(net.module.state_dict().values())
+            # Check that all tensors in module's state_dict() are equal.
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for tensor in tensor_list:
+                    self.assertEqual(tensor, t)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sync_module_states(self):
+            # Test that after calling _sync_module_states, models across ranks
+            # are the same and are equal to the model on the input rank.
+            dim = 2
+            rank = self.rank
+            rank_to_broadcast = 1
+            # Seed to ensure that ranks are initialized with different initial models.
+            torch.manual_seed(rank)
+            model = nn.Linear(dim, dim, bias=False)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
+            net.module = copy.deepcopy(new_model)
+            # Assert params are different
+            net_module_states = list(net.module.state_dict().values())
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for i, tensor in enumerate(tensor_list):
+                    if i == rank:
+                        self.assertEqual(t, tensor)
+                    else:
+                        # tensor from another rank should be different.
+                        self.assertNotEqual(t, tensor)
+
+            _sync_module_states(
+                module=net.module,
+                process_group=net.process_group,
+                broadcast_bucket_size=net.broadcast_bucket_size,
+                src=rank_to_broadcast,
+                params_and_buffers_to_ignore=net.parameters_to_ignore,
+            )
+            # Now all model params should be the same.
+            self.validate_net_equivalence(net)
+            # Since the network params were broadcast from rank_to_broadcast, validate that
+            # they are the same as new_model on rank_to_broadcast.
+            if rank == rank_to_broadcast:
+                expected_states = new_model.state_dict().values()
+                for t, expected in zip(net_module_states, expected_states, strict=True):
+                    self.assertEqual(t, expected)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_grad_div_uneven_inputs(self):
+            # Test gradient division during training with join() API. If
+            # divide_by_initial_world_size=False, we scale by the effective world
+            # size when allreducing grads.
+            dim = 5
+            batch = 1
+            grad_scale = 50
+            rank = self.rank
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.ones(batch, dim, device=self.rank) * grad_scale
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            n_iters = 3
+            if self.rank > 0:
+                n_iters += 2
+
+            with net.join(divide_by_initial_world_size=False):
+                for _ in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    # The grad is always expected_grad, since we divide by the number
+                    # of currently active processes and inactive processes contribute
+                    # zero gradient. If we kept dividing by static initial world
+                    # size as processes leave, the grad would be smaller.
+                    expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grads so that it's the same every iteration
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+            # If divide_by_initial_world_size=True (default), we always scale grads
+            # by the initial world_size.
+            with net.join(divide_by_initial_world_size=True):
+                for i in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    effective_ws = dist.get_world_size()
+                    if i >= 3:
+                        effective_ws -= 1
+                    expected_grad = (
+                        torch.ones(dim, dim, device=self.rank)
+                        * grad_scale
+                        * effective_ws
+                    ) / dist.get_world_size()
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grad so that it's the same every iteration.
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+        def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None):
+            """Runs DDP based model training and captures profiles.
+            This test will do two profiler runs.
+            1. An initial basic run to check if profiler events are correctly captured.
+            2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state.
+
+            args
+                profiler_ctx : Profiler context manager for pass 1
+                profiler_ctx2 : Profiler context manager for pass 2.
+                    This can be left out as None, in which case a deepcopy
+                    of profiler_ctx is used.
+            Returns:
+                prof: Instantiated profiler object that can be used for post analysis.
+            """
+            batch = 3
+            dim = 10
+            num_iters = 6
+            torch.cuda.set_device(self.rank)
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            if profiler_ctx2 is None:
+                profiler_ctx2 = copy.deepcopy(profiler_ctx)
+
+            with profiler_ctx as prof:
+                for _ in range(num_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+
+            all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in events)
+            self.assertEqual(event_count, num_iters)
+            for event in events:
+                self.assertTrue(event.is_async)
+                self.assertEqual(event.name, all_reduce_event_name)
+
+            broadcast_event_name = f"{dist.get_backend()}:broadcast"
+            broadcast_events = get_profiling_event(
+                broadcast_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in broadcast_events)
+            # Broadcast is called during rebuild_buckets
+            self.assertGreaterEqual(event_count, 1)
+            for event in broadcast_events:
+                self.assertEqual(event.name, broadcast_event_name)
+
+            # Run DDP with profiling for a few iterations, then enable profiling
+            # for a single pass, and ensure it is recorded. This tests that the
+            # thread local state is correctly updated.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            for _ in range(3):
+                loss = net(inp).sum()
+                loss.backward()
+            # Now enable the profiler.
+            with profiler_ctx2 as prof:
+                loss = net(inp).sum()
+                loss.backward()
+
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            self.assertGreaterEqual(len(events), 1)
+            self.assertGreaterEqual(events[0].count, 1)
+            self.assertEqual(events[0].name, all_reduce_event_name)
+            for event in events:
+                self.assertTrue(event.is_async)
+            # Ensure searching unused parameters was profiled
+            events = get_profiling_event("search_unused_parameters", prof)
+            self.assertEqual(len(events), 1)
+
+            return prof
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle("Currently failing in NVIDIA internal CI")
+        def test_ddp_profiling_autograd_profiler(self):
+            autograd_profiler_ctx = torch.autograd.profiler.profile()
+            return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_ddp_profiling_torch_profiler(self):
+            cpu_act = torch.profiler.ProfilerActivity.CPU
+            cuda_act = torch.profiler.ProfilerActivity.CUDA
+            torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act])
+            prof = self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx)
+
+            if dist.get_backend() != "nccl":
+                return
+
+            # Note comment out the "os.remove(trace_file)" in `get_profiler_nccl_meta()`
+            # to debug any mismatches.
+            nccl_meta_events = get_profiler_nccl_meta(prof)
+            self.assertGreater(len(nccl_meta_events), 0)
+
+            nccl_meta = self._sanity_check_profiler_nccl_meta(nccl_meta_events)
+
+            # additionally check the specific collectives in this test case
+            self.assertEqual(len(nccl_meta["allreduce"]), 2)
+            self.assertEqual(len(nccl_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = nccl_meta["allreduce"][0]
+            self.assertEqual(a0["Out msg nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = nccl_meta["allreduce"][1]
+            self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        def _validate_execution_trace_nccl(self, et_file: str) -> None:
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in these nodes in the Execution Trace.
+            """
+            with open(et_file) as f:
+                et = json.load(f)
+            pg_cfg_node = [
+                n for n in et["nodes"] if n["name"] == "## process_group:init ##"
+            ]
+            self.assertGreaterEqual(len(pg_cfg_node), 1)
+            nccl_meta_nodes = [
+                n for n in et["nodes"] if n["name"] == "record_param_comms"
+            ]
+            self.assertEqual(len(nccl_meta_nodes), 3)
+            per_coll_meta = defaultdict(list)
+
+            # Sanity check NCCL metadata nodes
+            for n in nccl_meta_nodes:
+                attrs_list = n.get("attrs", [])
+                self.assertGreater(len(attrs_list), 0)
+                attrs = {a["name"]: a["value"] for a in attrs_list}
+
+                collname = attrs.get("collective_name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(attrs.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(attrs)
+                if collname == "wait":
+                    continue
+
+                self.assertEqual(attrs["pg_name"], "0")  # yes this is a string
+                self.assertEqual(attrs["pg_desc"], "default_pg")
+                self.assertEqual(attrs["pg_size"], 2)
+
+                self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0)
+                self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0)
+                self.assertTrue("in_split_size" in attrs)
+                self.assertTrue("out_split_size" in attrs)
+                self.assertEqual(attrs.get("global_rank_start", -1), 0)
+                self.assertEqual(attrs.get("global_rank_stride", -1), 1)
+
+            # print(per_coll_meta)
+            self.assertEqual(len(per_coll_meta["allreduce"]), 2)
+            self.assertEqual(len(per_coll_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = per_coll_meta["allreduce"][0]
+            self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = per_coll_meta["allreduce"][1]
+            self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.")
+        def test_ddp_profiling_execution_trace(self):
+            self.assertEqual(dist.get_backend(), "nccl")
+            # Create a temp file to save execution trace data
+            with TemporaryFileName("w+t", suffix=".et.json") as et_file:
+                et = ExecutionTraceObserver().register_callback(et_file)
+
+                # first profiler context need not have ET
+                torch_profiler_ctx1 = torch.profiler.profile(
+                    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+                )
+                # collect ET in second profiler pass
+                torch_profiler_ctx2 = torch.profiler.profile(
+                    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+                    execution_trace_observer=et,
+                )
+                self._test_ddp_profiling(
+                    profiler_ctx=torch_profiler_ctx1,
+                    profiler_ctx2=torch_profiler_ctx2,
+                )
+
+                print(f"Execution trace saved at {et_file}")
+                self._validate_execution_trace_nccl(et_file)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_join_model_equivalence(self):
+            # Verifies equivalence with model training locally and with DDP under
+            # the join context manager.
+            batch = 3
+            dim = 10
+            learning_rate = 0.03
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            local_model = copy.deepcopy(model)
+            local_model = local_model.cuda(self.rank)
+            rank_to_iter_mapping = {
+                rank: 2 * (rank + 1) for rank in range(dist.get_world_size())
+            }
+            # run local model
+            local_iters = sum(rank_to_iter_mapping.values())
+            local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate)
+            for _ in range(local_iters):
+                local_optim.zero_grad()
+                out = local_model(inp)
+                loss = out.sum()
+                loss.backward()
+                local_optim.step()
+
+            # run DDP model with join API
+            num_iters = rank_to_iter_mapping[self.rank]
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            ddp_optim = torch.optim.SGD(
+                model.parameters(), lr=learning_rate * dist.get_world_size()
+            )
+            with net.join():
+                for _ in range(num_iters):
+                    ddp_optim.zero_grad()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    torch.cuda.synchronize(device=self.rank)
+                    ddp_optim.step()
+
+            # Validate model state dicts are equal
+            for (_, local_tensor), (_, dist_tensor) in zip(
+                local_model.state_dict().items(),
+                net.module.state_dict().items(),
+                strict=True,
+            ):
+                self.assertEqual(local_tensor, dist_tensor)
+
+        def _run_uneven_inputs_test(
+            self,
+            test_case,
+            iteration_mapping,
+            find_unused_params,
+        ):
+            model = test_case.model
+            inp = test_case.inp
+            rank = self.rank
+            sync_interval = test_case.sync_interval
+            torch.cuda.set_device(rank)
+            # Ensure all outstanding GPU work is completed so this test runs independently.
+            dist.barrier()
+            # Bucket_cap_mb is intentionally low to test allreduce scheduling when
+            # there are many buckets.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank),
+                device_ids=[rank],
+                bucket_cap_mb=1,
+                find_unused_parameters=find_unused_params,
+            )
+            # Register hook if specified
+            if test_case.hook is not None:
+                net.register_comm_hook(test_case.state, test_case.hook)
+                print(f"registered hook {test_case.hook}")
+
+            # Determine num iters for this rank via the passed in mapping.
+            num_iters = iteration_mapping[rank]
+            # If we throw when earliest rank terminates, we should ensure
+            # that we iterate for that minimum number of times.
+            num_iters_tensor = torch.tensor(
+                [num_iters], device=torch.cuda.current_device()
+            )
+            dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN)
+            min_num_iters = num_iters_tensor.item()
+            total_iters = 0
+            if test_case.throw_on_early_termination:
+                if min_num_iters == num_iters:
+                    # Early termination rank(s)
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+            else:
+                exception_ctx = nullcontext()
+            with exception_ctx:
+                with net.join(
+                    throw_on_early_termination=test_case.throw_on_early_termination
+                ):
+                    for i in range(num_iters):
+                        # Use model.no_sync() to disable grad synchronization every
+                        # sync_interval.
+                        if i % sync_interval != 0:
+                            context = net.no_sync()
+                        else:
+                            context = nullcontext()
+                        with context:
+                            if isinstance(inp, tuple):
+                                loss = net(*inp).sum()
+                            else:
+                                loss = net(inp).sum()
+                            loss.backward()
+                            self._model_step(net)
+                            # Ensure completion of GPU kernels (including allreduce). If the
+                            # join API is not properly implemented, then this should hang
+                            # since the allreduce will hang.
+                            torch.cuda.synchronize(device=rank)
+                        total_iters += 1
+            if test_case.throw_on_early_termination:
+                # Ensure we iterated min_num_iters times.
+                self.assertEqual(total_iters, min_num_iters)
+            else:
+                # Ensure we iterated at least min_num_iters times.
+                self.assertGreaterEqual(total_iters, min_num_iters)
+
+            # Ensure completion of all GPU kernels.
+            torch.cuda.synchronize(device=rank)
+            # When throwing on early rank termination, we do not
+            # broadcast model state from an authoritative rank. All models
+            # should already be in sync.
+            if not test_case.throw_on_early_termination:
+                self.assertTrue(net._authoritative_rank)
+                # All ranks should have agreed on the same authoritative_rank!
+                final_rank_tensor = torch.tensor(
+                    [net._authoritative_rank], device=self.rank
+                )
+                tensor_list = [
+                    torch.zeros_like(final_rank_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, final_rank_tensor)
+                max_rank = dist.get_world_size() - 1
+                self.assertSetEqual(
+                    {max_rank}, {tensor.item() for tensor in tensor_list}
+                )
+                # Ensure that all models are the same across ranks after all have joined.
+                self.validate_net_equivalence(net)
+                # Ensure that running with DDP uneven inputs was logged.
+                ddp_logging_data = net._get_ddp_logging_data()
+                self.assertTrue(ddp_logging_data.get("join_uneven_inputs"))
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs_stop_iteration_sync_bn(self):
+            # Tests that uneven inputs join handler correctly throws StopIteration
+            # for models with SyncBN or general collective comm when
+            # throw_on_early_termination=True.
+            class ModelWithComm(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(2, 40, bias=False)
+
+                def forward(self, x):
+                    x = self.lin(x)
+                    dist.all_reduce(x)
+                    return x
+
+            torch.cuda.set_device(self.rank)
+            model_bn = BatchNormNet()
+            model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(model_bn)
+            ).cuda(self.rank)
+            comm_model = ModelWithComm().cuda(self.rank)
+            model_input = torch.randn(10, 2).cuda(torch.cuda.current_device())
+
+            for model in [model_bn, comm_model]:
+                model = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                min_num_iters = 5
+                if self.rank != 0:
+                    # Early termination rank(s)
+                    num_iters = min_num_iters
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    num_iters = min_num_iters * 2
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+                n = 0
+                with exception_ctx:
+                    with model.join(throw_on_early_termination=True):
+                        for _ in range(num_iters):
+                            loss = model(model_input).sum()
+                            loss.backward()
+                            self._model_step(model)
+                            n += 1
+
+                self.assertEqual(n, min_num_iters)
+                # Verify model equivalence
+                self.validate_net_equivalence(model)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs(self):
+            dim = 1000
+            batch = 1
+            # Create a variety of models to run uneven input tests on.
+            large_model = nn.Sequential(
+                nn.Conv2d(1, 20, 5),
+                nn.ReLU(),
+                nn.Conv2d(20, 32, 5),
+                nn.ReLU(),
+                nn.Conv2d(32, 256, 5),
+                nn.ReLU(),
+            )
+            small_model = nn.Linear(dim, dim, bias=False)
+            bn_net = BatchNormNet()
+
+            class UnusedParamModule(nn.Module):
+                def __init__(self, unused_params_rank):
+                    super().__init__()
+                    self.t0 = Task()
+                    self.t1 = Task()
+                    self.unused_params_rank = unused_params_rank
+
+                def task_parameters(self):
+                    return (self.t0.p, self.t1.p)
+
+                def forward(self, x, rank):
+                    return (
+                        self.t1(self.t0(x))
+                        if rank != self.unused_params_rank
+                        else self.t1(x)
+                    )
+
+            unjoined_rank_with_unused_params_model = UnusedParamModule(1)
+            joined_rank_with_unused_params_model = UnusedParamModule(0)
+
+            rank = self.rank
+            models_to_test = [
+                # Network with batchnorm
+                DDPUnevenTestInput(
+                    name="batch_norm_net",
+                    model=bn_net,
+                    inp=torch.ones(batch, 2, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="large_conv_model",
+                    model=large_model,
+                    inp=torch.ones(batch, batch, dim, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model",
+                    model=small_model,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does not join early has unused params
+                DDPUnevenTestInput(
+                    name="unjoined_rank_with_unused_params_model",
+                    model=unjoined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does join early has unused params
+                DDPUnevenTestInput(
+                    name="joined_rank_with_unused_params_model",
+                    model=joined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+            ]
+
+            # Test models that have hook installed.
+            models_with_hook = [
+                DDPUnevenTestInput(
+                    name="small_model_allreduce_hook",
+                    model=small_model,
+                    hook=default.allreduce_hook,
+                    state=None,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model_power_sgd_hook",
+                    model=small_model,
+                    hook=powerSGD.powerSGD_hook,
+                    state=powerSGD.PowerSGDState(
+                        process_group=None,
+                        matrix_approximation_rank=1,
+                        # Config so that powerSGD runs immediately instead of
+                        # allreduce.
+                        start_powerSGD_iter=1,
+                        warm_start=False,
+                        use_error_feedback=False,
+                    ),
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+            ]
+            models_to_test.extend(models_with_hook)
+
+            # Add resnet model if we have torchvision installed.
+            if HAS_TORCHVISION:
+                resnet_model = torchvision.models.resnet50()
+                models_to_test.append(
+                    DDPUnevenTestInput(
+                        name="resnet_model",
+                        model=resnet_model,
+                        inp=torch.ones(1, 3, 1000, 1000),
+                        sync_interval=1,
+                    )
+                )
+
+            # Test with no_sync every 2, 3, 4, ... iterations.
+            models_with_sync = []
+            for i, test_input in enumerate(models_to_test):
+                models_with_sync.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=i + 2,
+                    )
+                )
+
+            throw_on_early_term_tests = []
+            for test_input in models_to_test:
+                throw_on_early_term_tests.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=test_input.sync_interval,
+                        throw_on_early_termination=True,
+                    )
+                )
+
+            models_to_test.extend(models_with_sync)
+            models_to_test.extend(throw_on_early_term_tests)
+
+            # 0 iteration tests for when one process does not train model at all, so
+            # we must shadow the broadcast calls made when rebuilding buckets.
+            baseline_num_iters = [0, 5]
+            iteration_offsets = [2, 3, 10]
+            num_uneven_ranks = [1]
+            if dist.get_world_size() > 2:
+                num_uneven_ranks.append(2)
+            iteration_mappings = []
+            # Generate rank : num_iters mappings for various uneven input scenarios.
+            # This includes cases where rank 0 joins early and all other ranks join
+            # later, and scenarios where multiple ranks join early, but at different
+            # iterations, and later ranks join later.
+            for num_early_join_ranks in num_uneven_ranks:
+                for baseline_iter in baseline_num_iters:
+                    for offset in iteration_offsets:
+                        mapping = dict.fromkeys(
+                            range(num_early_join_ranks), baseline_iter
+                        )
+                        # if num_early_join_ranks > 1, ranks > 0 that will join early
+                        # iterate offset//2 more times than rank 0, to test nodes
+                        # depleting inputs at different times.
+                        if num_early_join_ranks > 1:
+                            for rank in mapping:
+                                if rank > 0:
+                                    mapping[rank] += offset // 2
+                        mapping.update(
+                            dict.fromkeys(
+                                range(num_early_join_ranks, dist.get_world_size()),
+                                baseline_iter + offset,
+                            )
+                        )
+                        iteration_mappings.append(mapping)
+
+            for test_case, iteration_mapping in itertools.product(
+                models_to_test, iteration_mappings
+            ):
+                if self.rank == 0:
+                    print(
+                        f"""Running test: {test_case.name} sync interval
+                        {test_case.sync_interval} with iteration mapping
+                        {iteration_mapping}"""
+                    )
+                self._run_uneven_inputs_test(
+                    test_case,
+                    iteration_mapping,
+                    find_unused_params=("unused_params_model" in test_case.name),
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_join_disable(self):
+            # tests that if net.join() with enable=False is specified, DDP works as
+            # expected with even inputs.
+            torch.manual_seed(self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1) * self.rank
+            n_iters = 5
+            world_size = dist.get_world_size()
+            with net.join(enable=False):
+                for _ in range(n_iters):
+                    # Clear grads
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    # Validate gradients to ensure that we divide by the correct
+                    # world_size when join mode is disabled.
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(net.module.weight.grad.item(), expected_grad)
+
+            join_config = net._join_config
+            self.assertFalse(join_config.enable)
+            self.validate_net_equivalence(net)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_exception(self):
+            # Tests that exceptions during training are correctly propagated by the
+            # context manager.
+            error_str = "Intentional error"
+
+            class ExceptionModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.param = nn.Parameter(torch.ones(1, requires_grad=True))
+
+                def forward(self, _):
+                    raise ValueError(error_str)
+
+            exception_module = ExceptionModule()
+            net = torch.nn.parallel.DistributedDataParallel(
+                exception_module.cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1)
+            with self.assertRaisesRegex(ValueError, error_str):
+                with net.join():
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+        def _test_broadcast_object_list(self, group=None):
+            gather_objects = create_collectives_object_test_list()
+
+            # Only set device for NCCL backend since it must use GPUs.
+            # Case where rank != GPU device.
+            next_rank = (self.rank + 1) % int(self.world_size)
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                torch.cuda.set_device(next_rank)
+
+            src_rank = 0
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            if IS_FBCODE:
+                # Create Tensor with > 2^31 Bytes storage requirements
+                # Only on FBCODE as testing OOMs in OSS
+                gather_objects.append(Foo(torch.randn(3, 178956971)))
+            objects = (
+                gather_objects
+                if self.rank == src_rank
+                else [None for _ in gather_objects]
+            )
+
+            # Single object test with device specified. Backend="gloo", device=cpu
+            if backend != "nccl":
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device("cpu")
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="gloo", device=current_device+1
+            # The test is gated by the fact GPU count is the same as world size to avoid the case
+            # when backend is gloo but there is no multiple GPU devices.
+            if backend != "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="nccl", device=current_device+1
+            if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test: backward compatibility with device unspecified
+            single_obj_list = [objects[0]]
+            if self.rank != src_rank:
+                self.assertNotEqual(single_obj_list[0], gather_objects[0])
+            dist.broadcast_object_list(single_obj_list, src=0, group=group)
+            self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Multiple input objects test
+            if self.rank != src_rank:
+                self.assertNotEqual(objects, gather_objects)
+            dist.broadcast_object_list(objects, src=0, group=group)
+            self.assertEqual(objects, gather_objects)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_broadcast_object_list(self):
+            return self._test_broadcast_object_list()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        def _test_broadcast_object_list_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_broadcast_object_list(subgroup)
+
+        def _test_ddp_ignore_params_arg(self, static_graph=False):
+            class TestModel(nn.Module):
+                def __init__(self, rank):
+                    self.rank = rank
+                    super().__init__()
+                    self.fc1 = nn.Linear(1, 1, bias=False)
+                    # Proxy that will be materialized to another architecture later.
+                    # (after wrapping model with DDP)
+                    if self.rank == 0:
+                        self.fc2 = nn.Linear(1, 10, bias=False)
+                    else:
+                        self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc1(x)
+                    x = self.fc2(x)
+                    return x
+
+            device_id = self.rank
+            # Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
+            for find_unused, broadcast_buffers in itertools.product(
+                [False, True], [False, True]
+            ):
+                model = TestModel(self.rank).float().to(device_id)
+                # Note that the model can have different shape buffers if we pass
+                # them in to be ignored as well.
+                model.fc2.register_buffer(
+                    "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
+                )
+                proxy_params = list(model.fc2.parameters())
+                model_fc2_name = next(
+                    module_name
+                    for module_name, module in model.named_modules()
+                    if module is model.fc2
+                )
+                proxy_param_names = [
+                    f"{model_fc2_name}.{param_name}"
+                    for param_name, _ in model.fc2.named_parameters()
+                ]
+                proxy_buffer_names = [
+                    f"{model_fc2_name}.{buf_name}"
+                    for buf_name, _ in model.fc2.named_buffers()
+                ]
+                # Specify that we should ignore proxy_params since it will be
+                # materialized later.
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, proxy_param_names + proxy_buffer_names
+                )
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[device_id],
+                    find_unused_parameters=find_unused,
+                    broadcast_buffers=broadcast_buffers,
+                    static_graph=static_graph,
+                )
+                # Materialize new params. These are not registered in DDP and thus
+                # don't have autograd hooks installed on them.
+                ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
+
+                # local model with the new materialized parameters.
+                local_model = copy.deepcopy(ddp.module).cuda(self.rank)
+
+                inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
+                for _ in range(6):
+                    ddp(inp).sum().backward()
+
+                    local_model(inp).sum().backward()
+                    # materialized param grad is not touched by DDP, so its grad should
+                    # be the same as if running locally.
+                    for materialized_param, local_param in zip(
+                        ddp.module.fc2.parameters(),
+                        local_model.fc2.parameters(),
+                        strict=True,
+                    ):
+                        self.assertEqual(materialized_param.grad, local_param.grad)
+
+                    # fc1 parameter grad should still be different, due to allreduce.
+                    for synced_param, local_param in zip(
+                        ddp.module.fc1.parameters(),
+                        local_model.fc1.parameters(),
+                        strict=True,
+                    ):
+                        self.assertFalse(synced_param.grad == local_param.grad)
+
+                    # Proxy module grad should not be touched
+                    for proxy_param in proxy_params:
+                        self.assertTrue(proxy_param.grad is None)
+
+                # Synchronize since we run multiple iterations of this test, to
+                # isolate failure hangs.
+                torch.cuda.synchronize(device=self.rank)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_ignore_params_arg(self):
+            self._test_ddp_ignore_params_arg(static_graph=False)
+            self._test_ddp_ignore_params_arg(static_graph=True)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_unused_params_rebuild_buckets_exception(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    return self.net1(x)
+
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().cuda(self.rank), device_ids=[self.rank]
+            )
+            for i in range(2):
+                inp = torch.rand(1, 10)
+                if i > 0:
+                    # On 2nd iteration, this will fail during rebuild_buckets,
+                    # but we should report an error regarding unused parameters
+                    # since that is the underlying root cause.
+                    try:
+                        ddp(inp).sum().backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(ddp, msg)
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["net2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(
+                            True, "DDP unused parameters error not raised."
+                        )
+                else:
+                    ddp(inp).sum().backward()
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_shared_grad_acc_unused_params(self):
+            # When find_unused_parameters=True, ensure we mark unused parameters
+            # even if they share gradient accumulators.
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    # net1, bias, and net1.bias are all unused params.
+                    self.net1 = nn.Linear(10, 5, bias=False)
+                    self.bias = nn.Parameter(torch.zeros(5))
+                    # net1.bias and self.bias are names for the same underlying
+                    # parameter, so they share the same grad acc. This caused
+                    # the bug reported in https://github.com/pytorch/pytorch/issues/41324.
+                    self.net1.bias = self.bias
+                    self.net2 = nn.Linear(10, 5)
+
+                def forward(self, x):
+                    return self.net2(x).sum()
+
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().to(torch.cuda.current_device())
+            for static in [True, False]:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    find_unused_parameters=True,
+                    static_graph=static,
+                )
+                inp = torch.randn(20, 10, device=self.rank)
+                for _ in range(6):
+                    loss = ddp_model(inp)
+                    # To test https://github.com/pytorch/pytorch/issues/61982
+                    loss /= 10
+                    loss.backward()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device(self):
+            expected_len = 2
+
+            class TensorWrapper:
+                __slots__ = ["t", "moved_to_gpu"]
+
+                def __init__(self, t):
+                    self.t = t
+                    self.moved_to_gpu = False
+
+            # Handlers for specific types of validation we want to do based on
+            # the input type.
+
+            def tuple_and_list_validator(x):
+                self.assertTrue(len(x), expected_len)
+                self.assertEqual(1, len({t.device for t in x}))
+                self.assertEqual(x[0].device.index, self.rank)
+                return x[0] + x[1]
+
+            def namedtuple_validator(x):
+                self.assertEqual(x._fields, EXPECTED_FIELDS)
+                self.assertEqual(x.a.device.index, x.b.device.index)
+                self.assertEqual(x.a.device.index, self.rank)
+                return x.a + x.b
+
+            def custom_type_validator(x):
+                self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
+                x.t = x.t.to(self.rank)
+                x.moved_to_gpu = True
+                return x.t
+
+            def dict_validator(x):
+                self.assertTrue(EXPECTED_FIELDS[0] in x)
+                self.assertTrue(EXPECTED_FIELDS[1] in x)
+                self.assertEqual(1, len({t.device for t in x.values()}))
+                self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
+                return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]
+
+            validators = {
+                TensorWrapper: custom_type_validator,
+                tuple: tuple_and_list_validator,
+                list: tuple_and_list_validator,
+                TestNamedTupleInput_0: namedtuple_validator,
+                TestNamedTupleInput_1: namedtuple_validator,
+                dict: dict_validator,
+            }
+
+            class ToyModel(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 10, bias=False)
+
+                def forward(self_, x, expected_type):  # noqa: B902
+                    # Similar to scatter, the recursive to in the single-device
+                    # case does not move tensors if they are in a custom type.
+                    self.assertTrue(isinstance(x, expected_type))
+                    fwd_tensor = validators[expected_type](x)
+                    return self_.lin(fwd_tensor)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().to(self.rank), device_ids=[self.rank]
+            )
+
+            def train_iter(inp, input_type):
+                for _ in range(4):
+                    out = model(inp, input_type)
+                    out.sum().backward()
+
+            # CPU tuple input, should be moved to the proper device before call
+            # to forward.
+            inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
+            train_iter(inp, tuple)
+
+            # List CPU input, should be moved to proper device before call to
+            # forward.
+            inp = [torch.randn(10, 10) for _ in range(expected_len)]
+            train_iter(inp, list)
+            # Custom type containing tensor. The type is maintained, but the
+            # device is not propagated (which is what happens with scatter too)
+            inp = TensorWrapper(torch.randn(10, 10))
+            train_iter(inp, TensorWrapper)
+            # NamedTuple input. The type should be maintained and tensor inputs
+            # should be moved to the correct device as in scatter.
+            batch = 5
+            dim = 10
+            a = torch.rand(batch, dim)
+            b = torch.rand(batch, dim)
+
+            inp = TestNamedTupleInput_0(a, b)
+            train_iter(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            train_iter(inp, type(inp))
+
+            # dictionary input.
+            inp = {
+                EXPECTED_FIELDS[0]: a,
+                EXPECTED_FIELDS[1]: b,
+            }
+            train_iter(inp, type(inp))
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_namedtuple(self):
+            batch = 5
+            dim = 10
+
+            a = torch.rand(batch, dim, device=self.rank)
+            b = torch.rand(batch, dim, device=self.rank)
+
+            class NamedTupleModule(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 1)
+
+                def forward(self_, input, expected_type):  # noqa: B902
+                    # Without NamedTuple support, this would be of type tuple.
+                    self.assertTrue(
+                        isinstance(input, expected_type),
+                        f"Expected type {expected_type} but got {type(input)}",
+                    )
+                    self.assertEqual(input._fields, EXPECTED_FIELDS)
+                    self.assertEqual(a, input.a)
+                    self.assertEqual(b, input.b)
+                    return self_.lin(torch.mul(input.a, input.b))
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = TestNamedTupleInput_0(a, b)
+            # The following would fail if DDP does not propagate NamedTuples correctly.
+            model(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            model(inp, type(inp))
+
+        @require_backend_is_available({"gloo"})
+        def test_grads_same_across_ranks_with_no_sync(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            if world_size < 2:
+                self.skipTest("This test requires at least two ranks.")
+
+            class SimpleConditionalModel(nn.Module):
+                # if rank is 0, uses nn1 on the first pass and nn2 on the second pass.
+                # else, uses nn3 on the first pass and nn4 on the second pass.
+
+                def __init__(self, rank):
+                    super().__init__()
+
+                    self.rank = rank
+                    self.nn1 = nn.Linear(1, 1)
+                    self.nn2 = nn.Linear(1, 1)
+                    self.nn3 = nn.Linear(1, 1)
+                    self.nn4 = nn.Linear(1, 1)
+                    self.state = 0
+
+                def forward(self, input):
+                    if self.state == 0:
+                        self.state = 1
+                        if self.rank == 0:
+                            return self.nn1(input)
+                        else:
+                            return self.nn3(input)
+                    else:
+                        self.state = 0
+                        if self.rank == 0:
+                            return self.nn2(input)
+                        else:
+                            return self.nn4(input)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                SimpleConditionalModel(rank), find_unused_parameters=True
+            )
+            mse_loss = nn.MSELoss()
+            grad_accumulation = 2
+
+            for microbatch_idx in range(grad_accumulation):
+                if microbatch_idx < grad_accumulation - 1:
+                    context = model.no_sync
+                else:
+                    context = nullcontext
+
+                with context():
+                    input = torch.rand((1,))
+                    output = model.forward(input)
+                    target = torch.rand((1,))
+
+                    loss = mse_loss(output, target)
+                    loss.backward()
+
+            self.assertTrue(
+                not any(p.grad is None for p in model.parameters()),
+                "Gradients can't be None for any model parameter.",
+            )
+            grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
+
+            # Gather all gradients to rank 0.
+            if rank == 0:
+                gathered_grads = [torch.zeros_like(grads) for _ in range(world_size)]
+            else:
+                gathered_grads = []
+
+            dist.gather(grads, gather_list=gathered_grads, dst=0)
+            if rank == 0:
+                for g in gathered_grads[1:]:
+                    self.assertTrue(
+                        torch.allclose(gathered_grads[0], g),
+                        "Gradients are not the same for all ranks.",
+                    )
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_same_across_ranks(self):
+            # Control flow that is the same across ranks.
+            batch = 20
+            dim = 10
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used.
+                local_used_map = model.reducer._get_local_used_map()
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, world_size], device=self.rank, dtype=torch.int32
+                    )
+
+                # Validate parameter usage.
+                variable_usage_tensor = local_used_map
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        # 2nd linear layer is unused
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_invalid_static_graph(self):
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                static_graph=True,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            ones_input = torch.ones(20, 10, device=self.rank)
+            # unused parameter in the first iteration got used
+            # in second iteration.
+            expected_err = "Your training graph has changed in this iteration"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                for i in range(2):
+                    if i % 2 == 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, expected_err)
+
+            # used parameter in the first iteration got unused
+            # in second iteration.
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Expected to have finished reduction in the prior iteration "
+                "before starting a new one. This error indicates that your "
+                "training graph has changed in this iteration, "
+                "e.g., one parameter is used in first iteration, "
+                "but then got unused in the second iteration. "
+                "this is not compatible with static_graph set to True.\n"
+                "Parameter indices which did not receive grad for",
+            ):
+                for i in range(2):
+                    if i % 2 != 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, "Expected to have finished reduction")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_different_across_ranks(self):
+            # Control flow that is different across ranks.
+            batch = 20
+            dim = 10
+
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    # Control-flow that is rank and input dependent for the
+                    # model.
+                    use_second_layer = (
+                        torch.equal(x, torch.ones(batch, dim, device=x.device))
+                        and self.rank == 1
+                    )
+
+                    if use_second_layer:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used only on rank 1.
+                local_used_map = model.reducer._get_local_used_map()
+
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, 1], device=self.rank, dtype=torch.int32
+                    )
+
+                variable_usage_tensor = local_used_map
+                # Validate parameter usage. On odd iterations, 2nd param is only
+                # used on rank 1.
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_scatter_object_list(self):
+            src_rank = 0
+            collectives_object_test_list = create_collectives_object_test_list()
+            scatter_list = (
+                collectives_object_test_list
+                if self.rank == src_rank
+                else [None for _ in collectives_object_test_list]
+            )
+            world_size = dist.get_world_size()
+            scatter_list = scatter_list[:world_size]
+            i = 0
+            while len(scatter_list) < world_size:
+                scatter_list.append(scatter_list[i])
+                i += 1
+
+            output_obj_list = [None]
+            dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
+            self.assertEqual(
+                output_obj_list[0],
+                collectives_object_test_list[
+                    self.rank % len(collectives_object_test_list)
+                ],
+            )
+            # Ensure errors are raised upon incorrect arguments.
+            with self.assertRaisesRegex(
+                ValueError,
+                "Expected argument scatter_object_output_list to be a list of size at least 1.",
+            ):
+                dist.scatter_object_list([], scatter_list, src=src_rank)
+
+        def _generate_sparse_tensors_for_bucket_assignment_test(self):
+            tensors = [
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+            ]
+
+            tensors_sparse = [t.to_sparse() for t in tensors]
+            return tensors_sparse
+
+        def _test_compute_bucket_assignment_by_size(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
+            # determinism.
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            # We never actually use the rest of the model - we only need its logger.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # if we don't pass a logger then we can only check that an exception was thrown.
+            expected_err = "No support for sparse tensors."
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                tensors_sparse = (
+                    self._generate_sparse_tensors_for_bucket_assignment_test()
+                )
+                if use_logger:
+                    dist._compute_bucket_assignment_by_size(
+                        tensors_sparse, [400], logger=net.logger
+                    )
+                else:
+                    dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
+            if use_logger:
+                verify_ddp_error_logged(net, expected_err)
+
+            # Perform gloo-based barrier to ensure one rank doesn't exit test
+            # early which causes failure with Barrier.sync.
+            dist.barrier(group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=False)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=True)
+
+        def _test_verify_model_across_rank(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # Modify the model so that the number of parameters are different for each rank.
+            # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes,
+            # so we can check if the correct error is thrown and is logged.
+            # We can't do this in the constructor above otherwise the logger will
+            # not be properly initialized.
+            net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
+
+            # if we pass a logger we can verify that it was logged
+            caught = 0
+            try:
+                if use_logger:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters()), net.logger
+                    )
+                else:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters())
+                    )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_with_logger(self):
+            self._test_verify_model_across_rank(use_logger=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_without_logger(self):
+            self._test_verify_model_across_rank(use_logger=False)
+
+        def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
+            caught = 0
+            try:
+                net = torch.nn.parallel.DistributedDataParallel(
+                    net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
+                )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_shape_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+            # Creates network with different sized embedding table on different
+            # ranks. This should throw an error during DDP init.
+            net = EmbeddingNetDifferentParams(self.rank)
+            self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_num_params_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Creates network with diff # of param across ranks, reducer should
+            # recognize this and throw appropriate error.
+            net = EmbeddingNetDifferentParams(
+                self.rank, diff_num_params=(self.rank == 1)
+            )
+
+            self._run_test_ddp_model_with_diff_params(
+                net,
+                group_to_use,
+                group_gloo,
+            )
+
+        def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view):
+            model = module_cls()
+            local_net = copy.deepcopy(model)
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+
+            # Tests that certain parameters not getting gradient since the
+            # output is unused in loss computation is supported. Specifically,
+            # checks that the grads remain unchanged and are the same as local
+            # training.
+            inp = torch.randn(10, 10)
+
+            # Ensure that if a param is not used in loss computation, its
+            # gradient is untouched, i.e. if it is None before it is None after,
+            # not zero.
+            if module_cls == DictOutputModule:
+                a, b = local_net(inp)["predictions"]
+                a_dist, b_dist = net(inp)["predictions"]
+            else:
+                a, b = local_net(inp)
+                a_dist, b_dist = net(inp)
+
+            loss_dist = b_dist.sum()
+            loss_dist.backward()
+
+            # Ensure that gradient corresponding to parameter "a" was not
+            # touched, i.e. it is None and matches the local grad.
+            if module_cls == DictOutputModule:
+                self.assertTrue(net.module.module.a.weight.grad is None)
+                self.assertEqual(
+                    net.module.module.a.weight.grad, local_net.module.a.weight.grad
+                )
+            else:
+                self.assertTrue(net.module.a.weight.grad is None)
+                self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad)
+
+            saved_a_local_grad = None
+            saved_a_dist_grad = None
+            net.zero_grad()
+            local_net.zero_grad()
+            for i in range(6):
+                if module_cls == DictOutputModule:
+                    a, b = local_net(inp)["predictions"]
+                    a_dist, b_dist = net(inp)["predictions"]
+                else:
+                    a, b = local_net(inp)
+                    a_dist, b_dist = net(inp)
+                if i < 2:
+                    # Use both params in loss computation. Later, "a" will go
+                    # unused and we check to ensure DDP supports this and
+                    # gradients remain the same as local training.
+                    t = a @ b
+                    t_dist = a_dist @ b_dist
+                    loss = t.sum()
+                    loss_dist = t_dist.sum()
+                else:
+                    # Model output "a" unused in loss.
+                    loss = b.sum()
+                    loss_dist = b_dist.sum()
+                loss.backward()
+                loss_dist.backward()
+                if i == 1:
+                    # Save grads to compare with them in next iterations.
+                    if module_cls == DictOutputModule:
+                        saved_a_local_grad = local_net.module.a.weight.grad
+                        saved_a_dist_grad = net.module.module.a.weight.grad
+                    else:
+                        saved_a_local_grad = local_net.a.weight.grad
+                        saved_a_dist_grad = net.module.a.weight.grad
+                    self.assertEqual(saved_a_local_grad, saved_a_dist_grad)
+                elif i >= 2:
+                    # parameter "a" of both models should be the same and not change
+                    if module_cls == DictOutputModule:
+                        self.assertEqual(
+                            net.module.module.a.weight.grad, saved_a_dist_grad
+                        )
+                        self.assertEqual(
+                            local_net.module.a.weight.grad, saved_a_local_grad
+                        )
+                    else:
+                        self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad)
+                        self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
+
+                # Verify grads are the same
+                for local_param, dist_param in zip(
+                    local_net.parameters(), net.parameters(), strict=True
+                ):
+                    local_grad = local_param.grad
+                    dist_grad = dist_param.grad
+                    self.assertEqual(local_grad, dist_grad)
+
+            dist.barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_tuple_module(self):
+            module_cls = UnusedParamTwoLinLayerNet
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_dict_module(self):
+            module_cls = DictOutputModule
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_undefined_grad_parity_unused_parameters(self):
+            # TODO: enable this for general training use cases:
+            # https://github.com/pytorch/pytorch/issues/58511.
+            x = torch.ones(1, 2).to(self.rank)
+            net = Net().to(self.rank)
+            local_net = copy.deepcopy(net)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            out = net(x).sum()
+            local_out = local_net(x).sum()
+            # Simulates undefined gradients.
+            torch._C._functions.UndefinedGrad()(out).backward()
+            torch._C._functions.UndefinedGrad()(local_out).backward()
+            for (dist_param_name, dist_param), (local_param_name, local_param) in zip(
+                net.named_parameters(), local_net.named_parameters(), strict=True
+            ):
+                dist_grad = dist_param.grad
+                local_grad = local_param.grad
+                self.assertEqual(
+                    dist_grad,
+                    local_grad,
+                    f"""DDP param {dist_param_name} with grad {dist_grad}
+                    does not match local param {local_param_name} with grad
+                    {local_grad}""",
+                )
+
+        def _test_different_graph_across_ranks(
+            self, find_unused_parameters=False, static_graph=False
+        ):
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    if self.rank == 0:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = ToyModel(self.rank).cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=find_unused_parameters,
+                gradient_as_bucket_view=True,
+                static_graph=static_graph,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_different_graph_across_ranks(self):
+            base_model = self._test_different_graph_across_ranks(
+                find_unused_parameters=True
+            )
+            self.assertFalse(
+                base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            static_model = self._test_different_graph_across_ranks(static_graph=True)
+            self.assertTrue(
+                static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            for i, j in zip(
+                base_model.parameters(), static_model.parameters(), strict=True
+            ):
+                self.assertEqual(i, j)
+
+        @require_backend_is_available({"gloo"})
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_gloo(self):
+            tensors = [torch.ones(10) * self.rank]
+            # Kick off some allreduce work on all ranks
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            # Run monitored barrier and ensure it passes
+            timeout = timedelta(seconds=2)
+            dist.monitored_barrier(timeout=timeout)
+            # Check monitored_barrier success with wait_all_ranks=True
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+            # All ranks besides 1 call into barrier, rank 0 should report failure
+            # while others report gloo error.
+            failed_rank = 1
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank != failed_rank:
+                # Other ranks should not pass barrier since rank 0 failed.
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+            # We need a barrier since otherwise failed_rank exits too early
+            # and cause a timeout.
+            self._barrier(timeout=30)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_subgroup(self):
+            # Tests that monitored_barrier works as expected on non-default
+            # process groups.
+            failed_rank = 1
+            timeout = 0.1
+            subgroup = dist.new_group(ranks=[0, 1])
+
+            if self.rank == failed_rank:
+                return
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(subgroup, timeout)
+            else:
+                # Other ranks call into monitored_barrier, but this should be a
+                # noop because they are not part of the subgroup. Verify that
+                # there are no errors here.
+                dist.monitored_barrier(subgroup, timeout)
+
+        def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks):
+            # tests expected behavior when nonzero rank hangs.
+            nccl_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                # provide sufficient timeout so communicators
+                # can be initialized in ctor.
+                timeout=timedelta(seconds=15),
+                backend=dist.Backend.NCCL,
+            )
+            gloo_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                backend=dist.Backend.GLOO,
+            )
+            tensors = [torch.ones(10, device=self.rank) * self.rank]
+            # Let all ranks call allreduce first to set up communicators etc.
+            # Directly simulating error here will run into store issue described
+            # in https://github.com/pytorch/pytorch/issues/54524.
+            nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
+            # All ranks besides 0 call into allreduce. This is to simulate a
+            # desync across the world, where some ranks call into
+            # monitored_barrier() and others are stuck in collective comm. In
+            # practice, we don't need TORCH_NCCL_BLOCKING_WAIT, but we use it in this
+            # test to ensure it exits cleanly.
+            if self.rank != 0:
+                # Can get different errors here depending on whether gloo-based
+                # wrapper PG is enabled or not, since with wrapper pg, it will
+                # fail in a collective synchronization check and not actually
+                # call into the nccl pg.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    err_regex = "Timed out waiting"
+                else:
+                    err_regex = "caught collective operation timeout"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1))
+            else:
+                # Rank 0 should report first (in order) timed out rank or all ranks
+                # depending on wait_all_ranks flag passed into monitored_barrier.
+                if wait_all_ranks:
+                    rank_str = ", ".join(
+                        [str(i) for i in range(1, int(self.world_size))]
+                    )
+                    err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                else:
+                    expected_first_fail_rank = 1
+                    err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier"
+                monitored_barrier_timeout_seconds = timedelta(seconds=0.1)
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    gloo_pg.monitored_barrier(
+                        monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
+                    )
+
+            self._barrier(timeout=30)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang(self):
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report first timed out rank.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang_wait_all_ranks(self):
+            # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
+            os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report all timed out ranks.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_rank_0_timeout(self):
+            # tests error when rank 0 exhausts its given timeout.
+            process_group = dist.new_group(ranks=list(range(int(self.world_size))))
+            timeout = timedelta(seconds=0)
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier"
+                ):
+                    process_group.monitored_barrier(timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_failure_order(self):
+            # Ensure that the first (in sorted order) rank is reported when
+            # multiple ranks fail to pass the monitored_barrier.
+            # TODO(#54879): Provide ability to wait and report all failed ranks
+            expected_first_failed_rank = 2
+            timeout = timedelta(seconds=2)
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {expected_first_failed_rank}"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank == 1:
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        def test_monitored_barrier_wait_all_ranks(self):
+            # Tests simple case where > 1 rank does not call into monitored
+            # barrier and verifies all ranks are reported by rank 0.
+            if self.rank == 0:
+                timeout = timedelta(seconds=0.1)
+                rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))])
+                err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping(self):
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "a.weight", 1: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test when DDP is used with ignored parameters.
+            model = TwoLinLayerNet()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test errors are raised when DDP and module parameters mismatch.
+            # This generally indicates a bug with DDP and is not expected to
+            # happen in user applications.
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            net_params, _ = net._build_params_for_reducer()
+            if self.rank == 0:
+                print(type(net_params[0]))
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+            with self.assertRaisesRegex(ValueError, "Expected param to name mapping"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params = net_params[:-3]
+            with self.assertRaisesRegex(ValueError, "Param with name"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping_requires_grad(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    # Is not tracked by DDP and should not show up in param to
+                    # name mapping.
+                    self.lin.bias.requires_grad_(False)
+
+                def forward(self, x):
+                    return self.lin(x)
+
+            model = Net()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            expected_mapping = {
+                0: "lin.weight",
+            }
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertEqual(param_to_name_mapping, expected_mapping)
+
+        def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse):
+            debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF
+
+            class SubModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.embedding_net = EmbeddingNetDifferentParams(0)
+                    self.lin = TwoLinLayerNet()
+                    self.bn = BatchNormNet()
+                    self.lin_layer = nn.Linear(4, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.bn(x)
+                    x = self.lin_layer(x)
+                    x = self.lin.a(x)  # self.lin.b param unused
+                    # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and
+                    # self.embedding_net.lin unused.
+                    return x
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.sub_module = SubModule()
+
+                def forward(self, x):
+                    return self.sub_module(x)
+
+            model = MyModel()
+            sparse_embedding_fqns = []
+            if ignore_sparse:
+                for module_name, module in model.named_modules():
+                    if module == model.sub_module.embedding_net.embedding:
+                        for parameter_name, _param in module.named_parameters(
+                            recurse=False
+                        ):
+                            fqn = f"{module_name}.{parameter_name}"
+                            sparse_embedding_fqns.append(fqn)
+
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, sparse_embedding_fqns
+                )
+                unused_modules = [
+                    model.sub_module.embedding_net.lin,
+                    model.sub_module.lin.b,
+                ]
+            else:
+                unused_modules = list(model.sub_module.embedding_net.modules()) + [
+                    model.sub_module.lin.b,
+                ]
+
+            expected_unused_param_fqns = []
+            used_param_fqns = []  # Validate that these don't mistakenly show up.
+            fqn_to_param_index = {}
+            index = 0
+            for module_name, module in model.named_modules():
+                for parameter_name, _param in module.named_parameters(recurse=False):
+                    fqn = f"{module_name}.{parameter_name}"
+                    fqn_to_param_index[fqn] = index
+                    if fqn not in sparse_embedding_fqns:
+                        index += 1
+                    if module in unused_modules:
+                        expected_unused_param_fqns.append(fqn)
+                    else:
+                        if (
+                            not ignore_sparse
+                            or module != model.sub_module.embedding_net.embedding
+                        ):
+                            used_param_fqns.append(fqn)
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            batch, dim = 10, 2
+            inp = torch.ones(batch, dim)
+            for i in range(2):
+                if i == 0:
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                else:
+                    try:
+                        out = net(inp)
+                        loss = out.sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        e = str(e)
+
+                        unused_param_substr = e[e.find("did not receive grad") :]
+                        # Validate that each unused param fully qualified name
+                        # shows up in error logs. We do this instead of
+                        # constructing a joined string since order of parameters
+                        # can be different in Reducer. In addition, validate
+                        # param indices show up as well.
+                        for unused_param_fqn in expected_unused_param_fqns:
+                            self.assertTrue(
+                                unused_param_fqn in unused_param_substr
+                                or debug_mode_off
+                            )
+                            self.assertTrue(
+                                str(fqn_to_param_index[unused_param_fqn])
+                                in unused_param_substr,
+                                f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}",
+                            )
+
+                        # Validate that used param fqns don't show up in error
+                        # logs.
+                        for used_param_fqn in used_param_fqns:
+                            self.assertFalse(used_param_fqn in unused_param_substr)
+                        # Validate that ignored param fqns don't show up as unused
+                        # (since DDP does not track them)
+                        for sparse_param_fqn in sparse_embedding_fqns:
+                            self.assertFalse(sparse_param_fqn in unused_param_substr)
+                    else:
+                        self.assertTrue(False, "Expected error was not raised!")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_error(self):
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_err_ignore_params(self):
+            # Tests unused parameter reporting when DDP is configured to ignore
+            # certain parameters.
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_inference(self):
+            # tests that DDP module can be run on a single node with no_grad
+            # or eval setting and there is no hang.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            model = Net().cuda()
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            syncbn_model = nn.SyncBatchNorm(
+                2, momentum=0.99, track_running_stats=False
+            ).cuda()
+            local_syncbn_model = copy.deepcopy(syncbn_model)
+            syncbn_model = torch.nn.parallel.DistributedDataParallel(
+                syncbn_model, device_ids=[rank]
+            )
+            inp = torch.randn(10, 2, device=rank)
+            inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
+            tests = [
+                (model, local_model, inp),
+                (syncbn_model, local_syncbn_model, inp_syncbn),
+            ]
+            for test in tests:
+                test_model, test_local_model, test_inp = test
+                if self.rank == 0:
+                    test_model.eval()
+                    test_local_model.eval()
+                    for _ in range(6):
+                        self.assertEqual(
+                            test_model(test_inp), test_local_model(test_inp)
+                        )
+
+            # Barrier since only rank 0 runs inference. Test should be
+            # much faster than 30s, but this is to avoid flakiness.
+            self._barrier(timeout=30)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_ddp_sync_bn_training_vs_eval(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            # Need to set track_running_stats=False, when track_running_stats=True,
+            # bn_training is False and sync could not occur in eval model.
+            model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda(
+                rank
+            )
+            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+            # Test sync occurs in training mode.
+            with torch.autograd.profiler.profile() as prof:
+                for _ in range(6):
+                    inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                    out = model(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+            # SyncBN allgathers stats across all ranks, so verify call to
+            # all_gather in profiler.
+            if BACKEND == "nccl":
+                all_gather_calls = get_profiling_event("_all_gather_base", prof)
+            else:
+                all_gather_calls = get_profiling_event("all_gather", prof)
+            self.assertNotEqual([], all_gather_calls)
+
+            # Only do inference on one rank. If SyncBN did collective stats sync,
+            # this would hang/error.
+            model_inference = model.module
+            if self.rank == 0:
+                model_inference.eval()
+                with torch.autograd.profiler.profile() as prof:
+                    for _ in range(6):
+                        inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                        out = model_inference(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                # Ensure sync does not occur in eval() mode.
+                if BACKEND == "nccl":
+                    all_gather_calls = get_profiling_event("_all_gather_base", prof)
+                else:
+                    all_gather_calls = get_profiling_event("all_gather", prof)
+                self.assertEqual([], all_gather_calls)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_python_error_logged(self):
+            # Most python exceptions in DDP are raised during init before
+            # reducer is constructed, so we don't have a logger in those cases.
+            # However, the below is one example where a python error is thrown
+            # after reducer is constructed.
+            model = TwoLinLayerNet().cuda(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            expected_err = "must be callable"
+            with self.assertRaisesRegex(TypeError, expected_err):
+                model.register_comm_hook({}, {})
+
+            verify_ddp_error_logged(model, expected_err)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_static_graph_nested_types(self):
+            # Tests for static graph training when outputs are not just tensors
+            # but can be (nested) tuple, list, dict, etc.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+
+            class NestedOutputModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(100, 1, bias=False)
+
+                def forward(self, inp, output_type):
+                    if output_type == "tuple":
+                        return (
+                            self.lin(inp),
+                            (
+                                self.lin(inp),
+                                self.lin(inp),
+                            ),
+                        )
+                    elif output_type == "list":
+                        return [
+                            self.lin(inp),
+                            [
+                                self.lin(inp),
+                                self.lin(inp),
+                            ],
+                        ]
+                    elif output_type == "dict":
+                        return {
+                            "a": self.lin(inp),
+                            "b": {
+                                "c": self.lin(inp),
+                            },
+                        }
+
+            def get_loss(model_output):
+                loss = 0.0
+                if isinstance(model_output, torch.Tensor):
+                    return model_output.sum()
+                elif isinstance(model_output, dict):
+                    for value in model_output.values():
+                        loss += get_loss(value)
+                elif isinstance(model_output, (tuple, list)):
+                    for x in model_output:
+                        loss += get_loss(x)
+                else:
+                    raise ValueError(f"Unknown model output type {type(model_output)}")
+                return loss
+
+            model = NestedOutputModule().cuda(rank)
+            model_static_graph = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            model_static_graph = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+                static_graph=True,
+            )
+            inp = torch.randn(10, 100)
+            type_mapping = {
+                "list": list,
+                "tuple": tuple,
+                "dict": dict,
+            }
+            for output_type in type_mapping:
+                for _ in range(6):
+                    out = model(inp, output_type=output_type)
+                    loss = get_loss(out)
+                    loss.backward()
+                    self._model_step(model)
+                    out_static = model_static_graph(inp, output_type=output_type)
+                    self.assertTrue(isinstance(out_static, type_mapping[output_type]))
+                    loss_static = get_loss(out_static)
+                    loss_static.backward()
+                    self._model_step(model_static_graph)
+                    for p, p_static in zip(
+                        model.parameters(), model_static_graph.parameters(), strict=True
+                    ):
+                        self.assertEqual(p, p_static)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_returns_tensor_with_no_grad(self):
+            # Tests case where module returns tensor that does not require grad.
+            torch.cuda.set_device(self.rank)
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc2(F.relu(self.fc1(x)))
+                    y = x.clone()
+                    x = x.detach()
+                    assert not x.requires_grad
+                    return (x, y)
+
+            model = MyModel().to(self.rank)
+            inp = torch.randn(1, 10, device=self.rank)
+            for find_unused, static_graph in itertools.product(
+                [True, False], [True, False]
+            ):
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+                for _ in range(6):
+                    out = ddp(inp)
+                    self.assertFalse(out[0].requires_grad)
+                    o = (out[0] + out[1]).sum()
+                    o.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_detect_ddp_is_actually_static(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10)
+
+                def forward(self, x, find_unused, dynamic):
+                    if find_unused:
+                        if dynamic:
+                            return self.net2(self.net1(x))
+                        else:
+                            return self.net2(x)
+                    else:
+                        return self.net2(self.net1(x))
+
+            # Set of unused parameters don't change across iterations
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().cuda()
+            for find_unused in [True, False]:
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    find_unused_parameters=find_unused,
+                )
+                inp = torch.randn(1, 10, device="cuda")
+                for _ in range(6):
+                    out = ddp(inp, find_unused=find_unused, dynamic=False)
+                    loss = out.sum()
+                    loss.backward()
+                    self.assertTrue(ddp.reducer._ddp_graph_static())
+
+            # Set of unused parameters dynamically change
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            inp = torch.randn(1, 10, device="cuda")
+            for i in range(6):
+                out = ddp(inp, find_unused=True, dynamic=i % 2 == 0)
+                loss = out.sum()
+                loss.backward()
+            self.assertFalse(ddp.reducer._ddp_graph_static())
+
+        def _test_ddp_new_tensor_in_fwd(self, static_graph):
+            # Test from https://github.com/pytorch/pytorch/issues/60733
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+                    self.device = self.fc1.weight.device
+
+                def __init_opt(self):
+                    opt = torch.randn(1, 10, device=self.device)
+                    return opt
+
+                def forward(self, x, opt_1, opt_2, opt_nested):
+                    x = F.relu(self.fc1(x))
+                    x = self.fc2(x)
+                    if opt_1 is None:
+                        opt_1 = self.__init_opt()
+                    if opt_2 is None:
+                        opt_2 = self.__init_opt()
+                    if opt_nested is None or not torch.is_tensor(opt_nested):
+                        opt_nested = self.__init_opt()
+                    # Test multiple tensors as well as newly created tensors
+                    # within a struct.
+                    return x, opt_1, opt_2, {"tensor": opt_nested}
+
+            model = MyModel().to(self.rank)
+            for find_unused in [True, False]:
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    broadcast_buffers=False,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+
+                opt = [None for _ in range(3)]
+                for i in range(2):
+                    ddp.zero_grad()
+                    x = torch.randn(1, 10, device=self.rank)
+                    out, opt[0], opt[1], opt[2] = ddp(
+                        x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2]
+                    )
+                    for i in range(len(opt)):
+                        if torch.is_tensor(opt[i]):
+                            self.assertEqual(opt[i].grad_fn, None)
+                        else:
+                            self.assertEqual(opt[i]["tensor"].grad_fn, None)
+                    out.mean().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd_static_graph(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=True)
+
+        def _test_ddp_buffer_hook_allreduce(self, return_futures):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                futs = [
+                    dist.all_reduce(
+                        buffer, group=ddp.process_group, async_op=True
+                    ).get_future()
+                    for buffer in buffers
+                ]
+                if return_futures:
+                    return futs
+                else:
+                    torch.futures.collect_all(futs).wait()
+
+            hook_pre_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD
+            )
+            hook_post_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD
+            )
+            for hook_run_location in [
+                hook_pre_fwd,
+                hook_post_fwd,
+            ]:
+                model = NetWithBuffers().cuda(rank)
+                model_ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                model_ddp._register_buffer_comm_hook(
+                    model_ddp, buffer_comm_hook, hook_run_location
+                )
+                model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    broadcast_buffers=False,
+                )
+                inp = torch.randn(2, 10, device=rank)
+                for _ in range(2):
+                    loss_hook = model_ddp(inp).sum()
+                    # Since buffer reduction is done pre-forward, simulate it for
+                    # no hook case here.
+                    # Simulate allreduce appropriately depending on hook location.
+                    if hook_run_location == hook_pre_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+
+                    loss_no_hook = model_ddp_no_hook(inp).sum()
+                    if hook_run_location == hook_post_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+                    torch.cuda.synchronize()
+
+                    # if return_futures, they are only awaited on by DDP
+                    # at the end of the backwards pass for maximum overlap.
+                    if not return_futures:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                    loss_hook.backward()
+                    loss_no_hook.backward()
+                    # Note that when custom hooks return futures, this
+                    # comparison is not expected to work when hook run location
+                    # is pre-forward pass. This is because the hook does async
+                    # communication and forward pass modifies the buffer without
+                    # appropriate synchronization. Therefore, if returning
+                    # futures from custom buffer hooks, it is advised to set
+                    # hook run location to post forward.
+                    if return_futures and hook_run_location == hook_post_fwd:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce_return_future(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=True)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer_via_hook(self):
+            # test that _distributed_broadcast_coalesced via registered hook is
+            # equivalent to DDP's default broadcast coalesced.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                # named_buffers is a Dict[str, Tensor] representing a mapping
+                # from buffer name to buffer.
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                ddp._default_broadcast_coalesced(buffers)
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp._register_buffer_comm_hook(model_ddp, buffer_comm_hook)
+            model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model),
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                loss_hook = model_ddp(inp).sum()
+                loss_no_hook = model_ddp_no_hook(inp).sum()
+                self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                loss_hook.backward()
+                loss_no_hook.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_remove_autograd_hooks(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.error = True
+                    self.fc1 = nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp):
+                    if self.error:
+                        return self.fc1(SimulateError.apply(inp))
+                    else:
+                        return self.fc1(inp)
+
+            # Run with error to trigger backward pass that marks fc1 as being marked
+            # ready. If we don't remove autograd hooks before running below it would
+            # fail on the old autograd hook.
+            model = MyModel(self.rank)
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            model_ddp1 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+
+            with self.assertRaises(RuntimeError):
+                model_ddp1(input).sum().backward()
+
+            # Remove autograd hooks on old instance.
+            model_ddp1._remove_autograd_hooks()
+
+            # Try another DDP instance without error now.
+            model.error = False
+            model_ddp2 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp2(input).sum().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @unittest.skip(
+            "Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751"
+        )
+        def test_ddp_has_finalized(self):
+            @dataclass
+            class MyClass:
+                obj: torch.Tensor
+
+            class MyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.rank = rank
+                    self.fc1 = nn.Linear(1024, 1024).cuda(rank)
+                    self.fc2 = nn.Linear(1024, 2 * 1024).cuda(rank)
+
+                def forward(self, inp):
+                    if self.rank == 0:
+                        return self.fc1(inp), MyClass(self.fc2(inp))
+                    else:
+                        return self.fc1(inp), self.fc2(inp)
+
+            model = MyModel(self.rank)
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=(1024 * 4 / 1024 / 1024),  # One bucket per parameter.
+            )
+
+            if self.rank == 0:
+                out1, _ = ddp(input)
+                out1.sum().backward()
+            else:
+                out1, out2 = ddp(input)
+                (out1.sum() + out2.sum()).backward()
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp._check_reducer_finalized()
+
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp(input)
+            else:
+                ddp._check_reducer_finalized()
+                ddp(input)
+
+        """
+        # The set of "test_ddp_update_process_group..." below failed after
+        # upgrading CI from 2 GPUs to 4 GPUs.
+        # Commented out for now.
+        # Test purpose needs better documentation.
+
+        def _run_ddp_update_process_group(self, new_pg):
+            def get_num_torch_recompiles():
+                guard_failures = torch._dynamo.utils.guard_failures
+                num_recompiles = [len(guard_failures[code]) for code in guard_failures]
+                return 0 if len(num_recompiles) == 0 else max(num_recompiles)
+
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    # 4MB for multiple buckets.
+                    self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc3(self.fc2(self.fc1(inp)))
+
+
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+            model = torch.compile(ddp)
+
+            def run_iteration():
+                # Run regular iteration.
+                out = model(input, error=False)
+                out.sum().backward()
+                torch.cuda.synchronize()
+
+                # Run with error.
+                with self.assertRaises(RuntimeError):
+                    out = model(input, error=True)
+                    out.sum().backward()
+                torch.cuda.synchronize()
+
+            run_iteration()
+            assert 0 == get_num_torch_recompiles()
+
+            if new_pg:
+                # Now reduce world_size and run iteration.
+                group_size_2 = dist.new_group(ranks=[0, 1])
+                ddp._update_process_group(group_size_2)
+                if self.rank in [0, 1]:
+                    run_iteration()
+
+                # Increase the world size and run iteration.
+                group_size_3 = dist.new_group(ranks=[1, 2, 3])
+                ddp._update_process_group(group_size_3)
+                if self.rank in [1, 2, 3]:
+                    run_iteration()
+
+                # Back to default size.
+                ddp._update_process_group(_get_default_group())
+                run_iteration()
+            else:
+                # Create default pg of smaller size.
+                dist.destroy_process_group()
+
+                if self.rank in [1, 2, 3]:
+                    dist.init_process_group(
+                        init_method=self.init_method,
+                        backend=BACKEND,
+                        world_size=3,
+                        rank=self.rank - 1,
+                        timeout=timedelta(seconds=default_pg_timeout),
+                    )
+                    ddp._update_process_group(_get_default_group())
+                    run_iteration()
+                    dist.destroy_process_group()
+
+                # Need a barrier here to ensure ranks 1, 2 and 3 are done.
+                self._barrier(wait_for=4)
+
+                # Need to init pg again for "_barrier" to succeed.
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=BACKEND,
+                    world_size=4,
+                    rank=self.rank,
+                    timeout=timedelta(seconds=default_pg_timeout),
+                )
+
+            # Validate no more recompiles.
+            assert 0 == get_num_torch_recompiles()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_new_group(self):
+            self._run_ddp_update_process_group(new_pg=True)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_default_group(self):
+            self._run_ddp_update_process_group(new_pg=False)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_grad_undefined(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.fc1 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc2 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc3 = torch.nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc2(self.fc1(inp))
+
+
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+
+            try:
+                ddp(input, True).sum().backward()
+            except RuntimeError:
+                ddp._update_process_group(_get_default_group())
+
+            # Reset grads.
+            for param in ddp.parameters():
+                param.grad = None
+
+            # Run ddp again.
+            ddp(input, False).sum().backward()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_no_find_unused(self):
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(10, 10).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            ddp._update_process_group(_get_default_group())
+        """
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            class NetWithBuffers(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.a = nn.Linear(10, 10, bias=False)
+                    self.b = nn.Linear(10, 1, bias=False)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+
+                def forward(self, x):
+                    return self.b(self.a(x))
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                if rank == 0:
+                    model_ddp.module.buffer = model_ddp.module.buffer + 1
+                loss = model_ddp(inp).sum()
+                loss.backward()
+                # Ensure all buffers are synchronized.
+                bufs = [
+                    torch.empty_like(model_ddp.module.buffer)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(bufs, model_ddp.module.buffer)
+                rank_0_buf = bufs[0]
+                for buf in bufs[1:]:
+                    self.assertEqual(rank_0_buf, buf)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_static_graph_multi_forward(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+
+                def forward(self, x):
+                    return self.relu(self.lin(x))
+
+            torch.cuda.set_device(self.rank)
+            torch.manual_seed(42 << 1337 % (self.rank + 1))
+            model = Net().cuda(self.rank)
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], static_graph=True
+            )
+            inp = torch.ones(2, 10, device="cuda")
+            for _ in range(3):
+                model.zero_grad()
+                local_model.zero_grad()
+                a = model(inp)
+                b = model(inp)
+                loss = a.sum() + b.sum()
+                loss.backward()
+                # Grads should be equal to a local model that ran through inp
+                # `world_size` times and averaged grads
+                if self.rank == 0:
+                    inp_clone = inp.clone()
+                    iters = dist.get_world_size()
+                    for _ in range(iters):
+                        a = local_model(inp_clone)
+                        b = local_model(inp_clone)
+                        loss = a.sum() + b.sum()
+                        loss.backward()
+
+                    for p in local_model.parameters():
+                        p.grad.data = p.grad / iters
+
+                    for p_ddp, p_local in zip(
+                        model.parameters(), local_model.parameters(), strict=True
+                    ):
+                        self.assertTrue(
+                            torch.allclose(p_ddp.grad, p_local.grad),
+                            f"{p_ddp.grad} vs {p_local.grad}",
+                        )
+
+            dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_sync_bn_logged(self):
+            model = BatchNormNet()
+            rank = self.rank
+            # single gpu training setup
+            model_gpu = model.cuda(rank)
+            no_sync_bn = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model_gpu),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = no_sync_bn._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
+            self.assertFalse(sync_bn_logged)
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
+            model_DDP = torch.nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
+            self.assertTrue(sync_bn_logged)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_stateless_api_with_ddp(self):
+            class MockModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.l1 = torch.nn.Linear(1, 1)
+                    buffer = torch.ones(1)
+                    self.register_buffer("buffer", buffer)
+
+                def forward(self, x):
+                    return self.l1(x) + self.buffer
+
+            device = self.rank
+            module = MockModule().to(device)
+            module = torch.nn.parallel.DistributedDataParallel(
+                module, device_ids=[device]
+            )
+            x = torch.rand((1, 1)).to(device)
+            weight = torch.tensor([[1.0]], device=device, requires_grad=True)
+            bias = torch.tensor([0.0], device=device, requires_grad=True)
+            buffer = torch.tensor([0.0], device=device)
+            parameters = {
+                "module.l1.weight": weight,
+                "module.l1.bias": bias,
+                "module.buffer": buffer,
+            }
+            prev_weight = module.module.l1.weight.clone()
+            prev_buffer = module.module.buffer.clone()
+
+            res = torch.func.functional_call(module, parameters, x)
+            self.assertEqual(x, res)
+            # check that the weight remain unmodified
+            cur_weight = module.module.l1.weight
+            cur_buffer = module.module.buffer
+            self.assertEqual(cur_weight, prev_weight)
+            self.assertEqual(cur_buffer, prev_buffer)
+            # run a backward pass and check the gradients
+            res.backward()
+            self.assertIsNotNone(weight.grad)
+            self.assertIsNotNone(bias.grad)
+            # Gradient was not calculated for the module stated and buffers
+            self.assertIsNone(buffer.grad)
+            self.assertIsNone(module.module.l1.weight.grad)
+            self.assertIsNone(module.module.l1.bias.grad)
+            self.assertIsNone(module.module.buffer.grad)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_forward_backward_hook(self):
+            class DummyTestModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    torch.manual_seed(0)
+                    self.fc = nn.Linear(2, 2)
+
+                def forward(self, x):
+                    return self.fc(x)
+
+            def relu_hook(module, input):
+                return nn.functional.relu(input[0])
+
+            def gelu_hook(module, _input, output):
+                return nn.functional.gelu(output)
+
+            def celu_hook(module, _input, output):
+                return (nn.functional.celu(output[0]),)
+
+            local_model = DummyTestModel()
+            ddp_model = DummyTestModel()
+            local_model.fc.register_forward_pre_hook(relu_hook)
+            local_model.fc.register_forward_hook(gelu_hook)
+            ddp_model.fc.register_forward_pre_hook(relu_hook)
+            ddp_model.fc.register_forward_hook(gelu_hook)
+            local_model.fc.register_backward_hook(celu_hook)
+            ddp_model.fc.register_backward_hook(celu_hook)
+            ddp_model = DistributedDataParallel(
+                ddp_model.to(self.rank), device_ids=[self.rank]
+            )
+            input_data = torch.rand(5, 2)
+            output_local = local_model(input_data)
+            output_ddp = ddp_model(input_data.to(self.rank))
+            self.assertEqual(output_local, output_ddp)
+            output_local.sum().backward()
+            output_ddp.sum().backward()
+            ddp_grads = [p.grad for p in ddp_model.parameters()]
+            self.assertEqual(ddp_grads[0], local_model.fc.weight.grad)
+            self.assertEqual(ddp_grads[1], local_model.fc.bias.grad)
+
+        def _test_hook_pickling(self, hook, hook_state):
+            torch.manual_seed(0)
+            learning_rate = 0.01
+            chkpt_file = tempfile.gettempdir() + "/checkpoint.pt"
+            rank = self.rank
+
+            input = torch.randn(7, 1, device=rank)
+            target = torch.randn(7, 5, device=rank)
+            net = torch.nn.Linear(1, 5).to(rank)
+            ddp_model = DistributedDataParallel(copy.deepcopy(net), device_ids=[rank])
+            dummy_ddp_model = DistributedDataParallel(
+                copy.deepcopy(net), device_ids=[rank]
+            )
+            optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate)
+            ddp_model.register_comm_hook(hook_state, hook)
+            ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                out = ddp_model(input)
+                loss = F.mse_loss(out, target)
+                loss.backward()
+                optimizer.step()
+
+            state = {
+                "state_dict": ddp_model.state_dict(),
+                "comm_hook": hook,
+                "comm_hook_state": hook_state,
+            }
+
+            if rank == 0:
+                with self.assertLogs("torch.distributed") as captured:
+                    torch.save(state, chkpt_file)
+
+                # Check that the logger has only one entry
+                self.assertEqual(len(captured.records), 1)
+                # Check that the logger has an expected entry
+                self.assertEqual(
+                    captured.records[0].getMessage(),
+                    "NOTE: Process group is not serializable and excluded from a saved state.",
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{rank:d}"}
+            with self.assertLogs("torch.distributed") as captured:
+                checkpoint = torch.load(chkpt_file, map_location=map_location)
+
+            # Check that the logger has only one entry
+            self.assertEqual(len(captured.records), 1)
+            # Check that the logger has an expected entry
+            self.assertEqual(
+                captured.records[0].getMessage(),
+                "NOTE: Process group will be set to a default group (i.e. the world size).\
+                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded.",
+            )
+
+            dummy_ddp_model.load_state_dict(checkpoint["state_dict"])
+            dummy_hook = checkpoint["comm_hook"]
+            dummy_hook_state = checkpoint["comm_hook_state"]
+            dummy_optimizer = torch.optim.SGD(
+                dummy_ddp_model.parameters(), lr=learning_rate
+            )
+
+            # Check that loaded function is correct
+            self.assertEqual(dummy_hook.__qualname__, hook.__qualname__)
+
+            # Check that all slots' keys were restored correctly
+            self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__)
+
+            # Check that all slots' attributes are restored correctly
+            # Excluding ``process_group`` and ``rng``.
+            for entry in dummy_hook_state.__slots__:
+                if entry != "process_group" and entry != "rng":
+                    self.assertEqual(
+                        getattr(dummy_hook_state, entry), getattr(hook_state, entry)
+                    )
+
+            # Check that ``process_group`` was set to default
+            self.assertEqual(dummy_hook_state.process_group, _get_default_group())
+
+            # Check that a random state was restored properly:
+            # ``np.random.RandomState.get_state`` returns a tuple with entries:
+            # ``bit_generator`` - str,
+            # ``state.key`` - ndarray dtype[uint32],
+            # ``state.pos`` - int,
+            # ``has_gauss`` - int,
+            # ``gauss`` - float
+            #  (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi)
+            # To make sure random state was restored properly, all entries should equal the original
+            for entry1, entry2 in zip(
+                hook_state.rng.get_state(),
+                dummy_hook_state.rng.get_state(),
+                strict=True,
+            ):
+                np.testing.assert_array_equal(entry1, entry2)
+
+            dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook)
+            dummy_ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                dummy_optimizer.zero_grad()
+                out_origin = ddp_model(input)
+                out_dummy = dummy_ddp_model(input)
+                loss_origin = F.mse_loss(out_origin, target)
+                loss_dummy = F.mse_loss(out_dummy, target)
+                loss_origin.backward()
+                loss_dummy.backward()
+                optimizer.step()
+                dummy_optimizer.step()
+
+            # Check that gradients after 10 epochs are the same
+            for orig_param, dummy_param in zip(
+                ddp_model.parameters(), dummy_ddp_model.parameters(), strict=True
+            ):
+                self.assertEqual(orig_param.grad, dummy_param.grad)
+
+            dist.barrier()
+            if rank == 0:
+                os.remove(chkpt_file)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        @skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness")
+        def test_ddp_hook_pickling_powerSGD(self):
+            hook = powerSGD.powerSGD_hook
+            powersgd_state = powerSGD.PowerSGDState(
+                process_group=None,
+                matrix_approximation_rank=1,
+                start_powerSGD_iter=4,
+            )
+            self._test_hook_pickling(hook, powersgd_state)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device_mesh_initialization(self):
+            """
+            Test DDP with device_mesh initialization.
+            """
+            world_size = int(os.environ["WORLD_SIZE"])
+
+            from torch.distributed.device_mesh import init_device_mesh
+
+            device_mesh = init_device_mesh("cuda", (world_size,))
+
+            pg = _get_default_group()
+
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model, device_mesh=device_mesh
+            )
+            self.assertEqual(ddp_model.device_mesh, device_mesh)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Cannot specify both process_group and device_mesh arguments.",
+            ):
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, process_group=pg, device_mesh=device_mesh
+                )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Only 1D device mesh is supported,"
+            ):
+                device_mesh = init_device_mesh("cuda", (2, world_size // 2))
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, device_mesh=device_mesh
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_compile_static_graph(self):
+            "Tests that DDP works with torch compile when static_graph=True"
+            model = torch.nn.Linear(10, 10).cuda(self.rank)
+            model_clone = copy.deepcopy(model)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            ddp_static = torch.nn.parallel.DistributedDataParallel(
+                model_clone, device_ids=[self.rank], static_graph=True
+            )
+            ddp = torch.compile(ddp)
+            ddp_static = torch.compile(ddp_static)
+            input = torch.rand(10, 10).cuda(self.rank)
+            # verify output and gradient parity
+            for _ in range(6):
+                out_ddp = ddp(input).sum()
+                out_ddp_static = ddp_static(input).sum()
+                self.assertEqual(out_ddp, out_ddp_static)
+                out_ddp.backward()
+                out_ddp_static.backward()
+                for p1, p2 in zip(
+                    ddp.parameters(), ddp_static.parameters(), strict=True
+                ):
+                    self.assertEqual(p1.grad, p2.grad)
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sink_noclone(self):
+            "Tests that we can configure DDP to avoid clone"
+
+            class OpPatcher(TorchDispatchMode):
+                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+                    func_packet = func._overloadpacket
+                    if func_packet == torch.ops.aten.clone:
+                        raise RuntimeError("clone encountered!")
+                    kwargs = kwargs if kwargs else {}
+                    return func(*args, **kwargs)
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc = torch.nn.Linear(10, 10)
+
+                def forward(self, input):
+                    return self.fc(input)
+
+            model = MyModel().cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            ddp._set_ddp_sink_clone(False)
+            input = torch.rand(10, 10).cuda(self.rank)
+
+            with OpPatcher():
+                ddp(input).sum().backward()
+
+        def _test_skip_all_reduce_unused_parameters(
+            self,
+            find_unused_parameters=False,
+            static_graph=False,
+            skip_all_reduce_unused_params=False,
+        ):
+            class LargeNet(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(100, 5000, bias=False)
+                    # fc2 is unused
+                    self.fc2 = nn.Linear(100, 100, bias=False)
+
+                def forward(self, x):
+                    y = self.fc1(x)
+                    return y
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = LargeNet().cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                find_unused_parameters=find_unused_parameters,
+                static_graph=static_graph,
+                bucket_cap_mb=1.5,
+                skip_all_reduce_unused_params=skip_all_reduce_unused_params,
+            )
+            random_input = torch.randn(20, 100, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_skip_all_reduce_unused_parameters(self):
+            base_model = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True, static_graph=False
+            )
+            test_model_1 = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True,
+                static_graph=False,
+                skip_all_reduce_unused_params=True,
+            )
+
+            self.assertEqual(
+                base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2
+            )
+            self.assertEqual(
+                test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1
+            )
+
+            for i, j in zip(
+                base_model.parameters(), test_model_1.parameters(), strict=True
+            ):
+                self.assertEqual(i, j)
+
+
+instantiate_parametrized_tests(DistributedTest._DistTestBase)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9c2cc7ee52093b555d94e5f4426fcbb6721b47
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
@@ -0,0 +1,32 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed as dist
+from torch._C._distributed_c10d import FakeProcessGroup
+
+
+class FakeStore(dist.Store):
+    """
+    A fake store is a fake Key-Value store simply for initialization usage
+    the of fake process group, one can either use FakeStore or HashStore.
+    """
+
+
+def _create_fake_pg(common_opts, backend_opts):
+    """
+    A fake process group (not related to FakeTensor) is a process group which
+    doesn't actually do any communication, it just hallucinates some
+    communication.  You can run a single rank with a fake process group
+    without needing multiple processes (simulates per-rank behavior)
+
+    NOTE: This is not a real process group, and it would produce wrong results
+    for every collective. It should be used as a convenient tool when playing
+    with distributed but don't care about the actual data.
+    """
+    return FakeProcessGroup._create_internal(
+        common_opts.group_rank, common_opts.group_size, backend_opts
+    )
+
+
+dist.Backend.register_backend(
+    "fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu", "xpu"]
+)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..79aff05b3421f37cf63501e5692f84723be73439
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -0,0 +1,611 @@
+# mypy: allow-untyped-defs
+
+import sys
+import threading
+import weakref
+from dataclasses import dataclass
+from functools import partial, reduce
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import (
+    _create_work_from_future,
+    AllgatherOptions,
+    AllreduceOptions,
+    AllToAllOptions,
+    BarrierOptions,
+    BroadcastOptions,
+    ReduceOp,
+    ReduceScatterOptions,
+    ScatterOptions,
+    Store,
+)
+from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
+from torch.futures import Future
+from torch.utils import _pytree as pytree
+
+
+"""
+TODO:
+Lots of missing collectives.
+Collectives validation.
+Make timeout robust by making collectives respect the test deadline.
+Make tests robust by making collectives interruptible.
+We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
+
+"""
+
+
+def flatten_list(lst):
+    return pytree.tree_leaves(lst)
+
+
+def ret_work(ret):
+    fut = Future()
+    fut.set_result(ret)
+    return _create_work_from_future(fut)
+
+
+def binop_reduce(tensors, op):
+    res = op(torch.stack(tensors), dim=0)
+    if isinstance(res, torch.Tensor):
+        return res
+    # min/max return a namedtuple
+    return res.values
+
+
+def bitwise_reduce(tensors, op):
+    return reduce(op, tensors)
+
+
+_reduce_ops = {
+    ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
+    ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
+    ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
+    ReduceOp.MIN: partial(binop_reduce, op=torch.min),
+    ReduceOp.MAX: partial(binop_reduce, op=torch.max),
+    ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and),
+    ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or),
+    ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
+}
+
+
+# Note [Hide collectives mutation from autograd]
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Threaded PG is intended to closely simulate the behavior of regular process
+# groups.  However, our regular PG implementations perform a dispatch through
+# c10d, whereas Threaded PG does not for some reason (some superficial
+# but not very convincing reasons include that Threaded PG is implemented
+# in Python but you can't override Backend in Python, you can only override
+# ProcessGroup in Python), thereby bypassing the dispatch step.  Now we have
+# a problem: c10d's signatures are LIES, they mutate their (output) tensor
+# arguments but their annotations don't have mutations on them so we don't
+# actually update any view metadata if you do differentiation.  This
+# ordinarily "doesn't matter" because distributed collectives aren't
+# differentiable anyway, but it's possible to tickle this in testing if
+# someone tries to touch the grad_fn of a Tensor.  There a few ways to
+# fix this, but the easiest way was to use the .detach() trick to hide
+# the mutations from autograd.
+
+
+class AllToAll:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_tensor_list, _ = data[dest_rank]
+            for src_rank in range(world_size):
+                _, input_tensor_list = data[src_rank]
+                # See Note [Hide collectives mutation from autograd]
+                output_tensor_list[src_rank].detach().copy_(
+                    input_tensor_list[dest_rank]
+                )
+
+
+class AllToAllBase:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_buffer, _, output_split_sizes, _ = data[dest_rank]
+
+            output_indexes = self._size_cumsum(
+                output_buffer.size(0), output_split_sizes, world_size
+            )
+
+            for src_rank in range(world_size):
+                _, input_buffer, _, input_split_sizes = data[src_rank]
+                input_indexes = self._size_cumsum(
+                    input_buffer.size(0), input_split_sizes, world_size
+                )
+
+                # See Note [Hide collectives mutation from autograd]
+                output_buffer[
+                    output_indexes[src_rank] : output_indexes[src_rank + 1]
+                ].detach().copy_(
+                    input_buffer[
+                        input_indexes[dest_rank] : input_indexes[dest_rank + 1]
+                    ]
+                )
+
+    def _size_cumsum(
+        self,
+        buf_size: int,
+        sizes: Union[torch.Tensor, list[int], None],
+        world_size: int,
+    ) -> torch.Tensor:
+        if sizes is None or len(sizes) == 0:
+            sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64)
+        if not isinstance(sizes, torch.Tensor):
+            sizes = torch.tensor(sizes, dtype=torch.int64)
+        assert sizes.dtype == torch.int64
+        sizes = torch.cumsum(
+            torch.cat(
+                (torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes),
+                dim=0,
+            ),
+            dim=0,
+        )
+        return sizes
+
+
+class AllReduce:
+    def __init__(self, op):
+        if op.op not in _reduce_ops:
+            raise NotImplementedError(
+                f"AllReduce op {op.op} not supported on multithreaded pg for now."
+            )
+        self.op = op.op
+
+    @torch.no_grad()
+    def work(self, data):
+        for i in range(len(data[0])):
+            # use rank0 as the device for sum
+            rank_0_device = data[0][i].device
+            # collect all data to the list and make them
+            # all on rank 0 device
+            tensors = [
+                data[src_rank][i].to(rank_0_device) for src_rank in range(len(data))
+            ]
+
+            # now mimic reduce across all ranks
+            res = _reduce_ops[self.op](tensors)
+
+            # copy all the reduced value to each rank
+            for src_rank in range(len(data)):
+                # See Note [Hide collectives mutation from autograd]
+                data[src_rank][i].detach().copy_(res.to(data[src_rank][i].device))
+
+
+class AllGather:
+    @torch.no_grad()
+    def work(self, data):
+        for src_rank in range(len(data)):
+            in_tensor_list = data[src_rank][1]
+            # Can't handle all_gather with multiple tensors
+            assert len(in_tensor_list) == 1
+            src_tensor = in_tensor_list[0]
+
+            for dest in data:
+                dest_tensor = dest[0][0][src_rank]
+                # See Note [Hide collectives mutation from autograd]
+                dest_tensor.detach().copy_(src_tensor)
+
+
+class Scatter:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        src_in_tensor_list = data[self.src][1]
+        # Can't handle scatter with multiple input tensor list
+        assert len(src_in_tensor_list) == 1
+        src_in_tensors = src_in_tensor_list[0]
+
+        for rank, each_rank_data in enumerate(data):
+            out_tensor_list = each_rank_data[0]
+            # Can't handle scatter with multiple output tensor
+            assert len(out_tensor_list) == 1
+            dest_tensor = out_tensor_list[0]
+            # See Note [Hide collectives mutation from autograd]
+            dest_tensor.detach().copy_(src_in_tensors[rank])
+
+
+class Gather:
+    def __init__(self, dst):
+        self.dst = dst
+
+    @torch.no_grad()
+    def work(self, data):
+        # Can't handle gather with multiple tensor lists
+        assert len(data[self.dst][0]) == 1
+        out_tensor_list = data[self.dst][0][0]
+        for rank, each_rank_data in enumerate(data):
+            src_in_tensor_list = each_rank_data[1]
+            # Can't handle gather with multiple tensor lists
+            assert len(src_in_tensor_list) == 1
+            dest_tensor = out_tensor_list[rank]
+            # See Note [Hide collectives mutation from autograd]
+            dest_tensor.detach().copy_(src_in_tensor_list[0])
+
+
+class ReduceScatter:
+    def __init__(self, op):
+        if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
+            raise NotImplementedError(f"ReduceScatter does not support {op}")
+        self.op = op
+
+    @torch.no_grad()
+    def work(self, data):
+        start_reduction = [False for _ in range(len(data))]
+        for each_rank_data in data:
+            # Can't handle reduce_scatter with multiple scatter list
+            assert len(each_rank_data[1]) == 1
+            to_scatter = each_rank_data[1][0]
+            for i in range(len(to_scatter)):
+                dest_tensor_on_rank_i = data[i][0]
+                # Can't handle reduce_scatter with multiple output tensor
+                assert len(dest_tensor_on_rank_i) == 1
+                dst_tensor_device = dest_tensor_on_rank_i[0].device
+                if not start_reduction[i]:
+                    # See Note [Hide collectives mutation from autograd]
+                    dest_tensor_on_rank_i[0].detach().copy_(
+                        to_scatter[i].to(dst_tensor_device)
+                    )
+                    start_reduction[i] = True
+                else:
+                    # See Note [Hide collectives mutation from autograd]
+                    dest_tensor_on_rank_i[0].detach().add_(
+                        to_scatter[i].to(dst_tensor_device)
+                    )
+        if self.op == dist.ReduceOp.AVG:
+            num_ranks = len(data)
+            for each_rank_data in data:
+                # See Note [Hide collectives mutation from autograd]
+                each_rank_data[0][0].detach().div_(num_ranks)
+
+
+class Broadcast:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        in_tensor_list = flatten_list(data[self.src])
+        for i in range(len(data)):
+            if i == self.src:
+                continue
+            out_tensor_list = flatten_list(data[i])
+            for j in range(len(in_tensor_list)):
+                # See Note [Hide collectives mutation from autograd]
+                out_tensor_list[j].detach().copy_(in_tensor_list[j])
+
+
+class Collective:
+    def __init__(self, world_size, collective, pg):
+        self._world_size = world_size
+        self._collective = collective
+
+        self._start_cond = threading.Condition()
+        self._done_cond = threading.Condition()
+
+        self._data = [None] * world_size
+        self._count = 0
+        self._done = False
+
+        self._pg = pg
+
+    def join(self, rank, data):
+        with self._start_cond:
+            self._data[rank] = data
+            self._count += 1
+
+            # notify rank 0
+            if self._count == self._world_size:
+                if rank > 0:
+                    self._start_cond.notify()
+
+            if rank == 0:
+                self._start_cond.wait_for(
+                    lambda: self._count == self._world_size
+                    or self._pg._terminate.is_set()
+                )
+                # SystemExit is not a subclass of Exception but BaseException
+                # and can be distinguished from normal exception raised from program errors
+                # so that we can hide it from the exception queue
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+
+        with self._done_cond:
+            # wait for rank 0 to finish
+            if rank > 0:
+                self._done_cond.wait_for(
+                    lambda: self._done or self._pg._terminate.is_set()
+                )
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+            else:
+                # copy data around
+                self._collective.work(self._data)
+                self._done = True
+                self._done_cond.notify_all()
+        return ret_work(data)
+
+
+class ProcessLocalGroup(dist.ProcessGroup):
+    _coll_lock = threading.Lock()
+    _cur_coll_on_pgs = {}
+
+    _terminate = threading.Event()
+
+    @classmethod
+    def _start_coll(cls, collective, pg):
+        with cls._coll_lock:
+            # pg_name is unique, we use that to record the mapping between pg and collective
+            if pg.pg_name not in cls._cur_coll_on_pgs:
+                cls._cur_coll_on_pgs[pg.pg_name] = Collective(
+                    pg.size(), collective, cls
+                )
+            return cls._cur_coll_on_pgs[pg.pg_name]
+
+    @classmethod
+    def _end_coll(cls, collective, pg):
+        # This is racily called by all ranks, so only one will work
+        with cls._coll_lock:
+            if (
+                pg.pg_name in cls._cur_coll_on_pgs
+                and cls._cur_coll_on_pgs[pg.pg_name] == collective
+            ):
+                cls._cur_coll_on_pgs.pop(pg.pg_name)
+
+    @classmethod
+    def exception_handle(cls, exc):
+        cls._terminate.set()
+        for coll in cls._cur_coll_on_pgs.values():
+            with coll._start_cond:
+                coll._start_cond.notify()
+            with coll._done_cond:
+                coll._done_cond.notify_all()
+
+    @classmethod
+    def reset(cls):
+        with cls._coll_lock:
+            cls._cur_coll_on_pgs = {}
+            cls._terminate.clear()
+
+    def alltoall_base(
+        self,
+        output_buffer: torch.Tensor,
+        input_buffer: torch.Tensor,
+        output_split_sizes: Optional[list[int]],
+        input_split_sizes: Optional[list[int]],
+        opts=AllToAllOptions(),
+    ) -> torch.Tensor:
+        coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
+        res = coll.join(
+            self._rank,
+            (output_buffer, input_buffer, output_split_sizes, input_split_sizes),
+        )
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()):
+        coll = ProcessLocalGroup._start_coll(AllToAll(), self)
+        res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def barrier(self, opts=BarrierOptions()):
+        return self.allreduce(tensor_list=[torch.ones(1)])
+
+    def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
+        coll = ProcessLocalGroup._start_coll(AllGather(), self)
+        res = coll.join(self._rank, (output_tensors, input_tensor))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
+        tensor_list = list(torch.chunk(output_tensor, self._world_size))
+        return self.allgather([tensor_list], [input_tensor], opts)
+
+    def broadcast(self, tensor_list, opts=BroadcastOptions()):
+        coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def gather(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self)
+        res = coll.join(self._rank, (output_tensor, scatter_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _reduce_scatter_base(
+        self, output_tensor, input_tensor, opts=ReduceScatterOptions()
+    ):
+        tensor_list = list(torch.chunk(input_tensor, self._world_size))
+        return self.reduce_scatter([output_tensor], [tensor_list], opts)
+
+    def reduce_scatter_tensor_coalesced(
+        self, output_tensors, input_tensors, opts=ReduceScatterOptions()
+    ):
+        works = [
+            self._reduce_scatter_base(output_tensor, input_tensor, opts)
+            for output_tensor, input_tensor in zip(
+                output_tensors, input_tensors, strict=True
+            )
+        ]
+        for work in works[:-1]:
+            work.wait()
+        return works[-1]
+
+    def allgather_into_tensor_coalesced(
+        self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()
+    ):
+        res = None
+        for o_t, i_t in zip(output_tensor_list, input_tensor_list, strict=True):
+            res = self._allgather_base(o_t, i_t)
+        return res
+
+    def __init__(self, rank, world_size):
+        super().__init__(rank, world_size)
+        self._rank = rank
+        self._world_size = world_size
+        world = dist.distributed_c10d._world
+        if isinstance(world, ThreadLocalWorld):
+            world = world._get_world()
+        self._world = weakref.ref(world)
+        self._ctx = torch.autograd.set_multithreading_enabled(False)
+
+    def size(self):
+        return self._world_size
+
+    @property
+    def pg_name(self):
+        """
+        return the global registered name of the current pg in the world
+        """
+        return self._world().pg_names[self]
+
+    @property
+    def group_name(self):
+        return self.pg_name
+
+    def getBackendName(self):
+        return "threaded"
+
+    def __repr__(self):
+        return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}"
+
+
+def _create_threaded_pg(prefix_store, rank, world_size, timeout):
+    pg = ProcessLocalGroup(rank, world_size)
+    # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional
+    # When device mesh involves sub groups while store based barrier is not enabled in c10d,
+    # even though threaded pg actual collectives are assumed to be single threaded,
+    # different threads may be initializing different groups,
+    # leading to race conditions.
+    # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups
+    # (dim 0 and 1) would be initialized in different threads independently.
+    # In this case we can no longer rely on class or global variables
+    # but have to rely on store based barrier to make sure each group
+    # is ready separately before we can invoke collectives in any of the groups.
+
+    # the prefix store is already per group so we pass an empty name here
+    _store_based_barrier(rank, prefix_store, "", world_size, timeout)
+    return pg
+
+
+dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
+
+
+@dataclass
+class WorldData:
+    default_pg: dist.ProcessGroup
+    pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
+    pg_names: dict[dist.ProcessGroup, str]
+    pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]]
+    pg_backend_config: dict[dist.ProcessGroup, str]
+    group_count: int
+    tags_to_pg: dict[str, list[dist.ProcessGroup]]
+    pg_to_tag: dict[dist.ProcessGroup, str]
+    pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]
+
+
+class ThreadLocalWorld:
+    _world = threading.local()
+
+    def _get_world(self) -> WorldData:
+        if not hasattr(ThreadLocalWorld._world, "world"):
+            ThreadLocalWorld._world.world = WorldData(
+                None, {}, {}, {}, {}, 0, {}, {}, {}
+            )
+        return ThreadLocalWorld._world.world
+
+    @property
+    def default_pg(self):
+        return self._get_world().default_pg
+
+    @default_pg.setter
+    def default_pg(self, value):
+        self._get_world().default_pg = value
+
+    @property
+    def pg_map(self):
+        return self._get_world().pg_map
+
+    @property
+    def pg_names(self):
+        return self._get_world().pg_names
+
+    @property
+    def pg_group_ranks(self):
+        return self._get_world().pg_group_ranks
+
+    @property
+    def pg_backend_config(self):
+        return self._get_world().pg_backend_config
+
+    @property
+    def group_count(self) -> int:
+        return self._get_world().group_count
+
+    @group_count.setter
+    def group_count(self, value):
+        self._get_world().group_count = value
+
+    @property
+    def tags_to_pg(self):
+        return self._get_world().tags_to_pg
+
+    @property
+    def pg_to_tag(self):
+        return self._get_world().pg_to_tag
+
+    @property
+    def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]:
+        return self._get_world().pg_coalesce_state
+
+
+_old_pg_world = None
+_ctx_manager = None
+
+
+def _install_threaded_pg():
+    global _old_pg_world
+    global _ctx_manager
+    _old_pg_world = dist.distributed_c10d._world
+    dist.distributed_c10d._world = ThreadLocalWorld()
+    _ctx_manager = torch.autograd.set_multithreading_enabled(False)
+
+    return dist.distributed_c10d._world
+
+
+def _uninstall_threaded_pg():
+    dist.distributed_c10d._world = _old_pg_world
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c0e25bf1f954150b69b668d54c288791b4a6a1b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7bc0c513bdf8a5cbfa3467536819231a77d6c61
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c384d271a677f59e7a389a5f88eb16755c7ee98
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..af136fb8722d17d70767718a0cd327f71d730fda
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
@@ -0,0 +1,754 @@
+# mypy: allow-untyped-defs
+
+import enum
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.testing._internal.dist_utils as dist_utils
+from torch import nn, Tensor
+from torch._jit_internal import Future
+from torch.distributed.nn import RemoteModule
+from torch.distributed.nn.api.remote_module import (
+    _REMOTE_MODULE_PICKLED_ATTRIBUTES,
+    _RemoteModule,
+)
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
+
+
+# RPC handler for querying the device on the destination worker.
+def remote_device(module_rref):
+    for param in module_rref.local_value().parameters():
+        return param.device
+
+
+# RPC handler for querying __dict__ on the destination worker.
+def remote_module_attributes(remote_module):
+    return remote_module.__dict__
+
+
+# RPC handler for running forward on the destination worker.
+def remote_forward(remote_module, args):
+    return remote_module.forward(*args)
+
+
+# RPC handler for running forward_async on the destination worker.
+def remote_forward_async(remote_module, args):
+    # Since future cannot be pickled and sent over the RPC layer,
+    # have to wait and behave just like ``forward_sync``.
+    return remote_module.forward_async(*args).wait()
+
+
+# RPC handler for getting training mode on the destination worker.
+def get_remote_training_arg(module_rref):
+    return module_rref.local_value().training
+
+
+class ModuleCreationMode(enum.Enum):
+    MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
+    MODULE_CTOR = "module_ctor"
+
+
+@torch.jit.interface
+class MyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+@torch.jit.interface
+class RemoteMyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+    def forward_async(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> Future[tuple[str, int, Tensor]]:
+        pass
+
+
+class MyModule(nn.Module):
+    def __init__(self, first_arg, first_kwarg=-1):
+        super().__init__()
+        self.param1 = _PARAM_VAL
+
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        return word, number, tensor
+
+
+class BadModule:
+    def __init__(self, first_arg, first_kwarg=-1):
+        pass
+
+
+def create_scripted_module(first_arg, first_kwarg=-1):
+    module = MyModule(first_arg, first_kwarg=first_kwarg)
+    scripted_module = torch.jit.script(module)
+    return scripted_module
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonRemoteModuleTest(RpcAgentTestFixture):
+    @property
+    def world_size(self):  # Override setting in RpcAgentTestFixture
+        return 2
+
+    @staticmethod
+    def _create_remote_module_iter(remote_device, modes=None):
+        if modes is None:
+            modes = ModuleCreationMode.__members__.values()
+
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        if ModuleCreationMode.MODULE_CTOR in modes:
+            remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
+            yield remote_module
+
+        if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
+            remote_module = _RemoteModule(
+                remote_device,
+                create_scripted_module,
+                args,
+                kwargs,
+                _module_interface_cls=MyModuleInterface,
+            )
+            scripted_remote_module = torch.jit.script(remote_module)
+            yield scripted_remote_module
+
+
+class RemoteModuleTest(CommonRemoteModuleTest):
+    @dist_utils.dist_init
+    def test_bad_module(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        remote_device = f"{dst_worker_name}/cpu"
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+    @dist_utils.dist_init
+    def test_forward_async(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret_fut = remote_module.forward_async(*args)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_async_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
+            ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
+            ret = ret_fut.wait()
+            return ret
+
+        ret = run_forward_async(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_sync(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret = remote_module.forward(*args)
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_sync_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_with_kwargs(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2)
+        kwargs = dict(word="3")
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+    @dist_utils.dist_init
+    def test_remote_parameters(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            param_rrefs = remote_module.remote_parameters()
+            self.assertEqual(len(param_rrefs), 1)
+            self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_get_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``get_module_rref``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            rref = remote_module.get_module_rref()
+            self.assertEqual(rref, remote_module.module_rref)
+            for param in rref.to_here().parameters():
+                self.assertTrue(torch.equal(param, _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_train_eval(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module.train()
+            ret1 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret1, True)
+
+            remote_module.eval()
+            ret2 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret2, False)
+
+    @dist_utils.dist_init
+    def test_unsupported_methods(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
+            ):
+                remote_module.register_buffer("buffer", torch.ones(5))
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_parameter`` not supported for RemoteModule",
+            ):
+                remote_module.register_parameter(
+                    "param", torch.nn.Parameter(torch.ones(1))
+                )
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``add_module`` not supported for RemoteModule"
+            ):
+                remote_module.add_module("empty", None)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``apply`` not supported for RemoteModule"
+            ):
+                fn = torch.rand((3, 3), requires_grad=False)
+                remote_module.apply(fn)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cuda`` not supported for RemoteModule"
+            ):
+                remote_module.cuda()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cpu`` not supported for RemoteModule"
+            ):
+                remote_module.cpu()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``type`` not supported for RemoteModule"
+            ):
+                remote_module.type(torch.FloatTensor)
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``float`` not supported for RemoteModule"
+            ):
+                remote_module.float()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``double`` not supported for RemoteModule"
+            ):
+                remote_module.double()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``bfloat16`` not supported for RemoteModule"
+            ):
+                remote_module.bfloat16()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``to`` not supported for RemoteModule"
+            ):
+                remote_module.to("cpu", dtype=torch.int32)
+
+            def hook(module, grad_input, grad_output):
+                pass
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_backward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_backward_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_pre_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_pre_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_hook(hook)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.state_dict()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``load_state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.load_state_dict({})
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.",
+            ):
+                remote_module.parameters()
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``named_parameters`` not supported for RemoteModule",
+            ):
+                remote_module.named_parameters()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``buffers`` not supported for RemoteModule"
+            ):
+                remote_module.buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_buffers`` not supported for RemoteModule"
+            ):
+                remote_module.named_buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``children`` not supported for RemoteModule"
+            ):
+                remote_module.children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_children`` not supported for RemoteModule"
+            ):
+                remote_module.named_children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``modules`` not supported for RemoteModule"
+            ):
+                remote_module.modules()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_modules`` not supported for RemoteModule"
+            ):
+                remote_module.named_modules()
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``requires_grad_`` not supported for RemoteModule"
+            ):
+                remote_module.requires_grad_()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``zero_grad`` not supported for RemoteModule"
+            ):
+                remote_module.zero_grad()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``share_memory`` not supported for RemoteModule"
+            ):
+                remote_module.share_memory()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``extra_repr`` not supported for RemoteModule"
+            ):
+                remote_module.extra_repr()
+
+    @dist_utils.dist_init
+    def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # If a new attribute is added to this RemoteModule after the initialization,
+        # and it will be sent over the wire by RPC,
+        # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
+        # Note that adding a new attribute out of constructor should rarely happen.
+        # If a new attribute is added to RemoteModule constructor,
+        # there is a sanity check to enforce developers to add this attribute to either
+        # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            new_attr_name = "new_attr"
+            setattr(remote_module, new_attr_name, 1)
+
+            attrs = rpc.rpc_sync(
+                dst_worker_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertNotIn(new_attr_name, attrs)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with TemporaryFileName() as fname:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC",
+                ):
+                    torch.save(remote_module, fname)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+        ):
+            with (
+                TemporaryFileName() as fname,
+                self.assertRaisesRegex(
+                    torch.jit.Error, "can only be pickled when using RPC"
+                ),
+            ):
+                torch.save(remote_module, fname)
+
+
+class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
+    @property
+    def world_size(self):  # Override setting in CommonRemoteModuleTest
+        return 3
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            # Test querying some simple attributes from worker2.
+            attrs = rpc.rpc_sync(
+                dst_worker2_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs)
+            self.assertEqual(attrs["on"], "worker1")
+            self.assertEqual(attrs["device"], "cpu")
+            self.assertFalse(attrs["is_device_map_set"])
+            self.assertFalse(attrs["is_scriptable"])
+
+            # Test the installed methods on worker1's can be initiated by worker2 over RPC layer.
+            # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``,
+            # not have another worker to initiate forward over the RPC layer.
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args))
+            self.assertEqual(ret1, tuple(reversed(args)))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward_async, (remote_module, args)
+            )
+            self.assertEqual(ret2, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire_script_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Passing a script RemoteModule over RPC is not supported."
+        ):
+            # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+            for remote_module in self._create_remote_module_iter(
+                dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            ):
+                # Test querying some simple attributes from worker2.
+                rpc.rpc_sync(
+                    dst_worker2_name, remote_module_attributes, (remote_module,)
+                )
+
+    @dist_utils.dist_init
+    def test_create_remote_module_from_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module2 = rpc.rpc_sync(
+                dst_worker2_name,
+                RemoteModule.init_from_module_rref,
+                (dst_worker2_name, remote_module.get_module_rref()),
+            )
+
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward, (remote_module2, args)
+            )
+            self.assertEqual(ret1, ret2)
+
+
+class CudaRemoteModuleTest(CommonRemoteModuleTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_valid_device(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = dist_utils.worker_name(dst_rank)
+
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+        # Test rank works as well.
+        for remote_module in self._create_remote_module_iter(
+            f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_invalid_devices(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            r"Expected one of .+ device type at start of device string",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/foo",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        if TEST_WITH_ROCM:
+            errorString = (
+                r"HIP error: invalid device ordinal\n"
+                r"HIP kernel errors might be asynchronously reported at some other API call, "
+                r"so the stacktrace below might be incorrect.\n"
+                r"For debugging consider passing AMD_SERIALIZE_KERNEL=3"
+            )
+        else:
+            errorString = r"CUDA error: invalid device ordinal"
+        with self.assertRaisesRegex(RuntimeError, errorString):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:100",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cpu2",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:0/cuda:1",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /cuda:0. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/cuda:0",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device.
+        t1 = torch.ones(1)
+        args = (t1, 2)
+        t2 = t1 * 2
+        kwargs = dict(word=t2)
+
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                f"{dst_worker_name}/cuda:0",
+                modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE],
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+        # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+        self.assertEqual(ret[2].device.type, "cpu")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..637692c9f1f72f3ef58c4748f54d2929703271c5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c75ee75ad9b24bb2883d385a1074b768ed51040a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81400d88693281df71242f3820ab9d3e794d6c96
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9e17856adbcba1acd1ed625ec0fc118c1f43fd5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c23a719fbabe4c35dd78b4bb9991f9fd775fb5a0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cc8b1d26c9b76da45d5291a8b267e378cd4e813
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1abadd33309da7c933ea03ec300e67d05d343600
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
@@ -0,0 +1,2756 @@
+# mypy: allow-untyped-defs
+
+import random
+import sys
+import threading
+import time
+from datetime import timedelta
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.testing._internal.dist_utils
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.distributed.rpc import RRef
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import (
+    IS_MACOS,
+    skip_but_pass_in_sandcastle_if,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    initialize_pg,
+    wait_until_node_failure,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# Right now we test up to 3-layer nested rpc calls.
+# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
+# sent from prev rank respectively.
+# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
+# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
+# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
+rpc_done = [False, False, False, False]
+ctx_ids = [-1, -1, -1, -1]
+
+known_context_ids = set()
+
+requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
+
+
+# Send rpc done info and context_id to
+# dst_rank = (self.rank + rank_distance) % self.world_size
+# we don't need a lock here since the GIL is held while executing remote
+# python UDFs, so access is serialized across several workers.
+def _set_rpc_done(ctx_id, rank_distance):
+    global rpc_done
+    global ctx_ids
+    global known_context_ids
+    rpc_done[rank_distance] = True
+    ctx_ids[rank_distance] = ctx_id
+    known_context_ids.add(ctx_id)
+
+
+def _check_rpc_done(rank_distance):
+    while not rpc_done[rank_distance]:
+        time.sleep(0.1)
+
+
+def _torch_ones(sizes, requires_grad=False):
+    return torch.ones(sizes, requires_grad=requires_grad)
+
+
+# This method must be called on the rref owner, and verifies that the grad of
+# rref tensor equals to the given grad.
+def _compare_owner_value(context_id, rref, grad):
+    grads = dist_autograd.get_gradients(context_id)
+    x = grads[rref.local_value()]
+    if x.is_sparse:
+        assert grad.is_sparse
+        x = x.to_dense()
+        grad = grad.to_dense()
+    else:
+        assert not grad.is_sparse
+    return torch.equal(x, grad)
+
+
+def create_tensor():
+    return torch.ones((3, 3), requires_grad=True)
+
+
+def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3.2, 4.1, 5.3]
+    tensor = torch.sparse_coo_tensor(
+        i, v, (3, 3), requires_grad=requires_grad, dtype=dtype
+    )
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+@torch.jit.script
+def create_torchscript_tensor() -> torch.Tensor:
+    return torch.ones((3, 3)).requires_grad_()
+
+
+def my_py_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+def my_scalar_add(a, b):
+    return a + b
+
+
+def my_rref_add(rref_t1, t2):
+    ret = torch.add(rref_t1.local_value(), t2)
+    return ret
+
+
+@torch.jit.script
+def my_script_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor:
+    t1 = ref_t1.to_here()
+    return torch.add(t1, t2)
+
+
+def my_nested_rref_add(dst, rref_t1, t2):
+    return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
+
+
+def ret_requires_grad():
+    return requires_grad_tensor
+
+
+def my_py_nested_call(t1, t2, dst, world_size, hops):
+    next_dst = (dst + 1) % world_size
+    if hops > 0:
+        return rpc.rpc_sync(
+            worker_name(next_dst),
+            my_py_nested_call,
+            args=(t1, t2, next_dst, world_size, hops - 1),
+        )
+    else:
+        return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2))
+
+
+# after dist autograd context is cleaned up, it should be cleaned up on other
+# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
+# ensures that all the contexts have been cleaned up in that timeframe.any
+def _all_contexts_cleaned_up(timeout_seconds=10):
+    global known_context_ids
+    start = time.time()
+    context_id_to_raised = set()
+    while (
+        time.time() - start < timeout_seconds
+        and context_id_to_raised != known_context_ids
+    ):
+        for context_id in known_context_ids:
+            try:
+                dist_autograd._retrieve_context(context_id)
+            except RuntimeError:
+                context_id_to_raised.add(context_id)
+    # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
+    success = context_id_to_raised == known_context_ids
+    return success
+
+
+# This function creates a dis autograd context, run rpc_sync on the given ps,
+# and then blocks until the ps has verified the grads are correctly accumulated.
+def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+# This function is the same as _run_trainer, except rpc calls torchscript
+# function "my_script_ref_add" instead of python function "my_rref_add"
+def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+class SimulateBackwardError(Function):
+    _simulate_error = True
+
+    @staticmethod
+    def forward(ctx, input):
+        return input
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, input):
+        if SimulateBackwardError._simulate_error:
+            raise Exception("Simulate error on backward pass")  # noqa: TRY002
+        else:
+            return input
+
+
+class ExecMode(Enum):
+    LOCAL = 1  # Run the operation locally.
+    RPC_SYNC = 2  # Run the operation using rpc_sync
+    REMOTE = 3  # Run the operation using remote.
+    RPC_ASYNC = 4  # Run the operation using rpc_async
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDistAutogradTest(RpcAgentTestFixture):
+    def _exec_func_with_dst(self, dst, exec_mode, method, *args):
+        if ExecMode.LOCAL == exec_mode:
+            if len(args) == 1 and isinstance(args[0], list):
+                return method(*args[0])
+            return method(*args)
+        elif ExecMode.RPC_SYNC == exec_mode:
+            return rpc.rpc_sync(worker_name(dst), method, args=(args))
+        elif ExecMode.REMOTE == exec_mode:
+            return rpc.remote(worker_name(dst), method, args=(args)).to_here()
+        elif ExecMode.RPC_ASYNC == exec_mode:
+            fut = rpc.rpc_async(worker_name(dst), method, args=(args))
+            return fut.wait()
+        else:
+            raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+    def _exec_func(self, exec_mode, method, *args):
+        return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args)
+
+    def _next_rank(self):
+        if hasattr(self, "dst_rank"):
+            self.dst_rank = (self.dst_rank + 1) % self.world_size
+            if self.dst_rank == self.rank:
+                return self._next_rank()
+        else:
+            self.dst_rank = (self.rank + 1) % self.world_size
+        return self.dst_rank
+
+    def _check_rpc_done(self, rank_distance):
+        _check_rpc_done(rank_distance)
+
+    def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
+        if exec_mode == ExecMode.LOCAL:
+            torch.autograd.backward(tensors)
+            return [arg.grad for arg in args]
+        else:
+            self._verify_backwards_remote(tensors, context_id, local_grads, *args)
+
+    def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
+        dist_autograd.backward(context_id, tensors)
+
+        # Verify grads were accumulated appropriately.
+        grads = dist_autograd.get_gradients(context_id)
+        nargs = len(args)
+        ngrads = 0
+        for i in range(nargs):
+            if local_grads[i] is not None:
+                self.assertIn(args[i], grads)
+                self.assertEqual(local_grads[i], grads[args[i]])
+                ngrads += 1
+            else:
+                self.assertNotIn(args[i], grads)
+
+        self.assertEqual(ngrads, len(grads))
+
+    def _test_graph(self, fn, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor()
+                t2 = build_sparse_tensor()
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Verify graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # Verify graph for previous context id.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+        # autograd context should be cleaned up by now.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._retrieve_context(context_id)
+
+        # No autograd context available.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._current_context()
+
+    # 3-layer nested calls
+    def _test_graph_for_py_nested_call(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            # Barrier to ensure all RPCs are done.
+            dist.barrier()
+
+            for rd in [1, 2, 3]:
+                rpc.rpc_sync(
+                    worker_name((self.rank + rd) % self.world_size),
+                    _set_rpc_done,
+                    args=(context_id, rd),
+                )
+
+            # Barrier to ensure all set_rpc_done have completed.
+            dist.barrier()
+
+            # For self.rank, it has 4 graphs to verify
+            # One is for current context id when this rank send first rpc call.
+            # Second one is for prev context id when this rank make 1st nested
+            # call.
+            # Third one is for prev prev context id when this rank make
+            # 2nd nested call.
+            # Last one is for prev prev prev context id when this rank
+            # execute the torch.add() operator.
+
+            # Verify first graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Verify second graph for 1st nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # Verify third graph for 2nd nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[2])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # verify last graph for rpc call execution.
+            ctx = dist_autograd._retrieve_context(ctx_ids[3])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    # Rank0->Rank1->Rank0
+    def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, 1),
+            )
+
+            # For self.rank, it has 2 graphs to verify.
+            # One is for current context id when this rank send first rpc
+            # call and execute the torch.add() operator.
+            # Another one is for prev context id when this rank make
+            # nested call.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(2, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(2, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                list(recv_functions.values())[1],
+                t1,
+                t2,
+                ret,
+            )
+            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])
+
+            # Verify two pairs of send and recv functions for nested
+            # call
+            self._check_rpc_done(1)
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=False)
+                t2 = build_sparse_tensor(requires_grad=False)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=False)
+                t2 = torch.zeros(3, 3, requires_grad=False)
+            if ExecMode.RPC_SYNC == exec_mode:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            ctx = dist_autograd._current_context()
+            send_functions = ctx._send_functions()
+            self.assertEqual(len(send_functions), 0)
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(len(recv_functions), 0)
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # NB: RRef.to_here() always passes the autograd context to the
+            # the callee, as the caller does not know whether the return
+            # value would contain a requires_grad tensor or not.
+            #
+            # rpc/remote with udf (_set_rpc_done here) also always passes the
+            # autograd context to the callee due to the same reason.
+            self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1]))
+            dist.barrier()
+
+    def _test_rpc_complex_args(self, exec_mode, sparse):
+        with dist_autograd.context():
+            num_tensors = 10
+            tensors = []
+            for i in range(num_tensors):
+                if sparse:
+                    tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
+                else:
+                    tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
+                tensors.append(tensor)
+            dst_rank = self._next_rank()
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank), torch.stack, args=(tensors,)
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            self.assertEqual(torch.stack(tensors), ret)
+
+            # Verify appropriate tensors have been attached the autograd graph.
+            next_funcs = next(
+                iter(dist_autograd._current_context()._send_functions().values())
+            ).next_functions
+            for i in range(len(next_funcs)):
+                self.assertEqual(
+                    "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
+                )
+                self.assertEqual(tensors[i], next_funcs[i][0].variable)
+
+            # Verify that the worker id has been recorded in the context
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(len(worker_ids), 1)
+            self.assertEqual(worker_ids, {dst_rank})
+
+    def context_cleanup_test_helper(self, rpc_args, func, nested=False):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        #  it is possible that the response could.
+        if nested:
+            dst_rank = (self.rank + 1) % self.world_size
+            nested_dst_rank = (dst_rank + 1) % self.world_size
+            dst_ranks = {dst_rank}
+        else:
+            dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+                if nested:
+                    rpc.rpc_sync(
+                        worker_name(nested_dst_rank),
+                        _set_rpc_done,
+                        args=(context_id, 2),
+                    )
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    def _backward_no_grad_on_tensor(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            self.assertIsNone(t1.grad)
+            self.assertIsNone(t2.grad)
+
+            # Now populate .grad with local autograd engine and
+            # verify dist autograd doesn't mess with it.
+            loss_local = torch.add(t1, t2)
+            if sparse:
+                loss_local = torch.sparse.sum(loss_local)
+            else:
+                loss_local = loss_local.sum()
+            loss_local.backward()
+            self.assertIsNotNone(t1.grad)
+            self.assertIsNotNone(t2.grad)
+
+            t1_grad_before = t1.grad
+            t2_grad_before = t2.grad
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(t1_grad_before, t1.grad)
+            self.assertEqual(t2_grad_before, t2.grad)
+
+    # The current rank first creates a tensor on the rref_owner, and then passes
+    # the rref with another tensor to the callee to run either my_rref_add or
+    # my_nested_rref_add, depending on whether the callee is the rref owner.
+    # The grad of tensor lives on the current rank, and the grad of the rref
+    # tensor lives on the rref owner.
+    def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            local_ret = torch.sparse.sum(local_ret)
+        else:
+            local_ret = local_ret.sum()
+        local_ret.backward()
+        with dist_autograd.context() as context_id:
+            if sparse:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    build_sparse_tensor,
+                    args=(
+                        False,
+                        True,
+                    ),
+                )
+            else:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    _torch_ones,
+                    args=((3, 3),),
+                    kwargs={"requires_grad": True},
+                )
+            if callee == rref_owner:
+                rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
+            else:
+                rref = rpc.remote(
+                    callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
+                )
+            ret = rref.to_here()
+            if sparse:
+                ret = torch.sparse.sum(ret)
+            else:
+                ret = ret.sum()
+            dist_autograd.backward(context_id, [ret])
+
+            # verify grads on caller
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertIn(t2, grads)
+            self.assertEqual(grads[t2], t2.grad)
+
+            # verify grads on rref owner
+            self.assertTrue(
+                rpc.rpc_sync(
+                    rref_owner,
+                    _compare_owner_value,
+                    args=(context_id, rref_t1, t1.grad),
+                )
+            )
+
+    # In this test, every rank will serve as a parameter server (ps) and a
+    # driver, and then kicks off trainers on the other three ranks. So, we have:
+    # ps = rank0 with trainers = rank1/2/3
+    # ps = rank2 with trainers = rank2/3/0
+    # ps = rank3 with trainers = rank3/0/1
+    # ps = rank4 with trainers = rank0/1/2
+    #
+    # These four test ps-trainer groups run on completely separate autograd
+    # graphs, but they share the same set of underlying RpcAgents.
+    def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
+        if sparse:
+            t1 = build_sparse_tensor(requires_grad=True)
+            t2 = build_sparse_tensor(requires_grad=True)
+        else:
+            t1 = torch.ones((3, 3), requires_grad=True)
+            t2 = torch.zeros((3, 3), requires_grad=True)
+
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            torch.sparse.sum(local_ret).backward()
+        else:
+            local_ret.sum().backward()
+
+        # create rref on self
+        rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=())
+
+        # kick off forward and backward pass on three other workers (trainers)
+        rank_diffs = [1, 2, 3]
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + rank_diff) % self.world_size),
+                trainer_fn,
+                args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
+            )
+            for rank_diff in rank_diffs
+        ]
+
+        # check if the trainers have done with their backward pass
+        for rank_diff in rank_diffs:
+            self._check_rpc_done(rank_diff)
+
+        # trainers are done and holding the context for verification
+        for rank_diff in rank_diffs:
+            # make sure grads are accumulated for the same tensors and values
+            # are all correct
+            ctx_id = ctx_ids[rank_diff]
+            grads = dist_autograd.get_gradients(ctx_id)
+            local_t1 = rref_t1.to_here()
+            self.assertIn(local_t1, grads)
+            self.assertEqual(grads[local_t1], t1.grad)
+
+        # unblock trainers
+        _set_rpc_done(None, 0)
+
+        # wait until all trainers are done
+        torch.futures.wait_all(futures)
+
+    def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                # Multiple RPCs between different nodes.
+                val = self._exec_func(exec_mode, torch.add, t1, t2)
+                val = self._exec_func(exec_mode, torch.mul, t3, val)
+                s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
+                s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
+                if sparse:
+                    val = self._exec_func(exec_mode, torch.mul, s1, s2)
+                    val = self._exec_func(exec_mode, torch.mul, val, val)
+                    loss = torch.sparse.sum(val)
+                else:
+                    val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+                    val = self._exec_func(exec_mode, torch.matmul, val, val)
+                    loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
+                )
+                local_grads = ret if ret else local_grads
+
+    def _backward_different_dtypes(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                loss = self._exec_func(exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(loss)
+                else:
+                    loss = loss.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_python_udf(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, my_py_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_script_call(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [
+            ExecMode.LOCAL,
+            ExecMode.RPC_SYNC,
+            ExecMode.RPC_ASYNC,
+            ExecMode.REMOTE,
+        ]:
+            with dist_autograd.context() as context_id:
+                forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(forward_ret)
+                else:
+                    loss = forward_ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    def _nested_backward_accumulate_grads(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            ret = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._test_nested_backward_accumulate_grads,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(ret)
+            else:
+                loss = ret.sum()
+            # Run backward twice.
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            dist_autograd.backward(context_id, [loss])
+
+    def _backwards_nested_python_udf(self, t1, t2, sparse):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = t3 + t4
+        loss = t1 * t2 * t3 * t4 * res
+        if sparse:
+            loss = torch.sparse.sum(loss)
+        else:
+            loss = loss.sum()
+        torch.autograd.backward([loss])
+
+        # Now run distributed autograd.
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_python_udf,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    def _mixed_requires_grad(self, t1, t2, sparse):
+        for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
+                )
+                self.assertEqual(t1 * t2, ret)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                dist_autograd.backward(context_id, [loss])
+                self.assertTrue(t1.requires_grad)
+                self.assertFalse(t2.requires_grad)
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertIn(t1, grads)
+                self.assertNotIn(t2, grads)
+                self.assertEqual(t2, grads[t1])
+
+    def _multiple_backward(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            # Run backward in a loop multiple times.
+            for _ in range(1000):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    # For current context, this rank sends t1 and t2 tensors to dst_rank,
+    # then get t3 = torch.add(t1, t2) result tensor.
+    # For the current context in this rank, it expects graph like this:
+    #  send function:
+    #              rpcSendBackward
+    #                  /          \
+    #  t1.AccumulateGrad         t2.AccumulateGrad
+    #
+    #  recv function:
+    #
+    #            |
+    #          t3.rpcRecvBackward
+    #
+    def _verify_graph_for_first_rpc_call(
+        self, send_function, recv_function, t1, t2, ret
+    ):
+        # Retrieve the next functions in the graph.
+        next_funcs = send_function.next_functions
+        self.assertEqual(2, len(next_funcs))
+
+        # We should now hit t1 and t2 in the autograd graph.
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name())
+        self.assertEqual(t1, next_funcs[0][0].variable)
+        self.assertEqual(0, next_funcs[0][1])
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name())
+        self.assertEqual(t2, next_funcs[1][0].variable)
+        self.assertEqual(0, next_funcs[1][1])
+
+        # Test recv functions.
+        self.assertEqual(ret.grad_fn, recv_function)
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple(self, dst, t1, t2, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
+    # result tensor t3 back.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions:
+    #       rpcSendBackward
+    #           |
+    #          t3.AddBackward0
+    #          /             \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    def _verify_graph_for_rpc_call_exec(self, send_function):
+        # Verify next function is AddBackward0
+        next_funcs = send_function.next_functions
+        self.assertEqual(1, len(next_funcs))
+        add_backward_fn = next_funcs[0][0]
+        self.assertEqual("AddBackward0", add_backward_fn.name())
+
+        # Verify the next two functions are the same recv backward function.
+        next_funcs = add_backward_fn.next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, forwards t1 and t2 tensors using
+    # nested rpc call to next dst. In return route, receive result tensor t3
+    # from next dst and forwarding t3 back to previous calls.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions for receiving and forwarding t1 and t2:
+    #       rpcSendBackward
+    #          /          \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    #  send and recv functions for receiving and forwarding t3:
+    #       rpcSendBackward
+    #             |
+    #           t3.recvRpcBackward
+    def _verify_graph_for_nested_rpc_call(self, ctx):
+        send_functions = ctx._send_functions()
+        self.assertEqual(2, len(send_functions))
+
+        # For send function when making nest rpc call,
+        # next functions of the send function are two recv functions
+        # for received two tensors from previous call
+        next_funcs = next(iter(send_functions.values())).next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+        # For send function when returning response to previous call
+        # next function of the send function is the recv function
+        # for received tensor result returned from nested call
+        next_funcs = list(send_functions.values())[1].next_functions
+        self.assertEqual(1, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+
+
+class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
+    # Sparse tests only work with TensorPipeAgent.
+    @dist_init
+    def test_graph_for_builtin_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_python_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_python_remote_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_rpc_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_remote_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor_sparse(self):
+        self._backward_no_grad_on_tensor(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_sparse(self):
+        self._backward_simple(
+            self._next_rank(),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_self_sparse(self):
+        self._backward_simple(
+            self.rank,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_multi_sparse(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                build_sparse_tensor(requires_grad=True),
+                build_sparse_tensor(requires_grad=True),
+                None,
+                True,
+            )
+
+    @dist_init
+    def test_backward_rref_sparse(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_nested_sparse(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_trainer_ps_sparse(self):
+        self._test_trainer_ps(build_sparse_tensor, _run_trainer, True)
+
+    @dist_init
+    def test_backward_multiple_round_trips_sparse(self):
+        self._backward_multiple_round_trips(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_different_dtypes_sparse(self):
+        self._backward_different_dtypes(
+            build_sparse_tensor(requires_grad=True, dtype=torch.float32),
+            build_sparse_tensor(requires_grad=True, dtype=torch.float64),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf_sparse(self):
+        self._backward_simple_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call_sparse(self):
+        self._backward_simple_script_call(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_nested_backward_accumulate_grads_sparse(self):
+        self._nested_backward_accumulate_grads(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backwards_nested_python_udf_sparse(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_mixed_requires_grad_sparse(self):
+        self._mixed_requires_grad(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            True,
+        )
+
+    @dist_init
+    def test_multiple_backward_sparse(self):
+        self._multiple_backward(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_embedding_bag_with_no_grad_tensors(self):
+        dst = self._next_rank()
+        remote_embedding = rpc.remote(
+            worker_name(dst),
+            torch.nn.EmbeddingBag,
+            args=(16, 16),
+            kwargs={"mode": "sum", "sparse": True},
+        )
+        local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True)
+
+        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
+        # requires_grad = True to record send/recv functions
+        per_sample_weights = torch.rand((8), requires_grad=True)
+        offsets = torch.LongTensor([0, 4])
+
+        local_res = local_embedding(input, offsets, per_sample_weights)
+
+        # Run backward twice.
+        torch.autograd.backward([local_res.sum()], retain_graph=True)
+        torch.autograd.backward([local_res.sum()])
+        local_grad = local_embedding.weight.grad
+
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._call_remote_embedding,
+                args=(remote_embedding, input, offsets, per_sample_weights),
+            )
+
+            # Run backward twice to test accumulation of sparse gradients.
+            dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [res.sum()])
+
+            remote_grad = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._get_grad,
+                args=(remote_embedding, context_id),
+            )
+
+            self.assertEqual(local_grad, remote_grad)
+
+
+class DistAutogradTest(CommonDistAutogradTest):
+    @dist_init
+    def test_autograd_context(self):
+        # Verify max possible id.
+        max_auto_increment = 281474976710655
+        self.assertEqual(
+            max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()
+        )
+
+        context_ids = []
+        for _ in range(200):
+            with dist_autograd.context() as context_id:
+                self.assertEqual(
+                    context_id,
+                    dist_autograd._retrieve_context(context_id)._context_id(),
+                )
+                # First 16 bits should be worker_id.
+                self.assertEqual(self.worker_id, context_id >> 48)
+                context_ids.append(context_id)
+
+        for context_id in context_ids:
+            with self.assertRaisesRegex(
+                RuntimeError,
+                f"Could not find autograd context with id: {context_id}",
+            ):
+                dist_autograd._retrieve_context(context_id)
+
+    @dist_init
+    def test_nested_context(self):
+        with (
+            dist_autograd.context(),
+            self.assertRaisesRegex(
+                RuntimeError, "Already have an autograd context id for this thread"
+            ),
+            dist_autograd.context(),
+        ):
+            pass
+
+    @dist_init
+    def test_graph_for_builtin_call(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_python_call(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_python_remote_call(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
+
+    def _test_grad_only_on_return_value(self, exec_mode):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            dist_autograd.backward(context_id, [ret.sum()])
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            grads = dist_autograd.get_gradients(ctx_ids[1])
+            self.assertEqual(1, len(grads))
+            self.assertIn(requires_grad_tensor, grads)
+            self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor])
+            # due to the above get_gradients call, ensure that dist autograd
+            # contexts aren't cleaned up until all workers exit context managers
+            dist.barrier()
+
+    @dist_init
+    def test_grad_only_on_return_value(self):
+        self._test_grad_only_on_return_value(ExecMode.RPC_SYNC)
+
+    @dist_init
+    def test_grad_only_on_return_value_remote(self):
+        self._test_grad_only_on_return_value(ExecMode.REMOTE)
+
+    @dist_init
+    def test_rpc_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_remote_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_no_tensors(self):
+        self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_worker_ids_recorded(self):
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+        with dist_autograd.context() as context_id:
+            # if no tensors require grad, we should still record worker_ids, as
+            # the autograd context ID is still passed to other workers.
+            t1 = torch.ones(3, 3, requires_grad=False)
+            t2 = torch.zeros(3, 3, requires_grad=False)
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+            # worker_ids should be recorded when tensors do require grad
+            t1.requires_grad = True
+            t2.requires_grad = True
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+    @dist_init
+    def test_dist_autograd_profiling(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            ).sum()
+            with torch.autograd.profiler.profile() as p:
+                dist_autograd.backward(context_id, [loss])
+
+        function_events = p.function_events
+
+        def get_event(partial_key):
+            return next(event for event in function_events if partial_key in event.name)
+
+        send_event = get_event("SendRpcBackward")
+        recv_event = get_event("RecvRpcBackward")
+        backward_event = get_event("torch::distributed::autograd::backward")
+        # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed.
+        self.assertEqual(send_event.count, 1)
+        self.assertEqual(recv_event.count, 1)
+        # The CPU total for backward event should be great than send and recv, since
+        # applying those functions in the backwards pass is a subset of the entire backward pass.
+        self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total)
+        self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total)
+
+    @dist_init
+    def test_error_in_context(self):
+        with dist_autograd.context():
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(6, 6, requires_grad=True)
+
+            with self.assertRaises(RuntimeError):
+                # This should throw an error since matrix sizes don't match.
+                rpc.rpc_sync(
+                    worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+                )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor(self):
+        self._backward_no_grad_on_tensor(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple(self):
+        self._backward_simple(
+            self._next_rank(),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_self(self):
+        self._backward_simple(
+            self.rank,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref_multi(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                torch.rand((3, 3), requires_grad=True),
+                torch.rand((3, 3), requires_grad=True),
+                None,
+                False,
+            )
+
+    @dist_init
+    def test_backward_rref_nested(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_trainer_ps(self):
+        self._test_trainer_ps(create_tensor, _run_trainer, False)
+
+    @dist_init
+    def test_trainer_ps_torchscript_functions(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        self._test_trainer_ps(
+            create_torchscript_tensor, _run_trainer_torchscript, False
+        )
+
+    @dist_init
+    def test_backward_multiple_round_trips(self):
+        self._backward_multiple_round_trips(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_different_tensor_dims(self):
+        local_grads = None
+        t1 = torch.rand((4, 6), requires_grad=True)
+        t2 = torch.rand((6, 5))
+        t3 = torch.rand((5, 7), requires_grad=True)
+        t4 = torch.rand((7, 9))
+
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                val = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4))
+                loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_unused_tensors(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        t3 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
+                val = self._exec_func(
+                    exec_mode,
+                    torch.matmul,
+                    torch.narrow(s, 0, 0, 1),
+                    torch.narrow(s, 0, 2, 1),
+                )
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_multiple_output_tensors(self):
+        local_grads = None
+        t = torch.rand((10, 2), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                tensor_list = self._exec_func(exec_mode, torch.split, t, 2)
+                t1 = tensor_list[0]
+                t2 = tensor_list[2]
+                t3 = tensor_list[4]
+
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3))
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t
+                )
+                local_grads = ret if ret else local_grads
+
+    def _run_test_backward_unused_send_function_in_thread(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            # We don't use the result of an RPC function, as a result the
+            # backward pass would hang in the "FAST" mode.
+            rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            val = torch.mul(t1, t2)
+
+            # Run backward, this would hang forever.
+            dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init
+    def test_backward_unused_send_function(self):
+        # Run the test in a thread which would never finish.
+        t = threading.Thread(
+            target=self._run_test_backward_unused_send_function_in_thread
+        )
+        t.daemon = True
+        t.start()
+        t.join(10)  # Wait for 10s.
+
+        # Verify thread is still alive (indicating backward hasn't completed yet).
+        self.assertTrue(t.is_alive())
+
+    @dist_init
+    def test_backward_autograd_engine_error(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            # Perform some ops before error simulation.
+            tmp = (t1 + t2) * (t1 + t2)
+            t3 = SimulateBackwardError.apply(tmp)
+
+            # Run multiple round trips across different nodes and verify the
+            # original node receives an error thrown on a node deep in the chain.
+            val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3))
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.mul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.div, args=(val, t2)
+            )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass"
+            ):
+                # Run backwards, and validate we receive an error.
+                dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure(self):
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            # Wait for all RPCs to be done.
+            dist.barrier()
+
+            # Kill all odd rank nodes.
+            if self.rank % 2 == 0:
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                # Wait for all other nodes to die.
+                for rank in range(self.world_size):
+                    if rank % 2 != 0:
+                        wait_until_node_failure(rank, shutdown_error_regex)
+
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex()
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since all
+                    # other nodes are dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+            else:
+                # Exit all other nodes.
+                pass
+
+    @dist_init
+    def test_backward_without_context(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+
+    @dist_init
+    def test_backward_without_rpc(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_backward_invalid_args(self):
+        with dist_autograd.context() as context_id:
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(context_id, None)
+
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(None, None)
+
+            with self.assertRaisesRegex(
+                RuntimeError, "No tensors provided for gradient computation"
+            ):
+                dist_autograd.backward(context_id, [])
+
+            with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
+                t = torch.rand(3, 3)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "is not a scalar, all roots need to be scalar"
+            ):
+                t = torch.rand(3, 3, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "does not have a valid gradient function"
+            ):
+                t = torch.rand(1, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+    @dist_init
+    def test_backward_multiple_roots(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+            with dist_autograd.context() as context_id:
+                r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum()
+                r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum()
+                r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
+                r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
+
+                local_grads = self._verify_backwards(
+                    exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
+                )
+
+    @dist_init
+    def test_backward_different_dtypes(self):
+        self._backward_different_dtypes(
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf(self):
+        self._backward_simple_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call(self):
+        self._backward_simple_script_call(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @staticmethod
+    def _complex_python_udf(t1, t2):
+        t3 = torch.nn.functional.linear(t1, t2)
+        t4 = torch.nn.functional.linear(t2, t3)
+        t5 = torch.nn.functional.linear(t3, t4)
+        return torch.linalg.multi_dot([t1, t2, t3, t4, t5])
+
+    @dist_init
+    def test_backward_complex_python_udf(self):
+        # Run the same code locally and with dist autograd and verify gradients
+        # are same.
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._complex_python_udf, t1, t2
+                )
+                loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    @staticmethod
+    def _python_udf_with_backward_error(t1, t2):
+        t3 = t1 + t2
+        t4 = SimulateBackwardError.apply(t3)
+        return torch.linalg.multi_dot([t1, t2, t3, t4])
+
+    @staticmethod
+    def _nested_rpc_call_backward_error(t1, t2, dst):
+        t1 = t1 * t2
+        t2 = t1 + t2
+        res = rpc.rpc_sync(
+            worker_name(dst),
+            DistAutogradTest._python_udf_with_backward_error,
+            args=(t1, t2),
+        )
+        return torch.linalg.multi_dot([t1, t2, res])
+
+    @dist_init
+    def test_backward_python_udf_error(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_rpc_call_backward_error,
+                args=(t1, t2, self._next_rank()),
+            )
+            with self.assertRaisesRegex(
+                RuntimeError, "Simulate error on backward pass"
+            ):
+                dist_autograd.backward(context_id, [loss.sum()])
+
+    _backward_done = False
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure_python_udf(self):
+        # Set a short timeout to quickly time out failed RPCs.
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            dst = self._next_rank()
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                my_py_nested_call,
+                args=(t1, t2, dst, self.world_size, 1),
+            )
+
+            dist.barrier()
+
+            # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error.
+            if self.rank == 2:
+                return
+
+            store = dist.distributed_c10d._get_default_store()
+            if self.rank == 0:
+                # Wait for rank 2 to die.
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                wait_until_node_failure(2, shutdown_error_regex)
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex().
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since rank 2 is dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+
+                # Mark rank 0 is done in the store, since the RPC framework on
+                # some nodes might be broken at this point.
+                store.set("test_backward_node_failure_python_udf_rank0_done", "True")
+            else:
+                # Wait for backward to finish on rank 0.
+                store.wait(
+                    ["test_backward_node_failure_python_udf_rank0_done"],
+                    timedelta(seconds=10),
+                )
+
+    @staticmethod
+    def _nested_python_udf(t1, t2, dst):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
+        return t1 * t2 * t3 * t4 * res
+
+    @dist_init
+    def test_backwards_nested_python_udf(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    _test_clean_context_backward_context_id = None
+
+    class MyBackwardFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            assert DistAutogradTest._test_clean_context_backward_context_id is not None
+
+            # Release the context to simulate error (use barrier before releasing
+            # context to ensure all nodes execute the backward function).
+            dist.barrier()
+            dist_autograd._release_context(
+                DistAutogradTest._test_clean_context_backward_context_id
+            )
+
+            # Verify all contexts are cleaned up.
+            assert _all_contexts_cleaned_up()
+
+            return input
+
+    @dist_init
+    def test_clean_context_during_backward(self):
+        """
+        This test simulates the situation where the 'backward' call might throw
+        an exception locally which would lead to the autograd context being
+        cleaned up if we're using the context manager. As a result, the autograd
+        context might be cleaned up while some threads are still using the
+        autograd context.
+
+        It is fine for the 'backward' call to throw an exception in this test,
+        but the process should not crash.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        context = dist_autograd._new_context()
+        context_id = context._context_id()
+        DistAutogradTest._test_clean_context_backward_context_id = context_id
+
+        # Send the context id to all nodes.
+        for i in range(self.world_size):
+            if i != self.rank:
+                rank_distance = (i - self.rank + self.world_size) % self.world_size
+                rpc.rpc_sync(
+                    worker_name(i),
+                    _set_rpc_done,
+                    args=(context_id, rank_distance),
+                )
+
+        dist.barrier()
+
+        # Verify all context ids have been received.
+        self.assertEqual(self.world_size - 1, len(known_context_ids))
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        for _ in range(100):
+            dst = self._next_rank()
+            t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
+
+        # Call MyBackwardFunc as the first op of the backward pass to
+        # ensure we release the context early in the backward pass.
+        t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
+        self.assertEqual(100, len(context._send_functions()))
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            dist_autograd.backward(context_id, [t1.sum()])
+
+        # HACK: Killing workers since otherwise the autograd engine gets stuck on
+        # other nodes. The proper fix would be addressing:
+        # https://github.com/pytorch/pytorch/issues/27643, which would inform
+        # other nodes about the failure.
+        # The autograd engine gets stuck on other nodes since they're waiting to
+        # receive gradients from the node that received an error (and as a
+        # result it didn't execute the rest of the graph).
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+        sys.exit(0)
+
+    @classmethod
+    def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
+        embedding = embedding_rref.local_value()
+        return embedding(input, offsets, per_sample_weights)
+
+    @classmethod
+    def _get_grad(cls, embedding_rref, context_id):
+        embedding = embedding_rref.local_value()
+        grad_map = dist_autograd.get_gradients(context_id)
+        return grad_map[embedding.weight]
+
+    @classmethod
+    def _mixed_requires_grad_operaton(cls, t1, t2):
+        if t2.requires_grad:
+            return t1 - t2
+        else:
+            return t1 * t2
+
+    @dist_init
+    def test_mixed_requires_grad(self):
+        self._mixed_requires_grad(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=False),
+            False,
+        )
+
+    class TestDebugInfoFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            debug_info = dist_autograd._get_debug_info()
+            assert debug_info is not None
+            backward_passes = int(debug_info["num_current_backward_passes"])
+
+            # Hard to validate exact numbers because of the distributed nature.
+            # We can't use a barrier() here since that would block the single
+            # CPU thread available for autograd and can cause deadlocks.
+            assert backward_passes >= 1 and backward_passes <= 4
+            return input
+
+    @dist_init
+    def test_debug_info(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            i = 0
+            res = {}
+            res[i] = t1
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            # Call custom function in middle of backward pass to ensure all
+            # nodes are still waiting on a backward().
+            res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i])
+            i += 1
+
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            dist_autograd.backward(context_id, [res[i].sum()])
+
+            debug_info = dist_autograd._get_debug_info()
+            num_autograd_context = int(debug_info["num_autograd_contexts"])
+            # Need at least one context and not more than 4.
+            self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4)
+
+        for rd in range(self.world_size - 1):
+            rpc.rpc_sync(
+                worker_name((self.rank + rd + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, rd + 1),
+            )
+
+        dist.barrier()
+
+        # Validate information
+        debug_info = dist_autograd._get_debug_info()
+        assert debug_info is not None
+        self.assertEqual(0, int(debug_info["num_current_backward_passes"]))
+        # only have `num_current_backward_passes` and `num_autograd contexts`
+        self.assertTrue(len(debug_info) == 2)
+
+        self.assertTrue(_all_contexts_cleaned_up())
+
+        # All contexts should be cleaned up.
+        debug_info = dist_autograd._get_debug_info()
+        self.assertEqual(0, int(debug_info["num_autograd_contexts"]))
+
+    @staticmethod
+    def _workload_thread():
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2))
+            t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3))
+            t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
+            t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
+
+            dist_autograd.backward(context_id, [t6.sum()])
+
+    @dist_init
+    def test_async_dist_autograd(self):
+        """
+        This test ensures async processing for distributed autograd works
+        appropriately. This is achieved by spawning multiple threads and
+        hammering a single node with a lot of backward() calls.
+        """
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            # All other ranks schedule work on rank 0.
+            threads = []
+            for _ in range(20):
+                t = threading.Thread(target=DistAutogradTest._workload_thread)
+                t.start()
+                threads.append(t)
+
+            for thread in threads:
+                thread.join()
+
+        dist.barrier()
+
+    @dist_init
+    def test_backward_accumulate_grads(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = torch.matmul(t1, t2)
+            # Run backward twice.
+            torch.autograd.backward([t3.sum()], retain_graph=True)
+            torch.autograd.backward([t3.sum()])
+
+            t3 = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+            )
+            # Run backward twice.
+            dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [t3.sum()])
+
+            # Verify the gradients are same for local and remote execution.
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @staticmethod
+    def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
+        return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+
+    @dist_init
+    def test_nested_backward_accumulate_grads(self):
+        self._nested_backward_accumulate_grads(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_multiple_backward(self):
+        self._multiple_backward(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init(clean_shutdown=False)
+    def test_multiple_backward_with_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                f"worker{self._next_rank()}",
+                DistAutogradTest._python_udf_with_backward_error,
+                args=(t1, t2),
+            ).sum()
+
+            try:
+                # Run backward in a loop multiple times.
+                for i in range(100):
+                    if i < 50:
+                        with self.assertRaisesRegex(
+                            RuntimeError, "Simulate error on backward pass"
+                        ):
+                            dist_autograd.backward(
+                                context_id, [loss], retain_graph=True
+                            )
+                    elif i > 50:
+                        # Recovered from error.
+                        dist_autograd.backward(context_id, [loss], retain_graph=True)
+                    else:
+                        dist.barrier()
+                        SimulateBackwardError._simulate_error = False
+                        dist.barrier()
+            finally:
+                # Sync before resetting flag.
+                dist.barrier()
+
+                # Reset the flag.
+                SimulateBackwardError._simulate_error = True
+
+    @dist_init
+    def test_backward_verify_hooks(self):
+        t1 = torch.ones((3, 3), requires_grad=True)
+        # Double the gradient.
+        t1.register_hook(lambda grad: grad * 2)
+        t2 = torch.ones((3, 3), requires_grad=True)
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_no_grad_copy(self):
+        """
+        Similar to test in test_autograd.py.
+        """
+
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad.data_ptr()
+                return grad, grad
+
+        class MyFuncSingleGrad(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFuncSingleGrad.static_grad_ptr = grad.data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            @staticmethod
+            def forward(ctx, inp1):
+                ctx.size = inp1.size()
+                return torch.tensor([1.0])
+
+            @staticmethod
+            def backward(ctx, grad):
+                return torch.ones(1).expand(ctx.size)
+
+        a = torch.randn(5, 6, requires_grad=True)
+        b = torch.randn(5, 6, requires_grad=True)
+        # non-contiguous grad should be copied
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(
+                context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]
+            )
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+        # test case that should trigger no copy for a
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFuncSingleGrad.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            # Verify there was no clone.
+            self.assertTrue(p_a == p_g)
+
+        # Test case that should trigger copy for both of a,b. This is
+        # different in the distributed autograd case since we hold
+        # a reference to all grads in a vector until all accumulation is done.
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            p_b = grads[b].data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # both should be copied.
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+    @dist_init
+    def test_no_grad_copy_sparse(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                # Create a sparse tensor with non-contiguous indices and values
+                # and return as grad.
+                v = torch.rand(1, 3)
+                i = torch.ones(1, 1, dtype=torch.long)
+                nv = v.expand(8, 3)
+                ni = i.expand(1, 8)
+                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
+                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
+                return ngrad, ngrad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        b = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        # test case that should trigger no copy for a.
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            # check a uses the same buffer
+            self.assertTrue(p_a == p_g)
+
+            # Run backwards multiple times.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+        # non-contiguous indices and value, we should trigger a copy.
+        with dist_autograd.context() as context_id:
+            emb_matrix = NonContGradFunc.apply(a, b)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = NonContGradFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            p_b = grads[b]._values().data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # Verify we cloned both grads.
+            self.assertFalse(p_a == p_g)
+            self.assertFalse(p_b == p_g)
+
+            # Run backwards multiple times to verify accumulation.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    @dist_init
+    def test_grad_copy_sparse_indices_extra_ref(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+            static_grad_indices_ref = None
+            static_grad_values_ref = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                # indices() and values() return views, so holding onto
+                # references of them would not increment refcount of indices
+                # and values inside the sparse tensor.
+                MyFunc.static_grad_indices_ref = grad._indices()
+                MyFunc.static_grad_values_ref = grad._values()
+                return grad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            self.assertIsNotNone(MyFunc.static_grad_indices_ref)
+            self.assertIsNotNone(MyFunc.static_grad_values_ref)
+            # grad would be stolen, since static_grad_indices_ref and
+            # static_grad_values_ref are holding onto views and don't bump the
+            # refcount.
+            self.assertTrue(p_g == p_a)
+
+    @dist_init
+    def test_post_hooks(self):
+        self.hook_called_times = 0
+
+        def post_hook_add_one(output_grads, input_grads):
+            self.hook_called_times += 1
+            return output_grads
+
+        def post_hook_add_two(output_grads, input_grads):
+            self.hook_called_times += 2
+            return output_grads
+
+        t = torch.rand(10, 10, requires_grad=True)
+        a = t + t
+
+        # Register post hooks
+        accumulate_grad_0 = a.grad_fn.next_functions[0][0]
+        accumulate_grad_0.register_hook(post_hook_add_one)
+        accumulate_grad_0.register_hook(post_hook_add_two)
+
+        accumulate_grad_1 = a.grad_fn.next_functions[1][0]
+        accumulate_grad_1.register_hook(post_hook_add_two)
+
+        with dist_autograd.context() as context_id:
+            loss = a.sum()
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(5, self.hook_called_times)
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads))
+            self.assertTrue(t in grads)
+
+    @staticmethod
+    def _slow_add(t1, t2):
+        time.sleep(1)
+        t3 = t1 + t2
+        t3.requires_grad = True
+        return t3
+
+    @dist_init
+    def test_thread_local_context_id(self):
+        t1 = torch.rand((3, 3))
+        t2 = torch.rand((3, 3))
+
+        t3 = t1 + t2
+        t3.requires_grad = True
+        t3.sum().backward()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2))
+
+        with dist_autograd.context() as context_id:
+            loss = rref.to_here().sum()
+            # due to slow add, the continuation of this backward pass will be
+            # invoked by the previous rpc.remote thread which does not have a
+            # valid context_id. So, this can test whether we propagate
+            # thread_local states properly when jumping across threads on the
+            # server side.
+            dist_autograd.backward(context_id, [loss])
+            self.assertTrue(
+                rpc.rpc_sync(
+                    dst, _compare_owner_value, args=(context_id, rref, t3.grad)
+                )
+            )
+
+
+class CudaDistAutogradTest(CommonDistAutogradTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_simple(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        (t1 + t2).sum().backward()
+        with dist_autograd.context() as context_id:
+            t3 = t1 + t2
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5)
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t7.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation_gpu_root(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t6.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+
+class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
+    # Reusing a simplified helper function from DistAutogradTest to ensure
+    # autograd context is successfully cleaned up even when RPCs are failing.
+    def context_cleanup_test_helper(self, rpc_args, func):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        # it is possible that the response could.
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+
+
+class WrapperModule(nn.Module):
+    def __init__(self, model, device):
+        super().__init__()
+        self.model = model.to(device)
+
+    def forward(self, *args):
+        return self.model(*args)
+
+    def gradients(self, ctx_id):
+        grads = dist_autograd.get_gradients(ctx_id)
+        return [grads[p] for p in self.model.parameters()]
+
+
+class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_backward_pass(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t1 = torch.rand(10, device=self.rank, requires_grad=True)
+        t2 = torch.rand(10, device=self.rank, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(dst, torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(torch.ones(10), grads[t1])
+            self.assertEqual(torch.ones(10), grads[t2])
+            self.assertEqual(t1.device, grads[t1].device)
+            self.assertEqual(t2.device, grads[t2].device)
+
+        rpc.shutdown()
+
+    class MyRemoteCompute(torch.nn.Module):
+        def forward(self, input):
+            input = input * 2.0
+            return input
+
+    class MyLocalCompute(torch.nn.Module):
+        def __init__(self, next_stage):
+            super().__init__()
+            self.next_stage = next_stage
+
+        def forward(self, input):
+            return self.next_stage.rpc_sync().forward(input)
+
+    @skip_if_lt_x_gpu(4)
+    def test_dist_autograd_sync_streams(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute)
+        local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute)
+        for _ in range(10):
+            input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
+            # Run local autograd
+            result = input * 2.0
+            r = random.random()
+            loss = result.sum() * r
+            loss.backward()
+
+            # Run distributed autograd
+            with dist_autograd.context() as context_id:
+                result = local_compute(input)
+                loss = result.sum() * r
+                dist_autograd.backward(context_id, [loss])
+
+                # Compare grads.
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertEqual(input.grad, grads[input])
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_gradients_synchronizations(self):
+        options = self.rpc_backend_options
+        for peer_rank in range(self.world_size):
+            options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # this is master
+            layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)]
+            local_layers = [l.to(0) for l in layers]
+            remote_layers = [
+                rpc.remote(
+                    worker_name(rank), WrapperModule, args=(layers[rank - 1], rank)
+                )
+                for rank in range(1, self.world_size)
+            ]
+
+            x = torch.randn(5000, 2000).to(0)
+            # local iteration
+            local_model = nn.Sequential(*local_layers)
+            local_model(x).sum().backward()
+
+            # remote iteration
+            with dist_autograd.context() as context_id:
+                for remote_layer in remote_layers:
+                    x = remote_layer.rpc_sync().forward(x)
+
+                dist_autograd.backward(context_id, [x.sum()])
+
+                futs = []
+                for remote_layer in remote_layers:
+                    futs.append(remote_layer.rpc_async().gradients(context_id))
+
+                for i in range(len(futs)):
+                    local_gradients = [p.grad for p in local_layers[i].parameters()]
+                    for g1, g2 in zip(futs[i].wait(), local_gradients, strict=True):
+                        self.assertEqual(g1, g2)
+
+        rpc.shutdown()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d335325f8364241dd14517da5c67c2a6e6a032b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
@@ -0,0 +1,281 @@
+# mypy: allow-untyped-defs
+
+
+import threading
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import optim
+from torch.distributed.optim import DistributedOptimizer
+from torch.testing._internal.dist_utils import dist_init
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class MyModule:
+    lock = threading.Lock()
+
+    def __init__(self, requires_grad=True):
+        # cannot directly use torch.manual_seed(0) as all threads share the same
+        # default generator. The race from multiple RPC threads could mess up
+        # the draw order from the default RNG instance, leading to
+        # non-deterministic behavior. Hence, create a dedicated RNG here.
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu)
+
+    def forward(self, t1):
+        return torch.mm(self.w, t1)
+
+    def get_w(self):
+        return self.w
+
+
+class FailingOptimizer(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+
+    def step(self, closure=None):
+        raise ValueError("Error running optimizer.")
+
+
+class OptimizerFailingOnConstructor(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+        raise ValueError("Error creating optimizer.")
+
+    def step(self, closure=None):
+        raise NotImplementedError
+
+
+def _call_method(method, obj_rref, *args, **kwargs):
+    return method(obj_rref.local_value(), *args, **kwargs)
+
+
+def remote_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.remote on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a RRef to the remote method call result.
+    """
+    return rpc.remote(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+def rpc_async_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.rpc_async on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a Future to the method call result.
+    """
+    return rpc.rpc_async(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+class DistOptimizerTest(RpcAgentTestFixture):
+    @dist_init()
+    def test_dist_optim_exception(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        dist_optim = DistributedOptimizer(
+            FailingOptimizer, [remote_param1, remote_param2]
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu = torch.Generator()
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1).sum()
+
+            dist_autograd.backward(context_id, [loss])
+            with self.assertRaisesRegex(Exception, "Error running optimizer"):
+                dist_optim.step(context_id)
+
+    @dist_init()
+    def test_dist_optim_exception_on_constructor(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        with self.assertRaisesRegex(Exception, "Error creating optimizer."):
+            DistributedOptimizer(
+                OptimizerFailingOnConstructor, [remote_param1, remote_param2]
+            )
+
+    def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule()
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
+            new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
+
+            # ensure optimizer changed weights
+            self.assertNotEqual(old_w1, new_w1)
+            self.assertNotEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim(self):
+        self._test_dist_optim_base(optim.Adagrad, lr=0.05)
+        self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
+        self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
+        self._test_dist_optim_base(optim.SGD, lr=0.05)
+        self._test_dist_optim_base(
+            optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True
+        )
+        self._test_dist_optim_base(optim.Adadelta, rho=0.95)
+        self._test_dist_optim_base(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_base(optim.Adamax, lr=0.05)
+        self._test_dist_optim_base(optim.Rprop, lr=0.05)
+
+    def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule(requires_grad=False)
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
+        remote_param1 = remote_module1.remote().get_w()
+        remote_param2 = remote_module2.remote().get_w()
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = remote_module1.rpc_async().forward(t2)
+            output2 = remote_module2.rpc_async().forward(output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = remote_module1.rpc_async().get_w().wait()
+            new_w2 = remote_module2.rpc_async().get_w().wait()
+
+            # ensure optimizer changed weights for w1
+            self.assertNotEqual(old_w1, new_w1)
+
+            # ensure optimizer not changed weights for w2
+            self.assertEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim_none_grads(self):
+        self._test_dist_optim_none_grads(optim.SGD, lr=0.05)
+        self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Rprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c884f37ac5ce85d5751e3412c7f2fe93989964d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..835d3198dc355574cd98a7a9123006407a65bf64
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c7be0cba31902b9286df1470845f1da3ebeb815
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad0b7fbe2207f8533da1eba8c23cda513f2bcf25
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
@@ -0,0 +1,140 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py
+# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server
+
+import threading
+from datetime import datetime
+from time import perf_counter
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch import optim
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+batch_size = 20
+in_features = 100
+out_features = 30
+num_batches = 4
+
+
+def timed_log(text):
+    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")
+
+
+class BatchUpdateParameterServer:
+    def __init__(self, batch_update_size):
+        self.model = nn.Linear(in_features, out_features)
+        self.lock = threading.Lock()
+        self.future_model = torch.futures.Future()
+        self.batch_update_size = batch_update_size
+        self.curr_update_size = 0
+        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
+        for p in self.model.parameters():
+            p.grad = torch.zeros_like(p)
+
+    def get_model(self):
+        return self.model
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def update_and_fetch_model(ps_rref, grads):
+        self = ps_rref.local_value()
+        for p, g in zip(self.model.parameters(), grads, strict=True):
+            if p.grad is None:
+                p.grad = g
+            else:
+                p.grad += g
+        with self.lock:
+            timed_log(
+                f"PS got {self.curr_update_size}/{self.batch_update_size} updates"
+            )
+            self.curr_update_size += 1
+            fut = self.future_model
+
+            if self.curr_update_size >= self.batch_update_size:
+                for p in self.model.parameters():
+                    p.grad /= self.batch_update_size
+                self.curr_update_size = 0
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+                fut.set_result(self.model)
+                timed_log("PS updated model")
+                self.future_model = torch.futures.Future()
+
+        return fut
+
+
+class Trainer:
+    def __init__(self, ps_rref):
+        self.ps_rref = ps_rref
+        self.loss_fn = nn.L1Loss()
+
+    def get_next_batch(self):
+        for _ in range(num_batches):
+            inputs = torch.randn(batch_size, in_features)
+            labels = torch.zeros(batch_size, out_features)
+            yield inputs, labels
+
+    def train(self):
+        name = rpc.get_worker_info().name
+        m = self.ps_rref.rpc_sync().get_model()
+        for inputs, labels in self.get_next_batch():
+            timed_log(f"{name} processing one batch")
+            self.loss_fn(m(inputs), labels).backward()
+            timed_log(f"{name} reporting grads")
+            m = rpc.rpc_sync(
+                self.ps_rref.owner(),
+                BatchUpdateParameterServer.update_and_fetch_model,
+                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
+            )
+            timed_log(f"{name} got updated model")
+
+
+def run_trainer(ps_rref):
+    trainer = Trainer(ps_rref)
+    trainer.train()
+
+
+def run_ps(trainers):
+    timed_log("Start training")
+    start = perf_counter()
+    ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
+    futs = [
+        rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers
+    ]
+
+    torch.futures.wait_all(futs)
+    stop = perf_counter()
+    timed_log("Finish training")
+    timed_log(f"Time spent training: {stop - start}s")
+
+
+class ParameterServerTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_batch_updating_parameter_server(self):
+        if self.rank != 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        else:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)])
+
+        rpc.shutdown()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..57008aed17dba34aacbc3b8a7a5b62c6dcbb5526
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
@@ -0,0 +1,265 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py
+# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html
+
+import numpy as np
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef
+from torch.distributions import Categorical
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+TOTAL_EPISODE_STEP = 5000
+GAMMA = 0.1
+SEED = 543
+
+
+def _call_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to call a method on the given RRef
+    """
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to run method on the owner of rref and fetch back the
+    result using RPC
+    """
+    args = [method, rref] + list(args)
+    return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
+
+
+class Policy(nn.Module):
+    r"""
+    Borrowing the ``Policy`` class from the Reinforcement Learning example.
+    Copying the code to make these two examples independent.
+    See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.affine1 = nn.Linear(4, 128)
+        self.dropout = nn.Dropout(p=0.6)
+        self.affine2 = nn.Linear(128, 2)
+
+        self.saved_log_probs = []
+        self.rewards = []
+
+    def forward(self, x):
+        x = self.affine1(x)
+        x = self.dropout(x)
+        x = F.relu(x)
+        action_scores = self.affine2(x)
+        return F.softmax(action_scores, dim=1)
+
+
+class DummyEnv:
+    r"""
+    A dummy environment that implements the required subset of the OpenAI gym
+    interface. It exists only to avoid a dependency on gym for running the
+    tests in this file. It is designed to run for a set max number of iterations,
+    returning random states and rewards at each step.
+    """
+
+    def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
+        self.state_dim = state_dim
+        self.num_iters = num_iters
+        self.iter = 0
+        self.reward_threshold = reward_threshold
+
+    def seed(self, manual_seed):
+        torch.manual_seed(manual_seed)
+
+    def reset(self):
+        self.iter = 0
+        return torch.randn(self.state_dim)
+
+    def step(self, action):
+        self.iter += 1
+        state = torch.randn(self.state_dim)
+        reward = torch.rand(1).item() * self.reward_threshold
+        done = self.iter >= self.num_iters
+        info = {}
+        return state, reward, done, info
+
+
+class Observer:
+    r"""
+    An observer has exclusive access to its own environment. Each observer
+    captures the state from its environment, and send the state to the agent to
+    select an action. Then, the observer applies the action to its environment
+    and reports the reward to the agent.
+    """
+
+    def __init__(self) -> None:
+        self.id = rpc.get_worker_info().id
+        self.env = DummyEnv()
+        self.env.seed(SEED)
+
+    def run_episode(self, agent_rref, n_steps):
+        r"""
+        Run one episode of n_steps.
+        Arguments:
+            agent_rref (RRef): an RRef referencing the agent object.
+            n_steps (int): number of steps in this episode
+        """
+        state, _ep_reward = self.env.reset(), 0
+        for _ in range(n_steps):
+            # send the state to the agent to get an action
+            action = _remote_method(Agent.select_action, agent_rref, self.id, state)
+
+            # apply the action to the environment, and get the reward
+            state, reward, done, _ = self.env.step(action)
+
+            # report the reward to the agent for training purpose
+            _remote_method(Agent.report_reward, agent_rref, self.id, reward)
+
+            if done:
+                break
+
+
+class Agent:
+    def __init__(self, world_size):
+        self.ob_rrefs = []
+        self.agent_rref = RRef(self)
+        self.rewards = {}
+        self.saved_log_probs = {}
+        self.policy = Policy()
+        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
+        self.eps = np.finfo(np.float32).eps.item()
+        self.running_reward = 0
+        self.reward_threshold = DummyEnv().reward_threshold
+        for ob_rank in range(1, world_size):
+            ob_info = rpc.get_worker_info(worker_name(ob_rank))
+            self.ob_rrefs.append(remote(ob_info, Observer))
+            self.rewards[ob_info.id] = []
+            self.saved_log_probs[ob_info.id] = []
+
+    def select_action(self, ob_id, state):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that instead of keeping all probs in one list,
+        the agent keeps probs in a dictionary, one key per observer.
+
+        NB: no need to enforce thread-safety here as GIL will serialize
+        executions.
+        """
+        probs = self.policy(state.unsqueeze(0))
+        m = Categorical(probs)
+        action = m.sample()
+        self.saved_log_probs[ob_id].append(m.log_prob(action))
+        return action.item()
+
+    def report_reward(self, ob_id, reward):
+        r"""
+        Observers call this function to report rewards.
+        """
+        self.rewards[ob_id].append(reward)
+
+    def run_episode(self, n_steps=0):
+        r"""
+        Run one episode. The agent will tell each observer to run n_steps.
+        """
+        # make async RPC to kick off an episode on all observers
+        futs = [
+            rpc_async(
+                ob_rref.owner(),
+                _call_method,
+                args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps),
+            )
+            for ob_rref in self.ob_rrefs
+        ]
+
+        # wait until all observers have finished this episode
+        for fut in futs:
+            fut.wait()
+
+    def finish_episode(self):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that it joins all probs and rewards from
+        different observers into one list, and uses the minimum observer rewards
+        as the reward of the current episode.
+        """
+
+        # joins probs and rewards from different observers into lists
+        R, probs, rewards = 0, [], []
+        for ob_id in self.rewards:
+            probs.extend(self.saved_log_probs[ob_id])
+            rewards.extend(self.rewards[ob_id])
+
+        # use the minimum observer reward to calculate the running reward
+        min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards)
+        self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
+
+        # clear saved probs and rewards
+        for ob_id in self.rewards:
+            self.rewards[ob_id] = []
+            self.saved_log_probs[ob_id] = []
+
+        policy_loss, returns = [], []
+        for r in rewards[::-1]:
+            R = r + GAMMA * R
+            returns.insert(0, R)
+        returns = torch.tensor(returns)
+        returns = (returns - returns.mean()) / (returns.std() + self.eps)
+        for log_prob, R in zip(probs, returns, strict=True):
+            policy_loss.append(-log_prob * R)
+        self.optimizer.zero_grad()
+        policy_loss = torch.cat(policy_loss).sum()
+        policy_loss.backward()
+        self.optimizer.step()
+        return min_reward
+
+
+def run_agent(agent, n_steps):
+    while True:
+        agent.run_episode(n_steps=n_steps)
+        agent.finish_episode()
+
+        if agent.running_reward > agent.reward_threshold:
+            print(f"Solved! Running reward is now {agent.running_reward}!")
+            break
+
+
+class ReinforcementLearningRpcTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_rl_rpc(self):
+        if self.rank == 0:
+            # Rank 0 is the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            agent = Agent(self.world_size)
+            run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1)))
+
+            # Ensure training was run. We don't really care about whether the task was learned,
+            # since the purpose of the test is to check the API calls.
+            self.assertGreater(agent.running_reward, 0.0)
+        else:
+            # Other ranks are observers that passively wait for instructions from the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..747155e3e1cbce8f8e8c14756fe3f98bf22a8987
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
@@ -0,0 +1,337 @@
+# mypy: allow-untyped-defs
+
+import time
+
+import torch
+import torch.distributed.rpc as rpc
+from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+class FaultyAgentRpcTest(RpcAgentTestFixture):
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init(messages_to_delay={})
+    def test_check_failed_messages(self):
+        if self.rank == 0:
+            dst_worker_b = worker_name((self.rank + 1) % self.world_size)
+            dst_worker_c = worker_name((self.rank + 2) % self.world_size)
+
+            # Worker0 sends RPC to Worker1 and creates an RRef there
+            rref = rpc.remote(
+                dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))
+            )
+            # Worker0 sends an RPC to Worker2 with the RRef as an arg
+            rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
+            # check if the output is as expected
+            self.assertEqual(
+                rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))
+            )
+        # explicitly delete all User RRefs
+        _delete_all_user_and_unforked_owner_rrefs()
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
+        self.assertEqual(
+            self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC
+        )
+
+    @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
+    def test_custom_faulty_messages(self):
+        self.assertEqual(
+            {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"},
+            set(self.rpc_backend_options.messages_to_fail),
+        )
+
+    @dist_init(faulty_messages=[])
+    def test_no_faulty_messages(self):
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0)
+
+    @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_custom_messages_to_delay(self):
+        self.assertEqual(
+            self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}
+        )
+
+    def _test_remote_message_dropped_pickle(self, dst=None):
+        if self.rank != 0:
+            return
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(1,))
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref._serialize()
+        # Test that using RRef as arg over RPC (which forks) results in the same
+        # error
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle(self):
+        self._test_remote_message_dropped_pickle()
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle_to_self(self):
+        self._test_remote_message_dropped_pickle(self.rank)
+
+    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+
+        # test the case where rpc.remote() message creation is completely dropped.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, func, args=args)
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+        # Note: during shutdown, logs will indicate "Could not find OwnerRRef..."
+        # on the owning nodes, this is expected because the OwnerRRef was never
+        # successfully created. Therefore, delAllUsers will work as expected.
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout_to_self(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    def _test_remote_message_delay_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+        # Test the case where remote message is eventually processed on the owner,
+        # but the future on the creator times out before the response comes back.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # to_here() should now pick up that rpc.remote() creation has failed.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        # Test the case where rpc.remote() times out, but to_here() has already
+        # started blocking before.
+        # NOTE: we only test this when not sending to self, as to_here() calls
+        # calls localValue(), which does not send an RPC and thus does not have
+        # a timeout. This can be supported by allowing future.wait() to
+        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
+        if dst_rank != self.rank:
+            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)
+
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                # to_here() should raise timeout error, since it does not know about the
+                # status of rpc.remote().
+                slow_rref.to_here(0.001)
+        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
+        # but this can be a noop since it may not exist on the owner yet. Later,
+        # the owner can process the RRef creation and wait for the delete message,
+        # thus leading to a timeout.
+        # Therefore, we wait until we get notification that pending owners have
+        # been confirmed before sending out RRefUserDeletes.
+        if dst_rank != self.rank:
+            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout_to_self(self):
+        func = my_sleep_func
+        args = (1,)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout_to_self(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref.to_here(0.01)
+
+        rref.to_here()
+
+    @dist_init(faulty_messages=[])
+    def test_rpc_builtin_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        # PYTHON_CALL message types which correspond to Python UDF over RPC
+        # by default get a delay (see faulty_rpc_agent_test_fixture)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(
+                dst_worker,
+                torch.add,
+                args=(torch.tensor(1), torch.tensor(1)),
+                timeout=1,
+            )
+
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_rpc_script_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
+
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff7d556d10621e7290c07ecb433b865d7133bb2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
@@ -0,0 +1,64 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+import torch.distributed.rpc._testing  # noqa: F401
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# The following message types are currently retried in the RREF protocol and
+# distributed autograd. Thus only these messages should be tested with the
+# Faulty RPC Agent.
+retryable_message_types = [
+    "RREF_FORK_REQUEST",
+    "RREF_CHILD_ACCEPT",
+    "RREF_USER_DELETE",
+    "CLEANUP_AUTOGRAD_CONTEXT_REQ",
+]
+
+# The following messages incur the corresponding delay in seconds while being
+# processed in FaultyTensorPipeAgent's enqueueSend() function.
+default_messages_to_delay = {
+    "PYTHON_CALL": 1.5,  # Python UDF
+    "SCRIPT_CALL": 1.5,  # Script/Builtin
+}
+
+
+class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.messages_to_fail = retryable_message_types
+        self.messages_to_delay = default_messages_to_delay
+
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend,
+            init_method=self.init_method,
+            num_worker_threads=8,
+            num_fail_sends=3,
+            messages_to_fail=self.messages_to_fail,
+            messages_to_delay=self.messages_to_delay,
+        )
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):
+        if faulty_messages is not None:
+            self.messages_to_fail = faulty_messages
+        if messages_to_delay is not None:
+            self.messages_to_delay = messages_to_delay
+
+    def get_shutdown_error_regex(self):
+        error_regexes = [
+            "Exception in thread pool task",
+            "Connection reset by peer",
+            "Connection closed by peer",
+        ]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bed8a9111efb090c86fce9f822c82e46c7187d0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be5a116cd79bf365a44974dc8eaaf9cfbce82049
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d28ce4f821f4783f04bdd40166a70142faf78fa
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17b256741e7f121714b0216981df04a2050463e4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde1fe2355c2968e1b351b288d20c674835b0ca2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import rpc_async
+from torch.testing import FileCheck
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def local_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def remote_add(t1, t2, dst: str):  # noqa: E999
+    return rpc_async(dst, local_add, (t1, t2)).wait()
+
+
+@torch.jit.script
+def fork_add(t1, t2, dst: str):
+    fut = torch.jit._fork(remote_add, t1, t2, dst)
+    return torch.jit._wait(fut)
+
+
+class JitDistAutogradTest(RpcAgentTestFixture):
+    @dist_init
+    def test_get_gradients(self):
+        @torch.jit.script
+        def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]:
+            return dist_autograd.get_gradients(context_id)
+
+        FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_get_gradients(context_id)
+
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_dist_backward(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def dist_backward_script(context_id: int, loss: torch.Tensor):
+            dist_autograd.backward(context_id, [loss])
+
+        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum()
+            dist_backward_script(context_id, loss)
+
+    @dist_init
+    def test_jit_fork_within_context(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            res = fork_add(t1, t2, dst_worker_name)
+            loss = res.sum()
+            dist_autograd.backward(context_id, [loss])
+
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+
+    @dist_init
+    def test_restore_context_after_swtich_to_jit_thread(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def forward_script(
+            context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
+        ) -> tuple[Tensor, Tensor]:
+            res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
+            res1 = res1_fut.wait()  # After this, the script runs in a new JIT thread.
+            loss1 = res1.sum()
+
+            # SendRpcBackward is not attached, since DistAutogradContext is lost here.
+            res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
+            res2 = res2_fut.wait()
+            loss2 = res2.sum()
+
+            return loss1, loss2
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.ones((2, 3), requires_grad=True)
+            t2 = torch.ones((2, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
+            dist_autograd.backward(context_id, [loss0, loss1])
+            grad0, grad1 = dist_autograd.get_gradients(context_id)
+            self.assertEqual(grad0, grad1)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a5d66e87f38672fe7076075b764a094bb81b4c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -0,0 +1,1384 @@
+# mypy: allow-untyped-defs
+
+import io
+import time
+from typing import Any
+
+import torch
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.autograd.profiler import record_function
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import RRef
+from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
+from torch.futures import Future
+from torch.testing._internal.common_utils import TemporaryFileName
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def rref_isinstance(rref, cls_to_check):
+    return isinstance(rref.local_value(), cls_to_check)
+
+
+def sleep(t):
+    time.sleep(t)
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+@torch.jit.script
+def rref_local_value(rref: RRef[Tensor]) -> Tensor:
+    return rref.local_value()
+
+
+@torch.jit.script
+def list_create() -> list[int]:
+    global_list = [1, 2, 3]
+    return global_list
+
+
+@torch.jit.script
+def rref_list_mutate(rref: RRef[list[int]]) -> None:
+    rref.local_value().append(4)
+    rref.to_here().append(5)
+    rref.to_here(5.0).append(6)
+
+
+def return_value(value: int) -> int:
+    return value
+
+
+class RRefAPITest:
+    @dist_init
+    def test_rref_is_owner(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref_var = rpc_return_rref(dst_worker_name)
+
+        @torch.jit.script
+        def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
+            return rref_var.is_owner()
+
+        res = rref_tensor_is_owner(rref_var)
+        self.assertEqual(res, False)
+
+    @dist_init
+    def test_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc_return_rref(dst_worker_name)
+
+        with self.assertRaisesRegex(
+            RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
+        ):
+            rref_local_value(rref)
+
+        ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
+        self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))
+
+    @dist_init
+    def test_local_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name(self.rank)
+        rref = rpc.remote(dst_worker_name, return_value, (5,), {})
+
+        ret = rref_local_value(rref)
+        self.assertEqual(ret, 5)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_list_mutate(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        list_rref = rpc.remote(dst, list_create)
+
+        rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
+        self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])
+
+
+@torch.jit.script
+def no_arg():
+    return 0
+
+
+@torch.jit.script
+def one_arg(value):
+    return value + 1
+
+
+@torch.jit.script
+def script_add_ones(x):
+    return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def script_add_ones_with_record_function(x, block: str):
+    with record_function(block):
+        return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
+    t: Tensor = torch.ones(1)
+    with record_function(block):
+        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        # Extra operator call to avoid de-duplication of the next async call
+        # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
+        zero = torch.zeros_like(t)
+        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        res = fut1.wait() + fut2.wait() + zero
+    return res
+
+
+@torch.jit.script
+def script_fork_wait_udf(tensor):
+    fut = torch.jit._fork(script_add_ones, tensor)
+    x = torch.jit._wait(fut)
+    return x
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def script_raise_func(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+@torch.jit.script
+def script_fork_wait_throw(invalue):
+    fut = torch.jit._fork(script_raise_func, invalue)
+    value = torch.jit._wait(fut)
+    return value
+
+
+@torch.jit.script
+def call_rpc_with_profiling(
+    record: torch.classes.profiler._RecordFunction, dst_worker_name: str
+) -> Tensor:
+    # Call rpc_async from within ScriptFunction and ensure that we can attach
+    # profiling callbacks. Note that handle here is a Tensor representation of
+    # RecordFunction.
+    fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def call_rpc_torchscript_with_record_function(
+    dst_worker_name: str, block: str
+) -> Tensor:
+    fut = rpc.rpc_async(
+        dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)
+    )
+    return fut.wait()
+
+
+@torch.jit.script
+def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor:
+    # Call fork from within ScriptFunction and ensure that we can attach profiling
+    # callbacks to the resulting future. Note that handle here is a Tensor
+    # representation of RecordFunction.
+    fut = torch.jit._fork(one_arg, torch.tensor(1))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
+    def __init__(self, dst_worker):
+        super().__init__()
+        self.rrefs = []
+        for _ in range(4):
+            self.rrefs.append(rpc_return_rref(dst_worker))
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        res_tensor = torch.ones(2, 2)
+        for rref in self.rrefs:
+            res_tensor += rref.to_here()
+
+        return res_tensor
+
+
+@torch.jit.ignore
+def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_python_annotation(rref_var).to_here()
+
+
+class RRefTypingTest:
+    @dist_init
+    def test_rref_as_arg_and_return(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        local_ret = one_arg(torch.ones(2, 2))
+
+        # create rref on current rank
+        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))
+
+        # pass rref to another user in rpc call
+        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(ret, local_ret)
+
+        # return rref in rpc call
+        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref1.to_here(), local_ret)
+
+        # pass rref to another user in remote call
+        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(rref2.to_here(), local_ret)
+
+        # return rref in remote call
+        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref3.to_here().to_here(), local_ret)
+
+    @dist_init
+    def test_my_script_module_with_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
+        res = module_with_rrefs()
+        self.assertEqual(res, torch.ones(2, 2) * 9)
+
+    @dist_init
+    def test_rref_python_annotation(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+
+        res = rref_script_annotation(rref_var)
+        self.assertEqual(res, torch.ones(2, 2) + 1)
+
+
+class FutureTypingTest:
+    @dist_init
+    def test_future_passed_between_python_and_jit(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
+        expected_res = torch.tensor([10, 10])
+
+        @torch.jit.script
+        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
+            return fut.wait()
+
+        self.assertEqual(future_wait_in_script(ret_fut), expected_res)
+
+        @torch.jit.script
+        def future_return_to_python(
+            dst_rank: int, inputs: tuple[Tensor, Tensor]
+        ) -> Future[Tensor]:
+            return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs)
+
+        fut_res = future_return_to_python(dst_rank, inputs)
+        self.assertEqual(fut_res.wait(), expected_res)
+
+    @dist_init
+    def test_future_python_annotation(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        input_0 = torch.ones(2, 2)
+        input_1 = 1
+        expected_res = torch.add(input_0, input_1)
+
+        @torch.jit.ignore
+        def python_return_future() -> Future[Tensor]:
+            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
+            return fut
+
+        @torch.jit.script
+        def script_use_future() -> Tensor:
+            fut = python_return_future()
+            return fut.wait()
+
+        res = script_use_future()
+        self.assertEqual(res, expected_res)
+
+
+@torch.jit.script
+class MyScriptClass:
+    def __init__(self, a: int):
+        self.a = a
+
+    def get_value(self) -> int:
+        return self.a
+
+
+@torch.jit.interface
+class MyModuleInterface(torch.nn.Module):
+    def forward(self) -> Tensor:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+class MyScriptModule(torch.jit.ScriptModule):
+    def __init__(self, rank):
+        super().__init__()
+        self.a = torch.ones(rank)
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        return self.a
+
+    @torch.jit.script_method
+    def custom_func(self) -> Tensor:
+        return self.a
+
+
+def owner_create_rref_my_script_class(a):
+    return rpc.RRef(MyScriptClass(a))
+
+
+def owner_create_rref_my_script_module(a):
+    return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface)
+
+
+@torch.jit.script
+def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int:
+    return rref.to_here().get_value()
+
+
+@torch.jit.script
+def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor:
+    return rref.to_here().forward()
+
+
+class LocalRRefTest:
+    @dist_init
+    def test_create_local_script_class_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_class = rpc.RRef(MyScriptClass(self.rank))
+        ret = rref_script_class.to_here().get_value()
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_create_local_script_module_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
+        ret = rref_script_module.to_here().forward()
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Create a local RRef without type hint.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                "The RRef being created contains a ScriptModule, "
+                "must provide its ModuleInterface type hint."
+            ),
+        ):
+            rref_script_module = rpc.RRef(MyScriptModule(self.rank))
+
+    @dist_init
+    def test_return_local_script_class_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner(), script_rref_get_value_my_script_class, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, self.rank)
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_return_local_script_module_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner_name(),
+                script_rref_run_forward_my_script_module,
+                args,
+                kwargs,
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+
+def python_function():
+    return 0
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def assorted_types_args_kwargs(
+    tensor_arg: Tensor,  # noqa: E999
+    str_arg: str,
+    int_arg: int,
+    tensor_kwarg: Tensor = torch.tensor([2, 2]),
+    str_kwarg: str = "str_kwarg",
+    int_kwarg: int = 2,
+):
+    return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg
+
+
+@torch.jit.script
+def raise_script():
+    raise RuntimeError("Expected error")
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def script_rpc_sync_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return res
+
+
+@torch.jit.script
+def script_rpc_remote_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return rref_res.to_here()
+
+
+class JitRpcOpTest:
+    # Call functions remotely from Script.
+    @dist_init
+    def test_all_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([10, 10]))
+
+    @dist_init
+    def test_some_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {"first_kwarg": torch.tensor([2, 2])}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([9, 9]))
+
+    @dist_init
+    def test_no_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([8, 8]))
+
+    @dist_init
+    def test_args_and_kwargs_contain_different_types(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_with_assorted_types(
+            dst_worker_name: str,
+        ):
+            args = (torch.tensor([1, 1]), "str_arg", 1)
+            # Must annotate the value type as `Any`, because JIT type inference
+            # does not support multiple types when defining a Dict.
+            # The error JIT gives is,
+            # "Dict values must contain only a single type, "
+            # "expected: Tensor but found str instead."
+            kwargs: dict[str, Any] = {
+                "tensor_kwarg": torch.tensor([3, 3]),
+                "str_kwarg": "_str_kwarg",
+                "int_kwarg": 3,
+            }
+            fut = rpc.rpc_async(
+                dst_worker_name, assorted_types_args_kwargs, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_with_assorted_types(dst_worker_name)
+        self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
+
+    @dist_init
+    def test_kwargs_not_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            args = ()
+            fut = rpc.rpc_async(dst_worker_name, no_arg, args)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_args_kwargs_are_neither_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_args_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            fut = rpc.rpc_async(dst_worker_name, no_arg)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_less_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_less_args(
+                dst_worker_name: str,  # noqa: E999
+            ):
+                args = (torch.tensor([1, 1]),)
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_more_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected at most 4 arguments but found 5 positional arguments",
+        ):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_more_args(
+                dst_worker_name: str,
+            ):
+                args = (
+                    torch.tensor([1, 1]),
+                    torch.tensor([2, 2]),
+                    torch.tensor([3, 3]),
+                    torch.tensor([4, 4]),
+                    torch.tensor([5, 5]),
+                )
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_unexepected_kwarg_is_specified(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, kwargs matching happens during execution.
+        @torch.jit.script
+        def script_rpc_async_call_with_unexpected_kwarg(
+            dst_worker_name: str,  # noqa: E999
+        ):
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {"third_kwarg": torch.tensor([1, 1])}
+            fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Unknown keyword argument 'third_kwarg'"
+        ):
+            ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_python_function_remotely_from_script_not_supported(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function"
+        ):
+            ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_raises_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, TorchScript always translates(emits) Python `raise` statement,
+        # as the exception message string, "Exception",
+        # no matter what exception type and exception message are in the statement,
+        @torch.jit.script
+        def rpc_async_call_remote_raising_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_not_exists_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def nonexisting_script():
+            return 0
+
+        @torch.jit.script
+        def rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function nonexisting_script"
+        ):
+            ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+
+@torch.jit.ignore
+def my_script_module_init(rank: int) -> MyModuleInterface:
+    return MyScriptModule(rank)
+
+
+@torch.jit.script
+def construct_my_script_module(rank: int) -> MyModuleInterface:
+    return my_script_module_init(rank)
+
+
+@torch.jit.script
+def run_ref_script_module(
+    ref_script_module: RRef[MyModuleInterface], t: Tensor
+) -> Tensor:
+    module = ref_script_module.to_here()
+    return module.forward() + t
+
+
+@torch.jit.script
+def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool:
+    return rref.confirmed_by_owner()
+
+
+@torch.jit.script
+def save_rref(rref_var: RRef[Tensor], fname: str) -> None:
+    torch.save(rref_var, fname)
+
+
+@torch.jit.script
+def script_add(x: Tensor, y: Tensor) -> Tensor:
+    return x + y
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
+    return rpc.rpc_async(to, script_add, (x, y))
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_wrong_type() -> Tensor:
+    return torch.zeros(2)
+
+
+def load_script_module_with_pickled_rref(pickled_script_module):
+    f = io.BytesIO(pickled_script_module)
+    m = torch.jit.load(f)
+    return m()
+
+
+class JitRpcTest(
+    RRefAPITest,
+    RRefTypingTest,
+    LocalRRefTest,
+    JitRpcOpTest,
+    FutureTypingTest,
+    RpcAgentTestFixture,
+):
+    @dist_init
+    def test_torchscript_function(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        local_ret = one_arg(torch.ones(2, 2))
+        ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(ret, local_ret)
+        rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(rref.to_here(), local_ret)
+        # create rref to itself
+        local_rref = rpc.remote(
+            worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)
+        )
+        self.assertEqual(local_rref.to_here(), local_ret)
+
+    @dist_init
+    def test_torchscript_function_exception(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
+
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.remote(dst_worker_name, one_arg, args=(10, 20))
+
+    @dist_init
+    def test_torchscript_functions_not_supported(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        my_local_script_module = MyScriptModule(self.rank)
+
+        # It is not thread safe to instantiate MyScriptModule in multiple threads,
+        # wait for local MyScriptModule instantiation to finish,
+        # otherwise it could instantiate MyScriptModule in parallel with
+        # server thread in the below
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        # rpc_sync still accepts script class and run it in
+        # the same code path as python call.
+        rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,))
+
+        # rpc_sync does not accept script module method.
+        # Python 3.5 and Python 3.6 throw different error message, the only
+        # common word can be greped is "pickle".
+        with self.assertRaisesRegex(TypeError, "pickle"):
+            rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())
+
+    @dist_init
+    def test_remote_script_module(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        local_ret = torch.ones(self.rank) + torch.ones(self.rank)
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        remote_ref = rpc.remote(
+            worker_name(dst_rank), construct_my_script_module, args=(self.rank,)
+        )
+
+        # pass rref arg to owner
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_ref_script_module,
+            args=(remote_ref, torch.ones(self.rank)),
+        )
+        self.assertEqual(ret, local_ret)
+
+        # pass rref arg to self/user
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "is an RRef to a ScriptModule. It can't be sent through RPC from owner,",
+        ):
+            ret = rpc.rpc_sync(
+                worker_name(self.rank),
+                run_ref_script_module,
+                args=(remote_ref, torch.ones(self.rank)),
+            )
+
+    @dist_init
+    def test_create_script_module_on_remote(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        # Construct on remote end with rpc_sync
+        created_script_module = rpc.rpc_sync(
+            dst_name, MyScriptModule, args=(self.rank,)
+        )
+        # Forward should output a ones tensor of self.rank.
+        self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = created_script_module()
+        self.assertEqual(torch.ones(self.rank), rank_ones_tensor)
+
+        # Construct ScriptModule with rpc.remote.
+        remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,))
+        # Verify it is an instance of ScriptModule on remote end.
+        remote_end_is_script = rpc.rpc_sync(
+            remote_script_module.owner(),
+            rref_isinstance,
+            args=(remote_script_module, torch.jit.ScriptModule),
+        )
+        self.assertTrue(remote_end_is_script)
+        # Run forward pass remotely.
+        remote_forward_output = remote_script_module.rpc_sync().forward()
+        self.assertEqual(remote_forward_output, torch.ones(self.rank))
+        # Run function defined on ScriptModule remotely.
+        remote_func_output = remote_script_module.rpc_sync().custom_func()
+        self.assertEqual(remote_func_output, torch.ones(self.rank))
+        # Ensure we can transfer ScriptModule RRef to this rank and run
+        # forward pass.
+        local_script_module = remote_script_module.to_here()
+        self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = local_script_module()
+        self.assertEqual(rank_ones_tensor, torch.ones(self.rank))
+        local_script_func_output = local_script_module.custom_func()
+        self.assertEqual(local_script_func_output, torch.ones(self.rank))
+
+    @dist_init
+    def test_load_script_module_with_pickled_rref(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        m1 = MyScriptModuleWithRRefs(dst_name)
+        m2 = MyScriptModuleWithRRefs(dst_name)
+
+        f = io.BytesIO()
+
+        rpc._enable_jit_rref_pickle()
+        torch.jit.save(m1, f)
+        rpc._disable_jit_rref_pickle()
+
+        out1 = rpc.rpc_sync(
+            dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),)
+        )
+        out2 = m2()
+        self.assertEqual(out1, out2)
+
+    @dist_init
+    def test_rref_jit_pickle_not_supported(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+        with (
+            TemporaryFileName() as fname,
+            self.assertRaisesRegex(
+                RuntimeError, "RRef jit pickling is only allowed inside RPC calls"
+            ),
+        ):
+            save_rref(rref_var, fname)
+
+    @dist_init
+    def test_remote_script_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_raise_func,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_remote_script_udf(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_async_script_udf(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_callback_simple(self):
+        def callback(fut):
+            return fut.wait() + 1
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        ).then(callback)
+        self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_add_done_callback(self):
+        callback_called = None
+
+        def callback(fut):
+            nonlocal callback_called
+            callback_called = fut.wait() * 2
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+
+        future.add_done_callback(callback)
+        future_then = future.then(lambda _: True)
+
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        future_then.wait()
+        self.assertEqual(callback_called, torch.ones(2) * 4)
+
+    @dist_init
+    def test_async_script_throw(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            future.wait()
+
+    @dist_init
+    def test_callback_with_exception(self):
+        def callback(fut):
+            with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+                fut.wait()
+            raise RuntimeError("Another expected error")
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        ).then(callback)
+
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            future.wait()
+
+    @dist_init
+    def test_call_rpc_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function that calls rpc_async
+        if self.rank == 0:
+            with _profile() as prof:
+                prof_key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(one_arg),
+                    "worker0",
+                    "worker1",
+                )
+                with torch.autograd.profiler.record_function(prof_key) as rf:
+                    call_rpc_with_profiling(rf.record, "worker1")
+            # TODO: Can't get a reliable time for this profiling event since
+            # it's hard to estimate the execution time on the remote end for non-UDFs.
+            # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
+            # After that, this test should be modified to validate the function time.
+            events = prof.function_events
+            function_event = get_function_event(events, prof_key)
+            self.assertTrue(
+                torch._jit_internal._qualified_name(one_arg) in function_event.name
+            )
+
+    @dist_init
+    def test_rpc_async_jit_profiled(self):
+        # Tests that rpc_async calls made from within a TorchScript function are
+        # profiled.
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {}
+            with _profile() as prof:
+                script_rpc_async_call(dst_worker_name, args, kwargs)
+
+            # Ensure rpc_async call is profiled
+            function_events = prof.function_events
+            qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs)
+            rpc_async_jit_event = [
+                event
+                for event in function_events
+                if qual_name in event.name and event.node_id == self.rank
+            ]
+            self.assertEqual(len(rpc_async_jit_event), 1)
+            rpc_async_jit_event = rpc_async_jit_event[0]
+            profiled_name = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                qual_name,
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            self.assertEqual(profiled_name, rpc_async_jit_event.name)
+            remote_events = [event for event in function_events if event.is_remote]
+            # All remote events should have taken place on dst_rank
+            remote_event_node_ids = {
+                remote_event.node_id for remote_event in remote_events
+            }
+            self.assertEqual(remote_event_node_ids, {dst_rank})
+            # script_rpc_async_call invokes add operator
+            # so we should see this as a remote event.
+            remote_add = next(
+                remote_event
+                for remote_event in remote_events
+                if "aten::add" in remote_event.name
+            )
+            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
+            self.assertEqual(remote_add.name, remote_add_profiled_name)
+
+    @dist_init
+    def test_record_function_on_caller_rpc_async(self):
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                # Runs 2 rpc_async calls within JIT under record_function.
+                record_function_on_caller_rpc_async(dst_worker_name, block_scope)
+
+            # Ensure record_function event is profiled.
+            function_events = prof.function_events
+            record_function_scope_event = [
+                event for event in function_events if event.name == block_scope
+            ]
+            self.assertEqual(1, len(record_function_scope_event))
+            record_function_scope_event = record_function_scope_event[0]
+            # Ensure RPC future is profiled.
+            expected_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                torch._jit_internal._qualified_name(script_add_ones),
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            jit_rpc_events = [
+                event for event in function_events if event.name == expected_key
+            ]
+            self.assertEqual(2, len(jit_rpc_events))
+            # Validate that the record_function scope time is greater than both
+            # of the individual RPC async call times. The reason it is not necessarily
+            # greater than the sum is because the two can execute in parallel.
+            for jit_rpc_event in jit_rpc_events:
+                self.assertTrue(
+                    record_function_scope_event.cpu_time_total
+                    > jit_rpc_event.cpu_time_total
+                )
+
+    @dist_init
+    def test_rpc_torchscript_record_function(self):
+        # tests that torchscript functions can be profiled using with
+        # record_function(...) over RPC.
+        REMOTE_OP_STR = "#remote_op: "
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                call_rpc_torchscript_with_record_function(dst_worker_name, block_scope)
+
+            # Need to call below to populate CPU children.
+            prof.key_averages()
+            function_events = prof.function_events
+            expected_key = (
+                _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC_JIT,
+                    torch._jit_internal._qualified_name(
+                        script_add_ones_with_record_function
+                    ),
+                    worker_name(self.rank),
+                    dst_worker_name,
+                )
+                + REMOTE_OP_STR
+                + block_scope
+            )
+            remote_record_function_event = next(
+                evt for evt in function_events if evt.name == expected_key
+            )
+            self.assertTrue(block_scope in remote_record_function_event.name)
+            remote_children = remote_record_function_event.cpu_children
+            self.assertTrue("aten::add" in child.name for child in remote_children)
+
+    def test_record_function_jit_end_callbacks_with_fork(self):
+        # Ensures that we can call rf._call_end_callbacks_on_future on a jit
+        # future in python eager mode with torch.jit.fork
+        sleep_interval = 1
+        with _profile() as prof:
+            with torch.autograd.profiler.record_function("foo") as rf:
+                fut = torch.jit._fork(sleep, sleep_interval)
+                rf._call_end_callbacks_on_future(fut)
+            fut.wait()
+
+        function_events = prof.function_events
+        sleep_event = get_function_event(function_events, "foo")
+        self.assertEqual(sleep_event.name, "foo")
+        # Validate that callbacks were fired at the right time by checking the
+        # profiling event cpu time
+        self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
+
+    def test_call_fork_in_jit_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function with torch.jit.fork
+        with _profile() as prof, torch.autograd.profiler.record_function("foo") as rf:
+            call_fork_with_profiling(rf.record)
+
+        events = prof.function_events
+        function_event = get_function_event(events, "foo")
+        self.assertEqual(function_event.name, "foo")
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type
+            )
+
+    @dist_init
+    def test_async_function_wrong_decorator_order(self):
+        # @torch.jit.script complains about undefined value rpc. Error is shown
+        # below. The reason for not checking error string is to avoid making
+        # JIT error handling code depend on RPC tests, as we don't have any
+        # restrictions on the error message here.
+        #
+        # RuntimeError:
+        # undefined value rpc:
+        # def async_wrong_decorator_order(to, x, y):
+        #    # type: (str, Tensor, Tensor) -> Future[Tensor]
+        #    return rpc.rpc_async(to, script_add, (x, y))
+        #           ~~~ <--- HERE
+        with self.assertRaises(RuntimeError):
+
+            @torch.jit.script
+            @rpc.functions.async_execution
+            def async_wrong_decorator_order(
+                to: str, x: Tensor, y: Tensor
+            ) -> Future[Tensor]:
+                return rpc.rpc_async(to, script_add, (x, y))
+
+    @dist_init
+    def test_async_function_remote(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        rref = rpc.remote(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_remote_multi(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        rrefs = [
+            rpc.remote(
+                dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
+            )
+            for i in range(num)
+        ]
+
+        for i in range(num):
+            self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size), async_wrong_type
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rref.to_here()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bedaad32d0e904a9a7523f31eced9cef96e832d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
@@ -0,0 +1,219 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import RRef
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout_future_ret(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    return fut
+
+
+@torch.jit.script
+def rpc_async_call_future_ret(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return fut
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
+    return rref_var.to_here(timeout)
+
+
+@torch.jit.script
+def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor:
+    fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
+    ret = fut.wait()
+    return ret
+
+
+class JitFaultyAgentRpcTest(RpcAgentTestFixture):
+    """
+    Run tests for rpc_async in JIT under the faulty agent test fixture to test
+    arbitrary timeouts.
+    """
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_torchscript_function(self):
+        # Call rpc_async + fut.wait() in torchscript function and ensure that
+        # timeout is raised.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+        # Ensure that we get a timeout if we override the default timeout and
+        # the RPC takes longer to execute.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
+
+        # Ensure that we timeout if we don't specify a timeout but the default
+        # is less than the RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            script_rpc_async_call(dst_worker_name, args, kwargs)
+
+        # Ensure that we run to completion if zero timeout is specified.
+        ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
+        self.assertEqual(ret, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_python(self):
+        # Ensures timeouts are raised if we call rpc_async from within a
+        # torchscript function, but wait on the future in python.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure timeout if we don't specify but the default is less than the
+        # RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if zero timeout is specified
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
+        result = fut.wait()
+        self.assertEqual(result, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_remote_timeout_to_here_in_jit(self):
+        # Test that calling to_here() in JIT will raise timeout error if
+        # rpc.remote failed.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call to_here() within a ScriptFunction and ensure it raises
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref_to_here(rref)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout_in_jit(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref_to_here_with_timeout(rref, 0.01)
+
+        rref_to_here_with_timeout(rref, 100)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_in_jit(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with RRef arg in JIT, which will go through JIT pickling and
+        # ensure error is raised.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc_async_with_rref_arg(dst_worker, (rref,))
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_script_func(self):
+        # Similar to above test, but calls python rpc with script function.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with script function that takes RRef, ensure timeout during pickling
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a684b73d2f315a00465371fad3050a795251ddb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
@@ -0,0 +1,63 @@
+# mypy: allow-untyped-defs
+
+import os
+from abc import ABC, abstractmethod
+
+import torch.testing._internal.dist_utils
+
+
+class RpcAgentTestFixture(ABC):
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    @property
+    def init_method(self):
+        use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+        if use_tcp_init == "1":
+            master_addr = os.environ["MASTER_ADDR"]
+            master_port = os.environ["MASTER_PORT"]
+            return f"tcp://{master_addr}:{master_port}"
+        else:
+            return self.file_init_method
+
+    @property
+    def file_init_method(self):
+        return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
+            file_name=self.file_name
+        )
+
+    @property
+    @abstractmethod
+    def rpc_backend(self):
+        pass
+
+    @property
+    @abstractmethod
+    def rpc_backend_options(self):
+        pass
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):  # noqa: B027
+        """Method used by dist_init to prepare the faulty agent.
+
+        Does nothing for other agents.
+        """
+
+    # Shutdown sequence is not well defined, so we may see any of the following
+    # errors when running tests that simulate errors via a shutdown on the
+    # remote end.
+    @abstractmethod
+    def get_shutdown_error_regex(self):
+        """
+        Return various error message we may see from RPC agents while running
+        tests that check for failures. This function is used to match against
+        possible errors to ensure failures were raised properly.
+        """
+
+    @abstractmethod
+    def get_timeout_error_regex(self):
+        """
+        Returns a partial string indicating the error we should receive when an
+        RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
+        have the right errors during timeout.
+        """
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50aadc058cbdd2d5e08b4df711572828b2f2ee9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -0,0 +1,6312 @@
+# mypy: allow-untyped-defs
+
+import concurrent.futures
+import contextlib
+import json
+import operator
+import os
+import sys
+import threading
+import time
+from collections import namedtuple
+from functools import partial
+from threading import Event, Lock
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import (
+    _get_debug_info,
+    _rref_context_get_debug_info,
+    RRef,
+    WorkerInfo,
+)
+from torch.distributed.rpc.api import _thread_local_var, _use_rpc_pickler, _wait_all
+from torch.distributed.rpc.internal import (
+    _build_rpc_profiling_key,
+    _internal_rpc_pickler,
+    PythonUDF,
+    RPCExecMode,
+)
+from torch.futures import Future
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    skip_if_lt_x_gpu,
+    tp_transports,
+)
+from torch.testing._internal.common_utils import (
+    get_cycles_per_ms,
+    IS_MACOS,
+    load_tests,
+    skip_but_pass_in_sandcastle_if,
+    TemporaryFileName,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    wait_until_node_failure,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def foo_add():
+    return torch.add(torch.ones(1), torch.ones(1))
+
+
+def udf_with_torch_ops(device=-1, use_record_function=False):
+    device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
+    record_function_ctx = (
+        torch.autograd.profiler.record_function("##forward##")
+        if use_record_function
+        else contextlib.nullcontext()
+    )
+    with device_ctx, record_function_ctx:
+        t1, t2 = torch.ones(1), torch.ones(1)
+        t = torch.add(t1, t2)
+        t = torch.mul(t, t)
+        t = t.relu()
+        t = t.sigmoid()
+
+
+# Events (operator invocations) that are expected to be ran as part of the above
+# function.
+EXPECTED_REMOTE_EVENTS = [
+    "aten::ones",
+    "aten::ones",
+    "aten::add",
+    "aten::mul",
+    "aten::relu",
+    "aten::clamp_min",
+    "aten::sigmoid",
+]
+
+# Remote operations are prefixed with the following string for RPC profiling.
+REMOTE_OP_STR = "#remote_op: "
+
+
+VALUE_FUTURE = concurrent.futures.Future()
+DONE_FUTURE = concurrent.futures.Future()
+
+FIFTY_MIL_CYCLES = 50000000
+
+_rpc_barrier_count = 0
+
+
+def _increment_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count += 1
+
+
+def _reset_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count = 0
+
+
+class StubRpcAgent:
+    def __init__(self, world_size):
+        self.world_size = world_size
+
+    def get_worker_infos(self):
+        return {
+            WorkerInfo(name=worker_name(rank), id=rank)
+            for rank in range(self.world_size)
+        }
+
+
+def _stub_construct_rpc_backend_options_handler(**kwargs):
+    return mock.Mock()  # RpcBackendOptions.
+
+
+def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options):
+    return StubRpcAgent(world_size=world_size)
+
+
+def set_value(value):
+    VALUE_FUTURE.set_result(value)
+
+
+def wait_for_value_future():
+    return VALUE_FUTURE.result()
+
+
+def set_and_check_done(value):
+    VALUE_FUTURE.set_result(value)
+    return DONE_FUTURE.result()
+
+
+# it is used to test python user defined function over rpc
+# classes and functions are used to test python user defined class and
+# methods over rpc
+TensorClass = namedtuple("TensorClass", ["tensors"])
+
+
+class MyPickleClass:
+    def __init__(self) -> None:
+        self.t = None
+
+    def __getstate__(self):
+        (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
+            PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
+        )
+        return (pickled_python_udf, tensors)
+
+    def __setstate__(self, obj):
+        python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
+        result = python_udf.func(python_udf.args[0], python_udf.args[1])
+        self.t = result
+
+    def set(self, val):
+        self.t = val
+
+
+class SlowPickleClass:
+    def __init__(self, t):
+        self.t = t
+
+    def __getstate__(self):
+        time.sleep(self.t)
+        return (self.t,)
+
+    def __setstate__(self, obj):
+        self.t = obj[0]
+        time.sleep(self.t)
+
+
+class MyClass:
+    def __init__(self, a, delay=False):
+        self.a = a
+        # delay initialization to simulate errors if specified
+        if delay:
+            time.sleep(2)
+
+    def my_instance_method(self, b):
+        return self.a + b
+
+    @classmethod
+    def my_class_method(cls, d, e):
+        return d + e
+
+    @staticmethod
+    def my_static_method(f):
+        return f > 10
+
+    def increment_value(self, increment):
+        self.a += increment
+
+    def get_value(self):
+        return self.a
+
+    def my_slow_method(self, my_tensor_arg):
+        time.sleep(5)
+        return torch.add(self.a, my_tensor_arg)
+
+
+def _call_method_on_rref(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def get_rref_list(values):
+    return [RRef(MyClass(a)) for a in values]
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+def run_nested_pickle(pickle_cls_instance, tensor):
+    return pickle_cls_instance.t + tensor
+
+
+def build_sparse_tensor(coalesce=False):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3, 4, 5]
+    tensor = torch.sparse_coo_tensor(i, v, (2, 3))
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+def build_complex_tensors():
+    a = torch.ones(3, 3)
+    b = [a, a]
+    c = [b, b]
+    d = [a, b]
+    e = {a: d}
+    return [a, b, c, d, e]
+
+
+def non_cont_test(t_view, t_cont):
+    if t_view.is_contiguous():
+        raise Exception("t_view is contiguous!")  # noqa: TRY002
+    if not t_cont.is_contiguous():
+        raise Exception("t_cont is not contiguous!")  # noqa: TRY002
+    if not torch.equal(t_view, t_cont):
+        raise Exception("t_view is not equal to t_cont!")  # noqa: TRY002
+    return t_view
+
+
+def my_function(a, b, c):
+    return a + b + c
+
+
+def my_tensor_function(a, b):
+    return a + b
+
+
+def my_container_sum(a):
+    result = a[0]
+    for tensor in a[1:]:
+        result += tensor
+    return result
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
+    res = list_input[0]
+    for t in list_input:
+        res += t
+    for v in dict_input.values():
+        res += v
+    complex_tensors = tensor_class_input.tensors
+    return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
+
+
+def my_rref_function(rref_a, rref_b):
+    return rref_a.to_here() + rref_b.to_here()
+
+
+def delayed_add(a, b, seconds=0.05):
+    time.sleep(seconds)
+    return a + b
+
+
+def identity(a):
+    return a
+
+
+def no_result():
+    print("do nothing")
+
+
+def raise_or_inc(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+def nested_rpc(dst):
+    return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def nested_rpc_sparse(dst):
+    return rpc.rpc_sync(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+
+
+def multi_layer_nested_async_rpc(dst, world_size, ttl):
+    # this method returns immediately without blocking the callee, but will
+    # generate additional requests.
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        rpc.rpc_async(
+            current_dst,
+            multi_layer_nested_async_rpc,
+            args=(next_dst, world_size, ttl - 1),
+        )
+        return 0
+
+
+def nested_rref(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
+    )
+
+
+def nested_rref_sparse(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+    )
+
+
+def nested_remote(dst):
+    rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
+    return rref.to_here()
+
+
+def nested_remote_sparse(dst):
+    rref = rpc.remote(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+    return rref.to_here()
+
+
+def rref_forward_chain(dst, world_size, rref, ttl):
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        ret_rref = rpc.remote(
+            current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
+        )
+        return [ret_rref]
+    else:
+        return rref.to_here()
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def light_rpc():
+    return 0
+
+
+def heavy_rpc(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+def heavy_rpc_sparse(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor = tensor / (i + 1)
+    return 0
+
+
+@torch.jit.script
+def heavy_rpc_torchscript(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+expected_err = "Expected error"
+
+
+# Note that it needs to inherit from Exception, not BaseException. See comment
+# in rpc/internal.py
+class CustomException(Exception):
+    def __init__(self, bool, msg):
+        self.bool = bool
+        super().__init__(msg)
+
+
+def raise_func():
+    raise ValueError(expected_err)
+
+
+def custom_raise_func():
+    raise CustomException(True, "foo")
+
+
+@torch.jit.script
+def raise_func_script(expected_err: str) -> torch.Tensor:
+    raise ValueError(expected_err)
+
+
+expected_err_escape = (
+    "\nFirst line of error \n next line of error \n last line of error"
+)
+
+
+def raise_func_escape():
+    raise ValueError(expected_err_escape)
+
+
+global_rref = None
+
+
+def set_global_rref(rref):
+    global global_rref
+    global_rref = rref
+
+
+def clear_global_rref():
+    global global_rref
+    global_rref = None
+
+
+def check_rref_confirmed(rref):
+    return rref.confirmed_by_owner()
+
+
+def get_rref_debug_info():
+    return _rref_context_get_debug_info()
+
+
+def add_use_future_cb(to, x, y, z):
+    out = concurrent.futures.Future()
+
+    def callback(fut):
+        out.set_result(fut.wait() + z)
+
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(callback)
+    return out.result()
+
+
+def get_events_from_profile(profile_rref):
+    return profile_rref.local_value().process_global_function_events
+
+
+def add_use_future_set_result(to, x, y, z):
+    out = torch.futures.Future()
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(lambda fut: out.set_result(fut.wait() + z))
+    return out.wait()
+
+
+def add_use_future_nested_cb(to, x, y, z):
+    out = torch.futures.Future()
+
+    def callback(fut1):
+        fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
+        fut2.then(lambda fut2: out.set_result(fut2.wait()))
+
+    fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut1.then(callback)
+    return out.wait()
+
+
+def fail_on_fut(fut):
+    pass
+
+
+@rpc.functions.async_execution
+def async_raise_func():
+    raise RuntimeError("Expected error")
+
+
+@rpc.functions.async_execution
+def async_wrong_type():
+    return torch.zeros(2, 2)
+
+
+@rpc.functions.async_execution
+def async_add(to, x, y):
+    return rpc.rpc_async(to, torch.add, args=(x, y))
+
+
+def slow_add(x, y, device="cpu"):
+    time.sleep(1)
+    x = x.to(device)
+    y = y.to(device)
+    return torch.add(x, y).cpu()
+
+
+@rpc.functions.async_execution
+def slow_async_add(to, x, y, device="cpu"):
+    return rpc.rpc_async(to, slow_add, args=(x, y, device))
+
+
+@rpc.functions.async_execution
+def async_add_with_future_ctor(to, x, y, z):
+    fut = torch.futures.Future()
+    rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        lambda fut1: fut.set_result(fut1.wait() + z)
+    )
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_chained(to, x, y, z):
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(lambda fut: fut.wait() + z)
+
+
+@rpc.functions.async_execution
+def async_add_chained_multi(to, x, num, step):
+    fut = rpc.rpc_async(to, torch.add, args=(x, 0))
+    for _ in range(num):
+        fut = fut.then(lambda fut: fut.wait() + step)
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_nested(to, x, y, z):
+    return rpc.rpc_async(to, async_add, args=(to, x, y)).then(
+        lambda fut: fut.wait() + z
+    )
+
+
+@rpc.functions.async_execution
+def async_add_multi_fanout(to, x, num, step):
+    futs = []
+    for i in range(num):
+        if i == 0:
+            futs.append(rpc.rpc_async(to, torch.add, args=(x, step)))
+        else:
+            futs.append(rpc.rpc_async(to, torch.add, args=(0, step)))
+
+    # TODO: use torch.futures.collect_all
+    lock = Lock()
+    state = {"cnt": 0, "ret": torch.zeros_like(x)}
+    ret_future = torch.futures.Future()
+
+    def inc_and_set(fut):
+        with lock:
+            state["cnt"] += 1
+            state["ret"] += fut.wait()
+            if state["cnt"] >= len(futs):
+                ret_future.set_result(state["ret"])
+
+    for fut in futs:
+        fut.then(inc_and_set)
+
+    return ret_future
+
+
+@rpc.functions.async_execution
+def async_cuda_sleep_and_set_to_one(t):
+    device = t.device
+    original_stream = torch.cuda.current_stream(device)
+    new_stream = torch.cuda.Stream(device)
+    new_stream.wait_stream(original_stream)
+    with torch.cuda.stream(new_stream):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        t.fill_(1)
+        fut = Future(devices=[device])
+        fut.set_result(t)
+        return fut
+
+
+@rpc.functions.async_execution
+def async_cuda_nested_add(to, x, y, z):
+    def cb(fut):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        return fut.value() + z
+
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb)
+
+
+# A custom Python class that contains a tensor, needed to see if we correctly
+# use the Python pickler to extract tensors from non-IValue-convertible types.
+class TensorWrapper:
+    __slots__ = ("tensor", "lock", "event", "thread")
+
+    def __init__(self, t):
+        self.tensor = t
+        # Add one non-picklable field, to ensure it's ignored/skipped.
+        self.lock = Lock()
+        self.event = torch.cuda.Event(enable_timing=True)
+        self.thread = threading.Thread()
+        self.thread.start()
+
+    def increase(self, v):
+        with self.lock:
+            self.tensor += v
+
+    def sum(self):
+        with self.lock:
+            self.event.record()
+            return self.tensor.sum()
+
+
+class AsyncExecutionClass:
+    @staticmethod
+    @rpc.functions.async_execution
+    def static_async_add(to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+    @classmethod
+    @rpc.functions.async_execution
+    def class_async_add(cls, to, x, y, z):
+        ret_fut = torch.futures.Future()
+        rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: ret_fut.set_result(fut.wait() + z)
+        )
+        return ret_fut
+
+    @rpc.functions.async_execution
+    def bound_async_add(self, to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+
+def return_future():
+    return torch.futures.Future()
+
+
+class FooBackendOptions(rpc.RpcBackendOptions):
+    def __init__(self, init_method):
+        # Must call the __init__ of the superclass (and do so directly,
+        # without using super()) because... pybind.
+        rpc.RpcBackendOptions.__init__(self)
+        self.init_method = init_method
+
+
+# load_tests from common_utils is used to automatically filter tests for
+# sharding on sandcastle. This line silences flake warnings
+load_tests = load_tests  # noqa: PLW0127
+
+
+class MyEmbeddingBagModel(torch.nn.Module):
+    def __init__(self, sparse):
+        super().__init__()
+        self.eb = torch.nn.EmbeddingBag(10, 10, sparse=sparse)
+
+    def forward(self, x):
+        return self.eb(x)
+
+
+class MyParameterServer:
+    def __init__(self, trainers):
+        self.lock = Lock()
+        self.trainers = trainers
+        self.iteration = 0
+        self.updates = 0
+        self.futures = []
+        self.total = None
+        self.gradient = None
+
+    @staticmethod
+    def get_gradient(rref):
+        return rref.local_value().gradient
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def average(rref, riteration, tensor):
+        self = rref.local_value()
+        fut = torch.futures.Future()
+        with self.lock:
+            if riteration > self.iteration:
+                self.iteration = riteration
+                self.updates = 0
+                self.futures.clear()
+            self.futures.append(fut)
+            if self.total is None:
+                self.total = tensor
+            else:
+                self.total += tensor
+            self.updates += 1
+            if self.trainers == self.updates:
+                self.gradient = self.total / float(self.trainers)
+                for fut in self.futures:
+                    result = self.total / float(self.trainers)
+                    fut.set_result(result)
+        return fut
+
+
+class MyConvNetForMNIST(nn.Module):
+    def __init__(self, device):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Conv2d(1, 16, 3, 1),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, 3, 1),
+            nn.ReLU(),
+            nn.MaxPool2d(2),
+            nn.Flatten(1),
+            nn.Linear(4608, 128),
+            nn.ReLU(),
+            nn.Linear(128, 10),
+        ).to(device)
+        self.device = device
+
+    def forward(self, x, is_rref=False):
+        x = x.to_here() if is_rref else x
+        with torch.cuda.stream(torch.cuda.current_stream(self.device)):
+            # intentionally adding delay to current CUDA stream
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            return self.net(x)
+
+    def __getstate__(self):
+        # return an empty dict to avoid inspecting the model contents on the
+        # owner
+        return {}
+
+
+class RpcTestCommon:
+    def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None):
+        if mode == RPCExecMode.SYNC:
+            return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs)
+        elif mode == RPCExecMode.ASYNC:
+            return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait()
+        elif mode == RPCExecMode.REMOTE:
+            return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here()
+
+    def _self_py_udf_remote(self, worker_info, x, y, z):
+        rref = rpc.remote(worker_info, my_function, args=(x, y, z))
+        self.assertEqual(rref.to_here(), x + y + z)
+
+    def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
+        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
+        self.assertEqual(ret, x + y + z + x + y)
+        self.assertEqual(fut.wait(), x + y + z + x)
+
+    def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
+        self.assertEqual(ret_rref.to_here(), x + y + z + x)
+
+    def _world_size_one(self, a, b):
+        if self.rank == 0:
+            rpc.init_rpc(
+                name="me",
+                backend=self.rpc_backend,
+                rank=0,
+                world_size=1,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+            def _rpc_sync(x, y):
+                expect = x * 2
+                result = rpc.rpc_sync("me", my_tensor_function, args=(x, y))
+                self.assertEqual(expect, result)
+
+            def _rpc_async(x, y):
+                expect = x * 2
+                result = rpc.rpc_async("me", my_tensor_function, args=(x, y)).wait()
+                self.assertEqual(expect, result)
+
+            def _remote(x, y):
+                expect = x * 2
+                result = rpc.remote("me", my_tensor_function, args=(x, y)).to_here()
+                self.assertEqual(expect, result)
+
+            _rpc_sync(a, b)
+            _rpc_async(a, b)
+            _remote(a, b)
+
+            rpc.shutdown()
+
+    def _multi_rpc(self, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+        for i in range(20):
+            n = i + self.rank + 1
+            if sparse:
+                x = build_sparse_tensor() * n
+                y = build_sparse_tensor() * n
+            else:
+                x = torch.ones(2, 2)
+                y = torch.ones(2, 2)
+            ret = rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(x, y),
+            )
+            self.assertEqual(ret, x * 2)
+
+    def _run_uneven_workload(self, f, x, num_repeat=30):
+        # worker0 drives and waits for worker1 and worker2
+        # throughout the test.
+        if self.rank == 0:
+            self.assertTrue(self.world_size >= 3)
+
+            # Phase 1: Only worker1 has workload.
+            dst = "worker1"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for fut in torch.futures.collect_all(futs).wait():
+                self.assertEqual(fut.wait(), 0)
+
+            # Phase 2: Only worker2 has workload.
+            # If join is not correctly implemented,
+            # worker2 should be closed by now.
+            dst = "worker2"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for val in torch.futures.wait_all(futs):
+                self.assertEqual(val, 0)
+
+    def _wait_all_workers(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _wait_all_workers_twice(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _nested_rpc(self, f, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            f,
+            args=(worker_name(self.rank),),
+        )
+        self.assertEqual(ret, expected)
+
+    def _stress_test_rpc(self, f, repeat=1000, args=()):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        futs = []
+        tik = time.time()
+        for _ in range(repeat):
+            fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
+            futs.append(fut)
+
+        for val in torch.futures.wait_all(futs):
+            self.assertEqual(val, 0)
+        tok = time.time()
+        print(
+            f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds."
+        )
+
+    def _builtin_remote_ret(self, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _builtin_remote_self(self, x, y, expected):
+        rref = rpc.remote(
+            worker_name(self.rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.local_value(), expected)
+
+    def _test_multi_remote_call(
+        self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}
+    ):
+        m = 10
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rrefs = []
+        expected = []
+        for i in range(m):
+            n = n + i
+            rrefs.append(
+                rpc.remote(
+                    worker_name(dst_rank),
+                    fn,
+                    args=args_fn(n, sparse),
+                    kwargs=kwargs_fn(n, sparse),
+                )
+            )
+            expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
+
+        for i in range(m):
+            self.assertEqual(rrefs[i].to_here(), expected[i])
+
+    def _py_rref_args(self, a, b, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), torch.add, args=(a, b))
+        rref_b = rpc.remote(worker_name(dst_rank), torch.add, args=(x, y))
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        owner_rank = n % self.world_size
+        user_rank = (n + 1) % self.world_size
+        rref_a = rpc.remote(worker_name(owner_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(owner_rank), my_function, args=(x, y, z))
+        rref_c = rpc.remote(
+            worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(dst_rank), my_function, args=(x, y, z))
+
+        c = rpc.rpc_sync(worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b))
+        self.assertEqual(c, expected)
+
+    def _nested_remote(self, f, expected):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+
+        rref = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _nested_rref(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref_of_rrefs = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+
+        # Say C has 2 OwnerRRefs.
+        # B has 2 UserRRefs to those 2 OwnerRRefs, respectively.
+        # This call is effectively A asking B to share its 2 UserRRefs.
+        rrefs = rref_of_rrefs.to_here()
+
+        self.assertEqual(len(rrefs), 2)
+        self.assertEqual(rrefs[0].to_here(), expected1)
+        self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _nested_rref_stress(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        all_rrefs = [
+            rpc.remote(
+                worker_name(dst_rank1),
+                f,
+                args=(worker_name(dst_rank2),),
+            )
+            for _ in range(20)
+        ]
+
+        for i in range(20):
+            rref_of_rrefs = all_rrefs[i]
+            rrefs = rref_of_rrefs.to_here()
+            self.assertEqual(len(rrefs), 2)
+            self.assertEqual(rrefs[0].to_here(), expected1)
+            self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _trainer_func(self, rref, sparse):
+        m = MyEmbeddingBagModel(sparse=sparse)
+        loss_fn = nn.MSELoss()
+        for i in range(10):
+            outputs = m(torch.rand(10, 10).long())
+            loss_fn(outputs, torch.rand(10, 10)).backward()
+            gradient = next(iter(m.parameters())).grad
+            fut = rref.rpc_async().average(rref, i, gradient)
+            gradient = fut.wait()
+            if gradient.is_sparse:
+                gradient = gradient.to_dense().double()
+            ps_gradient = rref.rpc_sync().get_gradient(rref)
+            if ps_gradient.is_sparse:
+                ps_gradient = ps_gradient.to_dense().double()
+            self.assertTrue(torch.equal(gradient, ps_gradient))
+
+    def _my_parameter_server(self, sparse):
+        ps_rref = RRef(MyParameterServer(self.world_size - 1))
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + index) % self.world_size),
+                self._trainer_func,
+                args=(ps_rref, sparse),
+            )
+            for index in range(1, self.world_size)
+        ]
+        torch.futures.wait_all(futures)
+
+    def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor):
+        # We check proper CUDA stream synchronization by adding to the tensor
+        # in one stream to get the expected value, and reading it from another stream.
+        future = Future(devices=["cuda:0"])
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                if sparse_tensor:
+                    tensor = build_sparse_tensor().to("cuda:0")
+                    add_tensor = build_sparse_tensor().to("cuda:0")
+                    expected_tensor = (tensor + add_tensor).coalesce()
+                else:
+                    tensor = torch.zeros((100,), device="cuda:0")
+                    add_tensor = torch.ones((100,), device="cuda:0")
+                    expected_tensor = tensor + add_tensor
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor += add_tensor
+                if sparse_tensor:
+                    tensor = tensor.coalesce()
+                future.set_result(wrapper(tensor))
+            with torch.cuda.stream(another_stream):
+                tensor = unwrapper(future.wait())
+                if sparse_tensor:
+                    self.assertTrue(
+                        torch.eq(tensor.indices(), expected_tensor.indices())
+                        .all()
+                        .item()
+                    )
+                    self.assertTrue(
+                        torch.eq(tensor.values(), expected_tensor.values()).all().item()
+                    )
+                    self.assertEqual(tensor.size(), expected_tensor.size())
+                else:
+                    self.assertTrue(torch.eq(tensor, expected_tensor).all().item())
+
+
+class RpcTest(RpcAgentTestFixture, RpcTestCommon):
+    @dist_init
+    def test_worker_id(self):
+        n = self.rank + 1
+        peer_rank = n % self.world_size
+        self_worker_info = rpc.get_worker_info()
+        peer_worker_info = rpc.get_worker_info(worker_name(peer_rank))
+
+        self.assertEqual(self_worker_info.name, worker_name(self.rank))
+        self.assertEqual(peer_worker_info.name, worker_name(peer_rank))
+
+        with self.assertRaisesRegex(RuntimeError, "could not find destination"):
+            rpc.get_worker_info("WorkerUnknown")
+
+    @dist_init
+    def test_get_worker_infos(self):
+        worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos()
+
+        worker_names = {worker_info.name for worker_info in worker_infos}
+        expected_worker_names = {worker_name(rank) for rank in range(self.world_size)}
+        self.assertEqual(worker_names, expected_worker_names)
+
+        worker_ids = {worker_info.id for worker_info in worker_infos}
+        expected_worker_ids = set(range(self.world_size))
+        self.assertEqual(worker_ids, expected_worker_ids)
+
+    @dist_init
+    def test_self_add(self):
+        self_worker_info = rpc.get_worker_info()
+        fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_send_to_rank(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test dense tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            ret = self._run_func_in_mode(
+                dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+            )
+            self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test invalid ranks
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    self.world_size + 1,
+                    torch.add,
+                    exec_mode,
+                    args=(torch.ones(2, 2), 1),
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    -1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+    @dist_init
+    def test_self_py_udf_remote(self):
+        self._self_py_udf_remote(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg(self):
+        self._self_remote_rref_as_rpc_arg(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(), torch.ones(2, 2), 1, 3
+        )
+
+    @dist_init
+    def test_rref_proxy_non_exist(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+        msg = "has no attribute 'non_exist'"
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_sync().non_exist()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_async().non_exist().wait()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.remote().non_exist()
+
+    def _test_rref_proxy_tensor(self, dst):
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+
+        expected = torch.ones(2, 2) + 1 + 3
+        self.assertEqual(expected.size(), rref.rpc_sync().size())
+        self.assertEqual(expected + 1, rref.rpc_async().add(1).wait())
+        self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here())
+
+    @dist_init
+    def test_rref_proxy_tensor(self):
+        self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_tensor_self(self):
+        self._test_rref_proxy_tensor(rpc.get_worker_info())
+
+    @dist_init
+    def test_rref_proxy_reuse(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            my_function,
+            args=(torch.ones(2, 2), 1, 3),
+        )
+        expected = torch.ones(2, 2) + 1 + 3
+
+        proxy_rpc_sync = rref.rpc_sync()
+        proxy_rpc_async = rref.rpc_async()
+        proxy_remote = rref.remote()
+
+        self.assertEqual(expected.size(), proxy_rpc_sync.size())
+        self.assertEqual(expected + 1, proxy_rpc_sync.add(1))
+        self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4))
+
+        self.assertEqual(expected.size(), proxy_rpc_async.size().wait())
+        self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait())
+        self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait())
+
+        self.assertEqual(expected.size(), proxy_remote.size().to_here())
+        self.assertEqual(expected + 5, proxy_remote.add(5).to_here())
+        self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here())
+
+    def _test_rref_proxy_class(self, dst):
+        rref = rpc.remote(dst, MyClass, args=(7,))
+        expected = MyClass(7)
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        expected.increment_value(3)
+        self.assertEqual(None, rref.rpc_sync().increment_value(1))
+        self.assertEqual(None, rref.rpc_async().increment_value(1).wait())
+        self.assertEqual(None, rref.remote().increment_value(1).to_here())
+
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        self.assertEqual(
+            expected.my_instance_method(2), rref.rpc_sync().my_instance_method(2)
+        )
+        self.assertEqual(
+            expected.my_instance_method(3),
+            rref.rpc_async().my_instance_method(3).wait(),
+        )
+        self.assertEqual(
+            expected.my_instance_method(4),
+            rref.remote().my_instance_method(4).to_here(),
+        )
+
+        self.assertEqual(
+            expected.my_static_method(9), rref.rpc_sync().my_static_method(9)
+        )
+        self.assertEqual(
+            expected.my_static_method(10), rref.rpc_async().my_static_method(10).wait()
+        )
+        self.assertEqual(
+            expected.my_static_method(11), rref.remote().my_static_method(11).to_here()
+        )
+
+        self.assertEqual(
+            expected.my_class_method(2, torch.zeros(2, 2)),
+            rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(3, 3)),
+            rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait(),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(4, 4)),
+            rref.remote().my_class_method(2, torch.ones(4, 4)).to_here(),
+        )
+
+    @dist_init
+    def test_rref_proxy_class(self):
+        self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_class_self(self):
+        self._test_rref_proxy_class(rpc.get_worker_info())
+
+    @mock.patch.object(torch.distributed.autograd, "_init")
+    @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
+    @dist_init(setup_rpc=False)
+    def test_register_rpc_backend_and_set_and_start_rpc_backend(
+        self, mock_rpc_agent, mock_dist_autograd_init
+    ):
+        backend_name = "stub_backend"
+
+        backend = rpc.backend_registry.register_backend(
+            backend_name,
+            _stub_construct_rpc_backend_options_handler,
+            _stub_init_rpc_backend_handler,
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError, "^RPC backend .+: already registered$"
+        ):
+            backend = rpc.backend_registry.register_backend(
+                backend_name,
+                _stub_construct_rpc_backend_options_handler,
+                _stub_init_rpc_backend_handler,
+            )
+
+        rpc.init_rpc(
+            name="worker1",
+            backend=backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            store, _, _ = next(
+                torch.distributed.rendezvous(
+                    self.init_method, rank=self.rank, world_size=self.world_size
+                )
+            )
+            rpc._init_rpc_backend(
+                backend=self.rpc_backend,
+                store=store,
+                name="duplicate_name",
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name_2(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            rpc.init_rpc(
+                name=worker_name(self.rank % (self.world_size - 1)),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_reinit(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # TODO: with TCP init, rank 0 raises Address already in use because
+        # rank 0 is the start daemon and the store is created before checking if
+        # RPC is already initialized in init_rpc.
+        if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0:
+            expected_reinit_err = "Address already in use"
+        else:
+            expected_reinit_err = "is already initialized"
+
+        with self.assertRaisesRegex(RuntimeError, expected_reinit_err):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_pg_init_no_rpc_init(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        class MyModel(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.lin = torch.nn.Linear(3, 4)
+
+            def forward(self, x):
+                return self.lin(x)
+
+        model = MyModel()
+        model.train()
+        model = torch.nn.parallel.DistributedDataParallel(model)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Current RPC agent is not set! Did you initialize the RPC framework",
+        ):
+            [RRef(param) for param in model.parameters()]
+
+    def test_world_size_one(self):
+        self._world_size_one(torch.ones(2, 2), torch.ones(2, 2))
+
+    @dist_init(setup_rpc=False)
+    def test_invalid_names(self):
+        worker_id = 0
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo("abc*", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo(" ", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
+            WorkerInfo("", worker_id)
+
+        # If the number in the message does not match, it is likely that the
+        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
+        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
+            WorkerInfo("".join(["a" for i in range(500)]), worker_id)
+
+    # Test that WorkerInfo can be pickled and sent in RPC call
+    @dist_init
+    def test_worker_info_pickle(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        worker_info = rpc.api.get_worker_info()
+        ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,))
+        self.assertEqual(ret, worker_info)
+
+    @dist_init
+    def test_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @staticmethod
+    def return_callee_id():
+        return rpc.get_worker_info().id
+
+    @dist_init
+    def test_int_callee(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
+        self.assertEqual(ret, dst_rank)
+
+    @dist_init
+    def test_add_with_id(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        workder_info = rpc.get_worker_info(worker_name(dst_rank))
+
+        ret = rpc.rpc_sync(
+            workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_scalar_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n))
+        self.assertEqual(ret, (torch.ones(n, n) + n))
+
+    @dist_init
+    def test_async_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_nonzero(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        x = torch.ones(self.world_size, self.world_size)
+        x[self.rank][self.rank] = 0
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
+        self.assertEqual(ret, x.nonzero())
+
+    @dist_init
+    def test_multi_rpc(self):
+        self._multi_rpc(False)
+
+    @dist_init
+    def test_future_wait_twice(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(20)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+        for fut in futs:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    @dist_init(setup_rpc=False)
+    def test_wait_all_workers_timeout(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        og_func = rpc.api._wait_all_workers
+
+        def wait_all_workers_sleep(timeout):
+            rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
+
+        rpc.api._wait_all_workers = wait_all_workers_sleep
+
+        try:
+            with self.assertRaisesRegex(RuntimeError, ""):
+                rpc.shutdown(graceful=True, timeout=0.01)
+        finally:
+            rpc.api._wait_all_workers = og_func
+        dist.barrier()
+
+    def test_wait_all_workers_dense(self):
+        self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
+
+    def test_wait_all_workers_twice_dense(self):
+        self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init
+    def test_all_gather(self):
+        info = rpc.get_worker_info()
+        results = rpc.api._all_gather(info.id)
+        expected = {}
+        for info in rpc._get_current_rpc_agent().get_worker_infos():
+            expected[info.name] = info.id
+
+        self.assertEqual(expected, results)
+
+    @dist_init
+    def test_all_gather_timeout(self):
+        rpc._set_rpc_timeout(0.1)
+
+        if self.rank == 0:
+            with self.assertRaisesRegex(
+                RuntimeError, "timed out in _all_gather after 0\\.10 seconds"
+            ):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+        else:
+            expected_error = self.get_timeout_error_regex()
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+
+    def _test_barrier_helper(self, info, names, multi_threaded=False):
+        names = sorted(names)
+        leader = names[0]
+        rpc.rpc_sync(leader, _reset_count)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, 0)
+        rpc.api._barrier(names)
+        rpc.rpc_sync(leader, _increment_count)
+        rpc.api._barrier(names)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, len(names))
+
+    @dist_init
+    def test_rpc_barrier_all(self):
+        # Test rpc barrier when called with full list of workers
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_subset(self):
+        # Test rpc barrier when processes are called with different subsets of the full list
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [worker.name for worker in all_worker_info if not worker.id % 2]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_partial_subset(self):
+        # Test rpc barrier when some processes are not involved in the barrier
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [f"worker{info.id}"]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_multithreaded(self):
+        # This tests validates the implementation of barrier when multiple threads call into it
+        # We only need to check that it does not hang in this case
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        threads = []
+        for _ in range(3):
+            th = threading.Thread(
+                target=self._test_barrier_helper, args=(info, names, True)
+            )
+            threads.append(th)
+            th.start()
+        for th in threads:
+            th.join()
+
+    @dist_init
+    def test_graceful_shutdown_with_uneven_workload(self):
+        """Test graceful termination."""
+        self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_followed_by_rpc(self):
+        # Initialize RPC.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+        rpc.shutdown()
+
+        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
+            rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(torch.ones(n, n), torch.ones(n, n)),
+            )
+
+    @dist_init
+    def test_expected_src(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        expected_src_rank = (self.rank - 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,))
+        value = VALUE_FUTURE.result()
+        self.assertEqual(value, expected_src_rank)
+
+    @dist_init
+    def test_py_built_in(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2))
+        self.assertEqual(ret, min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_user_defined(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(ret, my_function(n, n + 1, n + 2))
+
+    def test_build_rpc_profiling_key(self):
+        # Tests that the name that shows up as an Event in profiling RPCs has all
+        # the necessary information.
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            rpc_profiling_key = _build_rpc_profiling_key(
+                exec_mode, "foo", "worker0", "worker1"
+            )
+            self.assertIn(exec_mode.value, rpc_profiling_key)
+            self.assertIn("foo", rpc_profiling_key)
+            self.assertIn("worker0", rpc_profiling_key)
+            self.assertIn("worker1", rpc_profiling_key)
+
+    def check_profiling_info(
+        self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+    ):
+        self.assertTrue(self_worker_name in rpc_event.name)
+        self.assertTrue(dst_worker_name in rpc_event.name)
+        if isinstance(func, torch.jit.ScriptFunction):
+            self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name)
+        else:
+            self.assertTrue(func.__name__ in rpc_event.name)
+        self.assertTrue(rpc_exec_mode.value in rpc_event.name)
+        self.assertEqual(rpc_event.count, 1)
+
+    @dist_init
+    def test_profiler_rpc_record_shapes(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        t1, t2 = torch.ones(100), torch.ones(100)
+        with _profile(record_shapes=True) as prof:
+            rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2))
+
+        function_events = prof.function_events
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_add_event = next(
+            event for event in remote_events if "aten::add" in event.name
+        )
+        remote_add_input_shapes = remote_add_event.input_shapes
+        # Run profiler on equivalent local op and validate shapes are the same.
+        with _profile(record_shapes=True) as prof:
+            torch.add(t1, t2)
+
+        local_function_events = prof.function_events
+        local_add_event = next(
+            event for event in local_function_events if "aten::add" in event.name
+        )
+        local_add_input_shapes = local_add_event.input_shapes
+        self.assertEqual(remote_add_input_shapes, local_add_input_shapes)
+
+    @dist_init
+    def test_profiler_rpc_memory(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile(profile_memory=True) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        # if cpu_memory_usage was not propagated over the wire, this set would
+        # only contain 0 (indicates no memory being profiled)
+        self.assertNotEqual({0}, event_cpu_mem_usages)
+        # No memory profiled if profile_memory=False
+        with _profile(profile_memory=False) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        self.assertEqual({0}, event_cpu_mem_usages)
+
+    @dist_init
+    def test_profiler_export_trace(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile() as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        with TemporaryFileName() as fname:
+            path = fname
+            p.export_chrome_trace(path)
+            with open(path) as f:
+                trace = json.load(f)
+                event_names = [event["name"] for event in trace]
+                for expected_event_name in EXPECTED_REMOTE_EVENTS + [
+                    RPCExecMode.ASYNC.value
+                ]:
+                    event_exists = any(
+                        expected_event_name in event_name for event_name in event_names
+                    )
+                    self.assertTrue(event_exists)
+
+    @dist_init
+    def test_profiler_rpc_key_names(self):
+        # tests that remote events are properly prefixed with the RPC profiling key.
+        if self.rank != 1:
+            return
+
+        # Spawn multiple threads that send RPCs to ensure keys are correctly
+        # prefixed when there are multiple RPCs being created/in flight at the
+        # same time.
+        dst_ranks = [rank for rank in range(self.world_size) if rank != self.rank]
+
+        def rpc_with_profiling(dst_worker):
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+            remote_event_names = {
+                event.name: event for event in events if event.is_remote
+            }
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                dst_worker,
+            )
+
+            remote_event_name_set = set(EXPECTED_REMOTE_EVENTS)
+            for name, event in remote_event_names.items():
+                # Ensure that we have the expected key as part of the remote
+                # event.
+                self.assertTrue(name.startswith(rpc_profiling_key))
+                self.assertTrue(event.is_remote)
+                self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id)
+                # Ensure that the remote event name also contains the operator.
+                operator_name_substr = name[len(rpc_profiling_key) :]
+                # Note: we don't assert that every remote event needs to be
+                # in the above set, the set is just a representative set of
+                # what we expect to see. The profiler can change and add more
+                # events, but we should always expect to see this representative
+                # set.
+                matching_event = {
+                    remote_event_name
+                    for remote_event_name in remote_event_name_set
+                    if remote_event_name in operator_name_substr
+                }
+                remote_event_name_set -= matching_event
+
+            # The set should be empty, otherwise its contained elements did
+            # not show up in the remote profiler output.
+            self.assertTrue(
+                remote_event_name_set == set(),
+                f"Expected {remote_event_name_set} to be included in remote profiler output.",
+            )
+
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            num_parallel_rpcs = 2
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=num_parallel_rpcs
+            ) as executor:
+                futs = [
+                    executor.submit(rpc_with_profiling, dst_worker)
+                    for _ in range(num_parallel_rpcs)
+                ]
+                # Wait for workers to finish test
+                for fut in futs:
+                    fut.result()
+
+    def _run_test_profiler_remote_events_profiled(self):
+        # Tests that we can successfully invoke the profiler on a remote node,
+        # and collect the remote events back in the local profiler.
+        if self.rank != 1:
+            return
+
+        dst_ranks = [rank for rank in range(self.world_size) if rank != self.rank]
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+
+            rpc_event = get_function_event(events, RPCExecMode.ASYNC.value)
+            self.check_profiling_info(
+                worker_name(self.rank),
+                dst_worker,
+                udf_with_torch_ops,
+                rpc_event,
+                RPCExecMode.ASYNC,
+            )
+
+            remote_events = {event.name: event for event in events if event.is_remote}
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                worker_name(dst),
+            )
+
+            for expected_remote_event_name in EXPECTED_REMOTE_EVENTS:
+                expected_key = (
+                    rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name
+                )
+                self.assertTrue(expected_key in remote_events)
+                remote_event = remote_events[expected_key]
+                # Remote event should have a node ID corresponding to the worker
+                # it ran on.
+                self.assertEqual(remote_event.node_id, dst)
+
+            # Validate order remote events show up in profiling output.
+            def convert_remote_to_local(event_name):
+                remote_op_key = rpc_profiling_key + REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            remote_events_list = [
+                convert_remote_to_local(event.name)
+                for event in events
+                if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS
+            ]
+            self.assertEqual(
+                set(remote_events_list),
+                set(EXPECTED_REMOTE_EVENTS),
+                f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}",
+            )
+
+    @dist_init
+    def test_profiler_remote_events_profiled(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    @dist_init
+    def test_profiler_remote_events_profiled_single_threaded(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    def run_profiling_workload(self, dst):
+        fut = rpc.rpc_async(
+            worker_name(dst),
+            torch.mul,
+            args=(
+                torch.tensor(1.0, requires_grad=True),
+                torch.tensor(1.0, requires_grad=True),
+            ),
+        )
+        fut.wait()
+
+    def _run_rpc_profiling_async_function(self, device="cpu"):
+        if self.rank != 1:
+            return
+
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        x = torch.ones(2)
+        y = torch.ones(2)
+        with _profile() as prof:
+            ret = rpc.rpc_async(
+                dst1, slow_async_add, args=(dst2, x, y, device), timeout=20
+            )
+            ret.wait()
+
+        function_events = prof.function_events
+        # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be
+        # recorded.
+        key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1
+        )
+
+        nested_rpc_key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2
+        )
+        expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix
+        remote_events = [event for event in function_events if event.is_remote]
+        rpc_remote_event = [
+            event for event in remote_events if event.name == expected_key
+        ]
+        self.assertEqual(1, len(rpc_remote_event))
+        rpc_remote_event = rpc_remote_event[0]
+        self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size)
+        # slow_async_add's RPC does an add on dst2, which should be reflected as well.
+        remote_add_key = (
+            expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add)
+        )
+        remote_add_event = [
+            event for event in remote_events if event.name == remote_add_key
+        ]
+        self.assertEqual(1, len(remote_add_event))
+        remote_add_event = remote_add_event[0]
+        # Validate that node_id is dst2.
+        self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size)
+
+    @dist_init
+    def test_rpc_profiling_async_function(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_async_function_single_threaded(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_remote_record_function(self):
+        # test that functions run over RPC with record_function show the expected
+        # profiled block.
+        if self.rank != 1:
+            return
+        dst_ranks = [i for i in range(self.world_size) if i != self.rank]
+        for dst_rank in dst_ranks:
+            dst_worker = worker_name(dst_rank)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True))
+                fut.wait()
+
+            function_events = prof.function_events
+            record_function_remote_event = [
+                evt for evt in function_events if "##forward##" in evt.name
+            ]
+            self.assertEqual(1, len(record_function_remote_event))
+            record_function_remote_event = record_function_remote_event[0]
+            self.assertEqual(record_function_remote_event.node_id, dst_rank)
+            # cpu_children only returns direct children, so here we get all
+            # children recursively.
+
+            def get_cpu_children(event):
+                if not event.cpu_children:
+                    return []
+                cpu_children = event.cpu_children
+                for e in event.cpu_children:
+                    cpu_children.extend(get_cpu_children(e))
+                return cpu_children
+
+            remote_children = get_cpu_children(record_function_remote_event)
+            # Get local children and verify parity.
+            with _profile() as prof:
+                udf_with_torch_ops(-1, True)
+
+            local_function_events = prof.function_events
+            local_record_function_event = next(
+                evt for evt in local_function_events if "##forward##" in evt.name
+            )
+            local_children = get_cpu_children(local_record_function_event)
+            local_children_names = [evt.name for evt in local_children]
+
+            REMOTE_OP_STR = "#remote_op: "
+
+            def convert_remote_to_local(event_name):
+                remote_op_key = REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            for evt in remote_children:
+                local_name = convert_remote_to_local(evt.name)
+                self.assertTrue(local_name in local_children_names)
+
+    def validate_profiling_workload(self, dst, prof):
+        def convert_remote_to_local(event_name):
+            return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        events = prof.function_events
+        remote_events = {
+            convert_remote_to_local(event.name): event
+            for event in events
+            if event.is_remote
+        }
+        self.assertTrue("aten::mul" in remote_events)
+        remote_mul_event = remote_events["aten::mul"]
+        self.assertEqual(remote_mul_event.node_id, dst)
+        self.check_profiling_info(
+            worker_name(self.rank),
+            worker_name(dst),
+            torch.mul,
+            remote_mul_event,
+            RPCExecMode.ASYNC,
+        )
+
+    def _run_test_profiler_with_autograd_context(self):
+        dst = (self.rank + 1) % self.world_size
+        if self.rank == 1:
+            # Cases where we can double wrap messages with profiling information and autograd info.
+            with dist_autograd.context(), _profile() as prof:
+                self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+            # Ensure that flipped order of ctx managers results in events being
+            # recorded as expected.
+            with _profile() as prof, dist_autograd.context():
+                self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+    @dist_init
+    def test_profiler_with_autograd_context_single_threaded(self):
+        self._run_test_profiler_with_autograd_context()
+
+    @dist_init
+    def test_profiler_with_autograd_context(self):
+        self._run_test_profiler_with_autograd_context()
+
+    def _profiler_test_with_rpc(
+        self,
+        rpc_exec_mode,
+        func,
+        args,
+        use_record_function=False,
+        dst=None,
+        kineto_profile=False,
+    ):
+        dst = dst if dst is not None else (self.rank + 1) % self.world_size
+
+        # only run profiler on rank 1.
+        p = _profile if not kineto_profile else torch.profiler.profile  # kineto
+        if self.rank == 1:
+            with p() as prof:
+                record_function_ctx_mgr = (
+                    contextlib.nullcontext()
+                    if not use_record_function
+                    else torch.autograd.profiler.record_function("foo")
+                )
+                with record_function_ctx_mgr:
+                    if rpc_exec_mode == RPCExecMode.SYNC:
+                        rpc.rpc_sync(worker_name(dst), func, args=args)
+                    elif rpc_exec_mode == RPCExecMode.ASYNC:
+                        fut = rpc.rpc_async(worker_name(dst), func, args=args)
+                        if kineto_profile:
+                            # Ensure multiple async RPCs don't cause issues.
+                            # Would have raised
+                            # "RuntimeError: Cannot call
+                            # RemoteProfilerManager::setCurrentKey when current
+                            # key is already set." error if RPC profiling was
+                            # not disabled properly for kineto.
+                            fut2 = rpc.rpc_async(worker_name(dst), func, args=args)
+                            fut2.wait()
+                        fut.wait()
+                    else:
+                        self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE)
+                        rref = rpc.remote(worker_name(dst), func, args=args)
+                        rref.to_here()
+                        # To avoid flakiness, wait for the RRef to be profiled. This
+                        # means that we received the acknowledgement of successful
+                        # creation on the owner and ran the callbacks responsible
+                        # for recording the profiling event.
+                        rref._get_profiling_future().wait()
+
+            events = prof.function_events if not kineto_profile else prof.events()
+            if kineto_profile:
+                # RPC profiling is disabled so there should be no rpc related
+                # events.
+                with self.assertRaises(IndexError):
+                    get_function_event(events, rpc_exec_mode.value)
+
+                return
+
+            rpc_event = get_function_event(events, rpc_exec_mode.value)
+            # verify Node ID for this rpc event.
+            self.assertEqual(rpc_event.node_id, self.rank)
+            # Ensure recording of remote events.
+            remote_events = {event for event in events if event.node_id == dst} - {
+                rpc_event
+            }
+            self.assertGreaterEqual(len(remote_events), 1)
+            for remote_event in remote_events:
+                self.assertEqual(remote_event.node_id, dst)
+
+            if use_record_function:
+                scope_event = get_function_event(events, "foo")
+                # Since RPC call is within the scope, its CPU interval should be
+                # contained within foo's interval.
+                self.assertLessEqual(
+                    scope_event.time_range.start, rpc_event.time_range.start
+                )
+                self.assertGreaterEqual(
+                    scope_event.time_range.end, rpc_event.time_range.end
+                )
+            # the sender, dest worker, function run, and type of RPC should all
+            # be recorded.
+            self_worker_name = worker_name(self.rank)
+            dst_worker_name = worker_name(dst)
+            self.check_profiling_info(
+                self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+            )
+            if use_record_function:
+                # verify order by ensuring that the outer context comes
+                # before the rpc event.
+                foo_event_ix = next(
+                    i for i, event in enumerate(events) if "foo" in event.name
+                )
+                rpc_event_idx = next(
+                    i
+                    for i, event in enumerate(events)
+                    if rpc_exec_mode.value in event.name
+                )
+                self.assertLess(foo_event_ix, rpc_event_idx)
+
+    def _run_test_profiler_with_sync_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    def _run_test_profiler_with_sync_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    def _run_test_profiler_with_async_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # Test to ensure that kineto profiler enabled in RPC does not enable
+        # RPC profiling (it is unsupported) and does not result in issues.
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    def _run_test_profiler_with_async_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    def _run_test_profiler_with_remote_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_remote_udf(self):
+        self._run_test_profiler_with_remote_udf()
+
+    @dist_init
+    def test_profiler_with_remote_udf_single_threaded(self):
+        self._run_test_profiler_with_remote_udf()
+
+    def _run_test_profiler_with_remote_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            dst=self.rank,
+        )
+
+    @dist_init
+    def test_profiler_with_remote_builtin(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    @dist_init
+    def test_profiler_with_remote_builtin_single_threaded(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    def _run_test_profiler_with_script_async_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_async_rpc(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    @dist_init
+    def test_profiler_with_script_async_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    def _run_test_profiler_with_script_sync_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    def _run_test_profiler_with_script_remote_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    def _assert_top_level_events(
+        self, process_global_events, expected_top_level_event_names
+    ):
+        top_level_event_names = []
+        for thread_local_events in process_global_events:
+            # Get top-level events from all events happened on a thread.
+            last_end_time = 0
+            for event in thread_local_events:
+                event_name = event.name
+                time_range = event.time_range
+                if time_range.start > last_end_time:
+                    top_level_event_names.append(event_name)
+                    last_end_time = time_range.end
+        top_level_event_names = sorted(top_level_event_names)
+        expected_top_level_event_names = sorted(expected_top_level_event_names)
+        self.assertEqual(
+            top_level_event_names,
+            expected_top_level_event_names,
+            f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}",
+        )
+
+    @dist_init
+    def test_server_process_global_profiler(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = worker_name(dst_rank)
+
+        x = torch.tensor(1)
+        y = torch.tensor(2)
+
+        outer_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        outer_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
+        inner_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        inner_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
+        inner_profile_rref.rpc_sync().__exit__(None, None, None)
+        outer_profile_rref.rpc_sync().__exit__(None, None, None)
+
+        inner_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (inner_profile_rref,)
+        )
+        expected_inner_events = ["aten::sub"]
+        expected_outer_events = expected_inner_events + ["aten::add"]
+
+        self._assert_top_level_events(inner_events, expected_inner_events)
+        outer_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (outer_profile_rref,)
+        )
+        self._assert_top_level_events(outer_events, expected_outer_events)
+
+        inner_profile_rref.rpc_sync().key_averages()
+        outer_profile_rref.rpc_sync().key_averages()
+
+    @dist_init
+    def test_async_record_function_double_end_callbacks(self):
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            # Validate that calling the function twice results in an error.
+            with _profile():
+                with torch.autograd.profiler.record_function("foo") as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    rf._call_end_callbacks_on_future(fut)
+                    with self.assertRaisesRegex(
+                        RuntimeError, "can only be called once."
+                    ):
+                        rf._call_end_callbacks_on_future(fut)
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_legacy(self):
+        # Test the legacy _record_function ops work
+        # Note: These exist for backward compatibility with TorchScript
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            with _profile():
+                try:
+                    handle = torch.ops.profiler._record_function_enter("foo", None)
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
+                finally:
+                    torch.ops.profiler._record_function_exit(handle)
+
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_cbs_jit_call(self):
+        if self.rank == 1:
+            with _profile() as pf:
+                key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(my_script_func),
+                    "worker1",
+                    "worker0",
+                )
+                with torch.autograd.profiler.record_function(key) as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_script_func, args=(torch.tensor(1),)
+                    )
+                    # Intentionally calling record_function internals
+                    fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(
+                        rf.record, fut
+                    )
+                result = fut.wait()
+                # Validate that the profiling future returns the same value as the RPC
+                # future.
+                expected = torch.add(torch.tensor(1), torch.tensor(1))
+                self.assertEqual(result, expected)
+            events = pf.function_events
+            rpc_event = get_function_event(
+                events, torch._jit_internal._qualified_name(my_script_func)
+            )
+            self.assertTrue(
+                torch._jit_internal._qualified_name(my_script_func) in rpc_event.name
+            )
+
+    @dist_init
+    def test_py_class_constructor(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,))
+        self.assertEqual(ret.a, n)
+
+    @dist_init
+    def test_py_class_instance_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,)
+        )
+        self.assertEqual(ret, MyClass(2).my_instance_method(n))
+
+    @dist_init
+    def test_py_class_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1)
+        )
+        self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
+
+    @dist_init
+    def test_py_class_static_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,)
+        )
+        self.assertEqual(ret, MyClass.my_static_method(n + 10))
+
+    @dist_init
+    def test_py_multi_async_call(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker_info = rpc.get_worker_info(worker_name(dst_rank))
+        fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
+        fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
+        self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
+        self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_no_return_result(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), no_result)
+        self.assertEqual(ret, no_result())
+
+    @dist_init
+    def test_py_tensors(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_tensor_function,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
+
+    @dist_init
+    def test_py_tensors_multi_async_call(self):
+        futs = []
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        for i in range(100):
+            fut = rpc.rpc_async(
+                worker_name(dst_rank),
+                my_tensor_function,
+                args=(torch.ones(i, i), torch.ones(i, i)),
+            )
+            futs.append(fut)
+
+        for j, val in enumerate(torch.futures.wait_all(futs)):
+            self.assertEqual(
+                val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
+            )
+
+    @dist_init
+    def test_py_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [torch.ones(n, n), torch.ones(n, n)]
+        b = TensorClass(build_complex_tensors())
+        c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c)
+        )
+        self.assertEqual(ret, my_complex_tensor_function(a, b, c))
+
+    @dist_init
+    def test_py_nested_pickle(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_nested_pickle,
+            args=(MyPickleClass(), torch.ones(2, 2)),
+        )
+
+        m = MyPickleClass()
+        m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
+        self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
+
+    @dist_init
+    def test_py_function_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        with self.assertRaises(TypeError):
+            rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,))
+
+    @dist_init
+    def test_py_raise_in_user_func(self):
+        with captured_output() as (_, err):
+            # This barrier prevents a race condition where the main thread has
+            # not entered the context manager when the remote function runs.
+            initialize_pg(self.file_init_method, self.rank, self.world_size)
+            dist.barrier()
+            n = self.rank + 1
+            dst_rank = n % self.world_size
+            fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
+            with self.assertRaisesRegex(ValueError, expected_err):
+                fut.wait()
+            # This barrier prevents a race condition where the main thread exits
+            # context manager before the remote function has ran.
+            dist.barrier()
+
+        # Validate that trainers log errors when running functions.
+        stderr_lines = err.getvalue()
+        self.assertTrue(expected_err in stderr_lines)
+
+    @dist_init
+    def test_py_raise_in_user_func_escaped_str(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape)
+        try:
+            fut.wait()
+        except ValueError as e:
+            msg = str(e)
+            # Ensure newlines are unescaped to provide a better repr of error.
+            self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape"))
+        else:
+            self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
+
+    @dist_init
+    def test_nested_rpc(self):
+        self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_stress_light_rpc(self):
+        self._stress_test_rpc(light_rpc)
+
+    @dist_init
+    def test_stress_heavy_rpc(self):
+        self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
+
+    @dist_init
+    def test_stress_heavy_rpc_torchscript(self):
+        self._stress_test_rpc(
+            heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret(self):
+        self._builtin_remote_ret(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self(self):
+        self._builtin_remote_self(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @staticmethod
+    def _multi_args_fn(n, sparse=False):
+        if sparse:
+            return (build_sparse_tensor(), build_sparse_tensor())
+        else:
+            return (torch.ones(n, n), torch.ones(n, n))
+
+    @dist_init
+    def test_multi_builtin_remote_ret(self):
+        self._test_multi_remote_call(torch.add, False, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_py_udf_remote(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
+
+    @staticmethod
+    def _multi_kwargs_fn(n, sparse=False):
+        if sparse:
+            return {
+                "a": build_sparse_tensor(),
+                "b": build_sparse_tensor(),
+                "c": build_sparse_tensor(),
+            }
+        else:
+            return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
+
+    @dist_init
+    def test_multi_py_udf_remote(self):
+        self._test_multi_remote_call(
+            my_function, False, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args(self):
+        self._py_rref_args(
+            torch.ones(2, 2), 1, torch.ones(2, 2), 2, torch.ones(2, 2) * 2 + 3
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share(self):
+        self._py_rref_args_user_share(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args(self):
+        self._py_rpc_rref_args(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_nested_remote(self):
+        self._nested_remote(nested_remote, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_nested_rref(self):
+        self._nested_rref(nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
+
+    @dist_init
+    def test_nested_rref_stress(self):
+        self._nested_rref_stress(
+            nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2
+        )
+
+    @dist_init
+    def test_multi_layer_nested_async_rpc(self):
+        # This test will exit right away, but there will be a chain of async
+        # RPCs. The termination algorithm should detect those messages properly.
+        # Otherwise, some peer could exit early, leaving others to timeout
+        # errors or connection closed errors.
+        ttl = 20
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
+
+    @dist_init
+    def test_remote_with_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        # check ref to other workers
+        rref = rpc.remote(worker_name(dst_rank), raise_func)
+        with self.assertRaises(ValueError):
+            rref.to_here()
+        # check ref to itself
+        rref = rpc.remote(worker_name(self.rank), no_result, args=(10,))
+        with self.assertRaises(TypeError):
+            rref.to_here()
+
+    @dist_init
+    def test_rpc_return_rref(self):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref = rpc.rpc_sync(
+            worker_name(dst_rank1),
+            rpc_return_rref,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_rref_forward_chain(self):
+        ttl = 8
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1))
+
+        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
+
+        for _ in range(ttl):
+            self.assertEqual(len(ret_rref), 1)
+            ret_rref = ret_rref[0].to_here()
+
+        ret = ret_rref
+        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
+
+    @dist_init
+    def test_local_rref_no_fork(self):
+        local_rref = RRef(35)
+        self.assertEqual(local_rref.local_value(), 35)
+
+    @dist_init
+    def test_local_value_not_on_owner(self):
+        # ensure that an error message is thrown if a user tries to call
+        # local_value() on a non-owning node.
+        next_rank = (self.rank + 1) % self.world_size
+        rref = rpc.remote(
+            worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                rf"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), "
+                rf"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), "
+                r"can't call localValue\(\) on user "
+                rf"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). "
+                rf"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)"
+            ),
+        ):
+            rref.local_value()
+
+    @dist_init
+    def test_return_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref_list = rpc.rpc_sync(
+            worker_name(dst_rank), get_rref_list, args=([1, 2, 3],)
+        )
+
+        for rref in rref_list:
+            rpc.rpc_sync(
+                rref.owner(),
+                _call_method_on_rref,
+                args=(MyClass.increment_value, rref, 10),
+            )
+
+        rets = [
+            rpc.rpc_sync(
+                rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref)
+            )
+            for rref in rref_list
+        ]
+
+        self.assertEqual(rets, [11, 12, 13])
+
+    @dist_init
+    def _test_rref_type(self, blocking):
+        def launched_rpc(events):
+            expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
+            return any(e.name.startswith(expected_name) for e in events)
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))
+
+        with _profile() as p:
+            t = rref._get_type(blocking=blocking)
+            if not blocking:
+                t = t.wait()
+
+        self.assertTrue(launched_rpc(p.function_events))
+        expected_type = type(torch.ones(2))
+        self.assertEqual(t, expected_type)
+
+        futs = []
+
+        def verify(fut):
+            self.assertEqual(fut.value(), expected_type)
+
+        with _profile() as p:
+            for _ in range(10):
+                t = rref._get_type(blocking=blocking)
+                if not blocking:
+                    futs.append(t)
+                    t.add_done_callback(verify)
+                    t = t.wait()
+                self.assertEqual(t, expected_type)
+
+        if not blocking:
+            # Note that cached calls with blocking=False all return the same
+            # cached original future.
+            first_fut = futs[0]
+            for f in futs[1:]:
+                self.assertTrue(f is first_fut)
+        # Ensure we never launch another RPC, other than for the very
+        # first call.
+        self.assertFalse(launched_rpc(p.function_events))
+        self.assertEqual(t, type(torch.ones(2)))
+
+        rref = rpc.remote(dst, MyClass, args=(0,))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_blocking(self):
+        self._test_rref_type(blocking=True)
+
+    def test_rref_type_non_blocking(self):
+        self._test_rref_type(blocking=False)
+
+    @dist_init
+    def _test_rref_type_with_error(self, blocking):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        # 10 ms timeout
+        rref = rpc.remote(dst, raise_func)
+        # Blocking: error raised inline
+        if blocking:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                rref._get_type(blocking=blocking)
+        else:
+            # Non-blocking: Immediately return future, block on wait
+            fut = rref._get_type(blocking=blocking)
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    def test_rref_type_with_error_blocking(self):
+        self._test_rref_type_with_error(blocking=True)
+
+    def test_rref_type_with_error_non_blocking(self):
+        self._test_rref_type_with_error(blocking=False)
+
+    @dist_init
+    def _test_rref_type_owner(self, blocking):
+        rref = RRef(torch.ones(2) + 1)
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, type(torch.ones(2)))
+
+        rref = RRef(MyClass(0))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_owner_blocking(self):
+        self._test_rref_type_owner(blocking=True)
+
+    def test_rref_type_owner_non_blocking(self):
+        self._test_rref_type_owner(blocking=False)
+
+    @staticmethod
+    def _slow_add(x, y):
+        time.sleep(1)
+        return x + y
+
+    @dist_init
+    def test_rref_type_slow_init(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1))
+        self.assertEqual(rref._get_type(), type(torch.ones(2)))
+
+    @dist_init
+    def test_owner_equality(self):
+        a = RRef(40)
+        b = RRef(50)
+
+        other_rank = (self.rank + 1) % self.world_size
+        other_a = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_b = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_a.to_here()  # to ensure clean termination
+        other_b.to_here()
+
+        self.assertNotEqual(a.owner(), 23)
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertNotEqual(a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertEqual(a.owner(), a.owner())
+        self.assertEqual(a.owner(), b.owner())
+        self.assertEqual(a.owner(), rpc.get_worker_info())
+        x = {}
+        x[a.owner()] = a
+        x[other_a.owner()] = other_a
+        self.assertEqual(x[a.owner()], a)
+        self.assertEqual(x[b.owner()], a)
+        self.assertEqual(x[other_a.owner()], other_a)
+        self.assertEqual(x[other_b.owner()], other_a)
+        self.assertEqual(len(x), 2)
+
+    @dist_init
+    def test_pass_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        rref = RRef(40)
+        self.assertEqual(
+            rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90
+        )
+        self.assertEqual(
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90
+        )
+        self.assertEqual(
+            rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90
+        )
+
+    @dist_init
+    def test_remote_same_worker(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
+        )
+        rref_b = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
+        )
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+
+    @dist_init(setup_rpc=True)
+    def test_call_method_on_rref(self):
+        """
+        Tests that it is possible to call an instance method on a remote object
+        by using rref.owner() as destination of the call.
+        """
+        vals = [10, 2, 5, 7]
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        # creates a remote object
+        rref = rpc.remote(dst_worker, MyClass, args=(vals[0],))
+
+        # modifies state of the remote object
+        rpc.rpc_sync(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[1]),
+        )
+        rpc.rpc_async(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[2]),
+        ).wait()
+        rpc.remote(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[3]),
+        ).to_here()
+
+        # queries state of the remote object
+        result = rpc.rpc_sync(
+            dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref)
+        )
+
+        self.assertEqual(result, sum(vals))
+
+    # Notice `rpc.api.shutdown()` accesses
+    # `_delete_all_user_and_unforked_owner_rrefs` through
+    # `torch.distributed.rpc.api`, so patching
+    # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will
+    # not help.
+    @mock.patch.object(
+        torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs"
+    )
+    def _test_rref_leak(
+        self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak
+    ):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        rref = rpc.remote(  # noqa: F841
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.ones(2, 2), 1),
+        )
+
+        import torch.distributed.rpc.api as api
+
+        if ignore_leak:
+            api._ignore_rref_leak = True
+            rpc.shutdown(graceful=True)
+        else:
+            api._ignore_rref_leak = False
+            with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
+                rpc.shutdown(graceful=True)
+
+    @dist_init(setup_rpc=False)
+    def test_rref_leak(self):
+        self._test_rref_leak(ignore_leak=False)
+
+    @dist_init(setup_rpc=False)
+    def test_ignore_rref_leak(self):
+        self._test_rref_leak(ignore_leak=True)
+
+    @dist_init
+    def test_rref_str(self):
+        rref1 = RRef(self.rank)
+        id_class = "GloballyUniqueId"
+        self.assertEqual(
+            f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))",
+            rref1.__str__(),
+        )
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(
+            rref2.__str__(),
+            f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), "
+            f"ForkId = {id_class}(created_on={self.rank}, local_id=2))",
+        )
+
+    @dist_init
+    def test_rref_get_future(self):
+        # Tests that we can obtain the future corresponding to the creation of
+        # the RRef on remote end
+        if self.rank == 0:
+            # Builtin
+            rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # UDF
+            rref = rpc.remote(worker_name(1), foo_add, args=())
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # Script
+            rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1),))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+    @dist_init
+    def test_rref_context_debug_info(self):
+        # This test checks local states that are modified by remote workers.
+        # This means that we would need barrier before and after every check.
+        # The barrier before the check makes sure that all previous states are
+        # cleared globally, the barrier after ensures that no following states
+        # change gets into the current check.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # Check 1: local RRef does not update owners_ map or add a pending user.
+        #################################################
+
+        rref1 = RRef(self.rank)
+
+        # don't need a barrier here as local RRef is handled by this thread
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertIn("num_pending_users", info)
+        # RRef on local value is not added to context until shared across RPC
+        self.assertEqual(0, int(info["num_owner_rrefs"]))
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after the check 1
+        dist.barrier()
+
+        # Check 2: Sharing RRef as an arg should update owners_ map
+        ###########################################################
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,))
+
+        # barrier before check 2
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(1, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after check 2
+        dist.barrier()
+
+        # clear states for check 2
+        rpc.rpc_sync(worker_name(dst_rank), clear_global_rref)
+
+        # Wait for owner rref to be cleared.
+        while int(info["num_owner_rrefs"]) != 0:
+            info = _rref_context_get_debug_info()
+            time.sleep(0.1)
+        dist.barrier()
+
+        # Check 3: rpc.remote call should update owners_ map
+        ####################################################
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref3 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref2.to_here()
+        rref3.to_here()
+
+        # barrier before check 3
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(2, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+
+        # barrier after check 3
+        dist.barrier()
+
+    @dist_init
+    def test_disable_gil_profiling(self):
+        # test that rpc.enable_gil_profiling(false) will result in
+        # GIL wait time not being recorded.
+
+        # GIL profiling should be disabled by default.
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"])
+        rpc.enable_gil_profiling(True)
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertIn("agent.gil_average_wait_time_us", info)
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown(self):
+        # test that we can start RPC and then immediately locally shutdown
+        # without sending any messages.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_debug_info(self):
+        # only test keys in this test case. Values should be covered by
+        # individual module debug info tests
+        import torch.distributed.autograd as dist_autograd
+
+        info = _get_debug_info()
+        rref_info = _rref_context_get_debug_info()
+        agent_info = rpc.api._get_current_rpc_agent().get_debug_info()
+        autograd_info = dist_autograd._get_debug_info()
+        common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys()
+        self.assertEqual(0, len(common_keys))
+        expected = {}
+        expected.update(rref_info)
+        expected.update(agent_info)
+        expected.update(autograd_info)
+        # NB: Key ordering is only preserved in python 3.6+. So here, we
+        # manually check keys are equal.
+        for key in expected:
+            self.assertIn(key, info.keys())
+
+        for key in info:
+            self.assertIn(key, expected.keys())
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_handle_send_exceptions(self):
+        # test that if a callee node has gone down, we raise an appropriate
+        # exception instead of just crashing.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc._set_rpc_timeout(10)
+        # This barrier is needed to ensure that some workers do not exit before
+        # others have been brought up.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        if self.rank == 1:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker = worker_name(dst_rank)
+            # allow destination worker to exit without joining
+            error_str = self.get_shutdown_error_regex()
+            wait_until_node_failure(dst_rank, error_str)
+            fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3))
+            # Shutdown sequence is not very well defined and as a result
+            # we can see any of the error messages defined in get_shutdown_error_regex.
+            with self.assertRaisesRegex(RuntimeError, error_str):
+                fut.wait()
+        # exit all workers non-gracefully.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_deadlock(self):
+        # this test is copied from https://github.com/pytorch/pytorch/issues/45089
+        if self.rank == 1:
+            dst1 = worker_name((self.rank + 1) % self.world_size)
+            x = torch.ones(2)
+            y = torch.ones(2)
+            rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
+
+        dist_initialized = dist.is_initialized()
+        if not dist_initialized:
+            dist.init_process_group(
+                backend="gloo",
+                init_method=self.file_init_method,
+                rank=self.rank,
+                world_size=self.world_size,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown_with_rpc(self):
+        # test that we can start RPC, send RPCs, and then run local shutdown.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # A barrier is needed to ensure that all RPCs are processed.
+        # Otherwise, some RPCs can timeout since the receiving end
+        # has terminated.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_default_rpc_timeout(self):
+        timeout = 0.5
+
+        # A new `RpcBackendOptions` is constructed
+        # when accessing `self.rpc_backend_options`.
+        rpc_backend_options = self.rpc_backend_options
+        rpc_backend_options.rpc_timeout = timeout
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+        set_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(timeout, set_timeout)
+        rpc.shutdown()
+
+    @dist_init
+    def test_default_timeout_used(self):
+        """
+        Tests that if no timeout is passed into rpc_async and rpc_sync, then the
+        default timeout is used.
+        """
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc._set_rpc_timeout(0.001)  # 1 ms
+        # futures should time out and be marked with an exception indicating it as such.
+        futs = [
+            rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=())
+            for _ in range(10)
+        ]
+        expected_error = self.get_timeout_error_regex()
+        for fut in futs:
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                fut.wait()
+
+        # ensure that if a new timeout is set old futures don't time out but new ones do.
+        rpc._set_rpc_timeout(200)  # 200 seconds
+        # create a longstanding RPC.
+        fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        # now, set a short timeout.
+        rpc._set_rpc_timeout(0.001)
+        # fut2 should time out, fut1 should not.
+        fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut2.wait()
+        fut1.wait()
+
+        # Zero timeout means infinity, so future should run to completion.
+        rpc._set_rpc_timeout(0)
+        rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait()
+
+        # reset to default timeout so shutdown messages can process cleanly.
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init
+    def test_rpc_timeouts(self):
+        # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803)
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+        timeout = 0.1  # 100 ms
+        expected_error = self.get_timeout_error_regex()
+        # Test async UDF
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait()
+
+        # Test sync UDF
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # If we set a default timeout for RPCs, it should be respected, though
+        # still overridden if we pass in a different timeout to the APIs.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # The RPCs should run to completion since we override the timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5)
+        # Passing in a zero timeout should ensure that the RPC won't time out.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0)
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    def test_dist_init_decorator(self):
+        @dist_init(setup_rpc=False)
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+        @dist_init
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+    def test_use_rpc_pickler(self):
+        class TestPickler:
+            pass
+
+        test_pickler = TestPickler()
+        with _use_rpc_pickler(test_pickler):
+            self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
+        self.assertTrue(
+            torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
+        )
+
+    @dist_init
+    def test_wait_all(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
+            self.assertTrue(len(_thread_local_var.future_list) == 1)
+            self.assertTrue(
+                isinstance(_thread_local_var.future_list[0], torch._C.Future)
+            )
+        self.assertTrue(fut.done())
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_multiple_call(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            for i in range(20):
+                fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1))
+                res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1))
+                self.assertEqual(res, torch.ones(i, i) + 1)
+                self.assertEqual(fut.wait(), torch.ones(i, i) + 1)
+            self.assertTrue(len(_thread_local_var.future_list) == 20)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_timeout(self):
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error), _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            timeout = 0.1  # 100 ms
+            rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_user_func(self):
+        with self.assertRaises(ValueError), _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            rpc.rpc_async(dst, raise_func)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_body(self):
+        with self.assertRaises(ValueError), _wait_all():
+            raise_func()
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_custom_exception_throw_during_reconstruction(self):
+        """
+        Test that we still throw info about the remote side exception even when
+        we cannot recreate it on client side.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            exc_caught = False
+            dst = worker_name(0)
+            try:
+                rpc.rpc_sync(dst, custom_raise_func, args=())
+            except RuntimeError as e:
+                exc_caught = True
+                msg = str(e)
+                print(f"Got msg {msg}")
+                self.assertTrue("Original exception on remote side was" in msg)
+                self.assertTrue("CustomException" in msg)
+            except BaseException as e:  # noqa: B036
+                raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
+            finally:
+                self.assertTrue(exc_caught)
+
+        dist.barrier()
+
+    timed_out_rpc_event = None
+
+    @staticmethod
+    def timed_out_rpc():
+        RpcTest.timed_out_rpc_event.wait()
+
+    @dist_init
+    def test_wait_all_exit_early_python(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func)
+        fut3 = rpc.rpc_async(dst, raise_func)
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(ValueError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_builtin(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, "size of tensor"):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_script_function(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+        fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_function_not_on_callee(self):
+        # test that if a function does not exist on a callee, we don't crash,
+        # instead we get an AttributeError indicating that the func does not exist.
+        this_module = sys.modules[__name__]
+        caller_worker = "worker0"
+        callee_worker = "worker1"
+
+        if self.rank == 1:
+            # Use delattr to remove the binding of a func on this nodes
+            delattr(this_module, "foo_add")
+            # notify remote end that we have removed it.
+            rpc.rpc_sync(caller_worker, set_value, args=(self.rank,))
+
+        if self.rank == 0:
+            # func exists on caller, but not callee.
+            # wait for remote end to remove the binding of foo_add func.
+            wait_for_value_future()
+            # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error.
+            self.assertTrue(hasattr(this_module, "foo_add"))
+            with self.assertRaisesRegex(RuntimeError, "RPC pickler does not serialize"):
+                rpc.rpc_sync(callee_worker, foo_add, args=())
+
+    @dist_init
+    def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        a = MyClass(1)
+        b = MyClass(2)
+
+        # This is to make Python not garbage collect a and b.
+        a.other = b
+        b.other = a
+
+        n = self.rank
+        a.rref = rpc.remote(dst_worker_name, torch.add, args=(torch.ones(n, n), 2))
+
+    @dist_init(setup_rpc=False)
+    def test_use_rref_after_shutdown(self):
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # pass in graceful=True to ensure that local UserRRefs are deleted.
+        rpc.shutdown(graceful=True)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call to_here\\(\\) on it after deletion."
+        ):
+            rref.to_here()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call fork an UserRRef after deletion."
+        ):
+            import torch.distributed.rpc.internal as internal
+
+            internal.serialize(rref)
+
+    @staticmethod
+    def _return_gpu_tensor():
+        return torch.rand(3, 3).cuda(0)
+
+    @staticmethod
+    def _return_gpu_tensor_list():
+        return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)]
+
+    @staticmethod
+    def _gpu_tensor_list_arg(tensor_list):
+        return torch.rand(3, 3)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_py_pickle_not_supported(self):
+        local_rref = RRef(35)
+        with (
+            TemporaryFileName() as fname,
+            self.assertRaisesRegex(
+                RuntimeError, "Can not pickle rref in python pickler"
+            ),
+        ):
+            torch.save(local_rref, fname)
+
+    @dist_init
+    def test_remote_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            raise_or_inc,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_non_cont_tensors(self):
+        if self.rank == 0:
+            # Create a non-contiguous tensor.
+            t = torch.rand(5, 5)
+            t_view = t.narrow(1, 2, 2)
+            self.assertFalse(t_view.is_contiguous())
+            t_cont = t_view.contiguous()
+            self.assertTrue(t_cont.is_contiguous())
+            self.assertEqual(t_view, t_cont)
+
+            # Send non-cont tensor over RPC.
+            next_rank = (self.rank + 1) % self.world_size
+            t_ret = rpc.rpc_sync(
+                worker_name(next_rank), non_cont_test, args=(t_view, t_cont)
+            )
+
+            # Verify the returned tensor.
+            self.assertEqual(t_view, t_ret)
+            self.assertFalse(t_ret.is_contiguous())
+
+    @dist_init
+    def test_callback_simple(self):
+        set_by_cb = concurrent.futures.Future()
+        n = self.rank + 1
+
+        def callback(fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            set_by_cb.set_result(ret.clone() + 1)
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+        self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_wrong_arg_num(self):
+        n = self.rank + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_fut = fut.then(my_function)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "my\\_function\\(\\) missing 2 required positional arguments"
+        ):
+            cb_fut.wait()
+
+    @dist_init
+    def test_callback_wrong_arg_type(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
+        fut1 = fut0.then(lambda x: x + 1)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "unsupported operand type\\(s\\) for \\+"
+        ):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_multi(self):
+        num_cbs = 10
+        n = self.rank + 1
+
+        def callback(idx, fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            return ret + idx
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_futs = [fut.then(partial(callback, idx)) for idx in range(num_cbs)]
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        for idx in range(num_cbs):
+            self.assertEqual(cb_futs[idx].wait(), torch.ones(n, n) * 2 + idx)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), torch.add, args=(torch.ones(n, n), 1)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_callback_in_rpc(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, add_use_future_cb, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_callback_with_ret(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            fut2 = rpc.rpc_async(dst, torch.add, args=(fut0.wait(), 1)).then(
+                lambda fut1: fut1.wait() + 1
+            )
+
+            return fut2.wait()
+
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)).then(callback)
+
+        self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_callback_with_error(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut0.wait()
+            raise RuntimeError("Another expected error")
+
+        fut1 = rpc.rpc_async(dst, raise_func).then(callback)
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_none(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(TypeError, "incompatible function arguments."):
+            rpc.rpc_async(dst, raise_func).then(None)
+
+    @dist_init
+    def test_add_done_callback(self):
+        set_by_cb = False
+        n = self.rank + 1
+
+        def callback(fut):
+            nonlocal set_by_cb
+            fut.wait()
+            set_by_cb = True
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.add_done_callback(callback)
+        fut_then = fut.then(lambda _: True)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        fut_then.wait()
+        self.assertTrue(set_by_cb)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_mark_future_twice(self):
+        fut = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.zeros(2, 2), 1),
+        )
+        self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
+        with self.assertRaisesRegex(
+            RuntimeError, "Future can only be marked completed once"
+        ):
+            fut.set_result(1)
+
+    @dist_init
+    def test_pickle_future(self):
+        fut = torch.futures.Future()
+        errMsg = "Can not pickle torch.futures.Future"
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.rpc_async(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg):
+            rpc.remote(dst, fail_on_fut, args=(fut,))
+
+    @dist_init
+    def test_future_done(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
+        fut.wait()
+        self.assertTrue(fut.done())
+
+    @dist_init
+    def test_future_done_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, raise_func)
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            fut.wait()
+        self.assertTrue(fut.done())
+
+    def _test_future_cb(self, func):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, func, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_future_in_rpc(self):
+        self._test_future_cb(add_use_future_set_result)
+
+    @dist_init
+    def test_future_nested_callback(self):
+        self._test_future_cb(add_use_future_nested_cb)
+
+    def _test_async_function_raise(self, mode):
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_raise_func, mode
+            )
+
+    @dist_init
+    def test_async_function_raise(self):
+        self._test_async_function_raise(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_raise_async(self):
+        self._test_async_function_raise(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_raise_remote(self):
+        self._test_async_function_raise(RPCExecMode.REMOTE)
+
+    def _test_async_function_wrong_return_type(self, mode):
+        errMsg = (
+            "Functions decorated with @rpc\\.async_function must return a "
+            "torch\\.futures\\.Future object,"
+        )
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode
+            )
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_async(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1))
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    def _test_async_function(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        args = (dst2, torch.ones(2, 2), 1, 2)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_async_function_with_future_ctor(self):
+        self._test_async_function(async_add_with_future_ctor)
+
+    @dist_init
+    def test_async_function_with_future_ctor_remote(self):
+        self._test_async_function(async_add_with_future_ctor, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_chained(self):
+        self._test_async_function(async_add_chained)
+
+    @dist_init
+    def test_async_function_chained_remote(self):
+        self._test_async_function(async_add_chained, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_nested(self):
+        self._test_async_function(async_add_nested)
+
+    @dist_init
+    def test_async_function_nested_remote(self):
+        self._test_async_function(async_add_nested, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_static_method(self):
+        self._test_async_function(AsyncExecutionClass.static_async_add)
+
+    @dist_init
+    def test_async_static_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.static_async_add, RPCExecMode.REMOTE
+        )
+
+    @dist_init
+    def test_async_class_method(self):
+        self._test_async_function(AsyncExecutionClass.class_async_add)
+
+    @dist_init
+    def test_async_class_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.class_async_add, RPCExecMode.REMOTE
+        )
+
+    def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        rref = rpc.remote(dst1, AsyncExecutionClass)
+
+        x = torch.ones(2, 2)
+        y = torch.ones(2, 2) + 1
+        if mode == RPCExecMode.SYNC:
+            ret = rref.rpc_sync().static_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().class_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().bound_async_add(dst2, x, x, y)
+        elif mode == RPCExecMode.ASYNC:
+            ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait()
+        elif mode == RPCExecMode.REMOTE:
+            ret = rref.remote().static_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().class_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().bound_async_add(dst2, x, x, y).to_here()
+
+        self.assertEqual(ret, 3 * 4 * x)
+
+    @dist_init
+    def test_async_class_rref_proxy(self):
+        self._test_test_async_class_rref_proxy()
+
+    @dist_init
+    def test_async_class_rref_proxy_async(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_class_rref_proxy_remote(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE)
+
+    def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        step = 3
+        args = (dst2, torch.ones(2, 2), num, step)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + num * step)
+
+    @dist_init
+    def test_async_function_multi_chained(self):
+        self._test_async_function_multi(async_add_chained_multi)
+
+    @dist_init
+    def test_async_function_multi_chained_async(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_chained_remote(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_multi_fanout(self):
+        self._test_async_function_multi(async_add_multi_fanout)
+
+    @dist_init
+    def test_async_function_multi_fanout_async(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_fanout_remote(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.REMOTE)
+
+    def _test_return_future(self, mode):
+        with self.assertRaisesRegex(
+            RuntimeError, "Can not pickle torch.futures.Future"
+        ):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), return_future, mode
+            )
+
+    @dist_init
+    def test_return_future(self):
+        self._test_return_future(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_return_future_async(self):
+        self._test_return_future(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_return_future_remote(self):
+        self._test_return_future(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_rref_timeout(self):
+        # This test is similar to ones in FaultyProcessGroupTest, but is meant to be
+        # run with other backends besides ProcessGroup.
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(2,), timeout=0.01)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        wait_until_owners_and_forks_on_rank(1, 1, rank=1)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_pg_then_rpc(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_rpc_then_pg(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init
+    def test_wait_all_with_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(10)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init
+    def test_wait_all_with_partial_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [
+            rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1)) for _ in range(10)
+        ]
+
+        futs.append(rpc.rpc_async(dst, raise_func))
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491",
+    )
+    def test_init_rpc_twice(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # Use a different file name for the next initialization
+        new_backend_options = self.rpc_backend_options
+        new_backend_options.init_method += "init_2"
+
+        # Ensure rpc initialization works again.
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=new_backend_options,
+        )
+
+        # Verify RPCs work after re-init.
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+        rpc.rpc_sync(dst, foo_add, args=())
+
+        rpc.shutdown()
+
+    def test_wrong_types(self):
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument backend must be a member of BackendType",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend="TENSORPIPE",
+            )
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument rpc_backend_options must be an instance of RpcBackendOptions",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=self.rpc_backend,
+                rpc_backend_options={"init_method": self.init_method},
+            )
+
+    def test_cannot_infer_backend_from_options(self):
+        # An exception should be raised if the backend isn't specified but
+        # options are given which are not an instance of any of the known
+        # agents' option classes.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(TypeError, "Could not infer backend for options"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                # Do _not_ pass backend.
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init
+    def test_owner_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t1 = torch.rand(10, 10, requires_grad=True)
+        rref = rpc.RRef(t1.sum() + t1.sum())
+        rref.backward()
+        expected_grad = torch.ones_like(t1) * 2
+        self.assertEqual(expected_grad, t1.grad)
+
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id)
+            self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1])
+
+        # Double backward.
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]
+            )
+
+        # Test errors.
+        with self.assertRaisesRegex(
+            RuntimeError, "tensors does not require grad and does not have a grad_fn"
+        ):
+            rpc.RRef(torch.rand(10)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "grad can be implicitly created only for scalar outputs"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Could not find autograd context with id: 100"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "RRef should contain a tensor for .backward()"
+        ):
+            rpc.RRef("foo").backward()
+
+    @staticmethod
+    def _sum(x):
+        return x.sum()
+
+    @staticmethod
+    def _identity(x):
+        return x
+
+    @dist_init
+    def test_user_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t = torch.rand(10, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._sum, args=(t,))
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]
+            )
+
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._identity, args=("foo",))
+            with self.assertRaisesRegex(
+                RuntimeError, "RRef should contain a tensor for .backward()"
+            ):
+                rref.backward(context_id)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "User RRefs require 'dist_autograd_ctx_id' to be specified",
+            ):
+                rref.backward()
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        if self.rank != 0:
+            og_func = rpc.api._broadcast_to_followers
+            og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs
+
+            # Monkey-patch _broadcast_to_followers to fail, which would ensure
+            # _all_gather on leader raises an exception.
+            def raise_error(sequence_id, objects_map):
+                og_func(sequence_id, objects_map)
+                raise RuntimeError("simulation")
+
+            # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail,
+            # which would ensure barrier is not called on followers.
+            def rref_error():
+                raise RuntimeError("simulation rref")
+
+            try:
+                rpc.api._broadcast_to_followers = raise_error
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error
+                with self.assertRaisesRegex(RuntimeError, "simulation rref"):
+                    rpc.shutdown()
+            finally:
+                rpc.api._broadcast_to_followers = og_func
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func
+        else:
+            with self.assertRaisesRegex(RuntimeError, "timed out in _all_gather"):
+                rpc.shutdown()
+
+        dist.barrier()
+
+    @dist_init
+    def test_my_parameter_server(self):
+        self._my_parameter_server(False)
+
+
+class CudaRpcTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(2)
+    @dist_init
+    def test_profiler_remote_cuda(self):
+        if self.rank != 1:
+            return
+
+        dst_cuda_0 = (self.rank + 1) % self.world_size
+        dst_cuda_1 = (self.rank + 2) % self.world_size
+        dst_worker_cuda_0 = worker_name(dst_cuda_0)
+        dst_worker_cuda_1 = worker_name(dst_cuda_1)
+
+        with _profile(use_cuda=True) as p:
+            fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0,))
+            fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1,))
+            fut1.wait()
+            fut2.wait()
+
+        def get_name(event):
+            return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        function_events = p.function_events
+        for event in function_events:
+            if event.is_async:
+                self.assertEqual(0, event.device_time_total)
+                self.assertEqual([], event.kernels)
+                self.assertEqual(0, event.device_time)
+            else:
+                if event.node_id == 1:
+                    continue
+                self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1])
+                if get_name(event) in EXPECTED_REMOTE_EVENTS:
+                    self.assertGreater(event.device_time_total, 0)
+                    self.assertEqual(1, len(event.kernels))
+                    kernel = event.kernels[0]
+                    if event.node_id == dst_cuda_0:
+                        self.assertEqual(kernel.device, 0)
+                    if event.node_id == dst_cuda_1:
+                        self.assertEqual(kernel.device, 1)
+                    self.assertGreater(event.device_time, 0)
+
+        # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled
+        # events.
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_event_names = [
+            get_name(event)
+            for event in remote_events
+            if get_name(event) in EXPECTED_REMOTE_EVENTS
+        ]
+        self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS))
+
+
+class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def test_mismatched_type_for_options(self):
+        # An exception should be raised if the options are not an instance of
+        # TensorPipeRpcBackendOptions.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(
+            TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`"
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=rpc.BackendType.TENSORPIPE,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    def test_infer_backend_from_options(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.init_method, _transports=tp_transports()
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            rank=self.rank,
+            world_size=self.world_size,
+            # Do _not_ pass backend.
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent)
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_num_worker_threads(self):
+        NUM_THREADS = 27
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=NUM_THREADS,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_set_default_timeout(self):
+        # Set a high timeout since it doesn't affect test runtime and ensures
+        # the test doesn't erroneously timeout due to slow machines.
+        timeout = 100
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=self.rpc_backend_options.num_worker_threads,
+            rpc_timeout=timeout,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        default_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(default_timeout, timeout)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_options_throw_on_timedelta_timeout(self):
+        from datetime import timedelta
+
+        timeout = timedelta()
+        # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails
+        with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"):
+            rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                rpc_timeout=timeout,
+            )
+
+    @dist_init
+    def _test_rref_get_type_timeout(self, blocking):
+        # Test where we try to get the type of a RRef from an owner, but RRef
+        # creation is slower than timeout passed into _get_type.
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.5
+        expected_err = self.get_timeout_error_regex()
+        # Blocking: blocks on inline call
+        if blocking:
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                slow_rref._get_type(timeout=timeout, blocking=blocking)
+        # Non-blocking: blocks on wait
+        else:
+            fut = slow_rref._get_type(timeout=timeout, blocking=blocking)
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                fut.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    def test_rref_get_type_timeout_blocking(self):
+        self._test_rref_get_type_timeout(blocking=True)
+
+    def test_rref_get_type_timeout_non_blocking(self):
+        self._test_rref_get_type_timeout(blocking=False)
+
+    @dist_init
+    def test_op_with_invalid_args(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Overloaded torch operator invoked from Python failed to match any schema",
+        ):
+            rpc.rpc_sync(dst, torch.add, args=())
+
+    def _test_rref_proxy_timeout(self, rref_proxy_api):
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2),))
+        # Ensure RRef is created on remote node.
+        rref.to_here()
+        rref_api = getattr(rref, rref_proxy_api)
+        self.assertTrue(
+            rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}"
+        )
+        expected_error = self.get_timeout_error_regex()
+        timeout = 2
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2))
+            if rref_api == rref.rpc_async:
+                result.wait()
+            elif rref_api == rref.remote:
+                result._get_future().wait()
+
+        # Case where rpc.remote() is stuck and exceeds timeout
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.01
+        rref_api = getattr(slow_rref, rref_proxy_api)
+        # Note that even when we call rref.rpc_async() in this case, we
+        # time out in future creation, not waiting for future. This is because
+        # rref proxy function calls rref._get_type before returning future,
+        # which blocks on the RRef being created on owner node, until the
+        # specified timeout.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
+            # rpc_async returns immediately and surface a timeout through wait()
+            if rref_api == slow_rref.rpc_async:
+                result.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    @dist_init
+    def test_rref_proxy_timeout(self):
+        for rpc_api in ["rpc_sync", "rpc_async", "remote"]:
+            self._test_rref_proxy_timeout(rpc_api)
+
+    @dist_init
+    def test_send_to_rank_sparse(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test sparse tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor()
+            y = build_sparse_tensor()
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor(coalesce=True)
+            y = build_sparse_tensor(coalesce=True)
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+    @dist_init
+    def test_self_py_udf_remote_sparse(self):
+        self._self_py_udf_remote(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg_sparse(self):
+        self._self_remote_rref_as_rpc_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg_sparse(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    def test_world_size_one_sparse(self):
+        self._world_size_one(build_sparse_tensor(), build_sparse_tensor())
+
+    @dist_init
+    def test_multi_rpc_sparse(self):
+        self._multi_rpc(True)
+
+    def test_wait_all_workers_sparse(self):
+        self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
+
+    def test_wait_all_workers_twice_sparse(self):
+        self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
+
+    @dist_init
+    def test_py_sparse_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [build_sparse_tensor(), build_sparse_tensor()]
+        ret = rpc.rpc_sync(worker_name(dst_rank), my_container_sum, args=(a,))
+        self.assertEqual(ret, my_container_sum(a))
+
+    @dist_init
+    def test_nested_rpc_sparse(self):
+        self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
+
+    @dist_init
+    def test_stress_heavy_rpc_sparse(self):
+        self._stress_test_rpc(
+            heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret_sparse(self):
+        self._builtin_remote_ret(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self_sparse(self):
+        self._builtin_remote_self(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_multi_builtin_remote_ret_sparse(self):
+        self._test_multi_remote_call(torch.add, True, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_multi_py_udf_remote_sparse(self):
+        self._test_multi_remote_call(
+            my_function, True, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args_sparse(self):
+        self._py_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 4,
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share_sparse(self):
+        self._py_rref_args_user_share(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args_sparse(self):
+        self._py_rpc_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_nested_remote_sparse(self):
+        self._nested_remote(
+            nested_remote_sparse, build_sparse_tensor() + build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_nested_rref_sparse(self):
+        self._nested_rref(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_nested_rref_stress_sparse(self):
+        self._nested_rref_stress(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_my_parameter_server_sparse(self):
+        self._my_parameter_server(True)
+
+    # Test init_rpc without world_size argument
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+    # Dynamic RPC new ranks communicate with existing ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            result = rpc.rpc_sync(
+                worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1))
+            )
+            self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+        if self.rank == 0:
+            for i in range(1, self.world_size):
+                result = rpc.rpc_sync(
+                    worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1))
+                )
+                self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc
+    @skip_if_lt_x_gpu(2)
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            for i in range(1, self.world_size):
+                dst = worker_name(i)
+                options.set_device_map(dst, {1: 0})
+                options.set_device_map(dst, {0: 1})
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # TODO: Cuda RPC is failing due to:
+        # terminate called after throwing an instance of 'c10::Error'
+        # what():  0 <= device && static_cast(device) < device_allocator.size()
+        # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937,
+        # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init?
+        # dist.barrier()
+        # if self.rank == 0:
+        #     for i in range(1, self.world_size):
+        #         x = torch.ones(2)
+        #         result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1))
+        #         result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1))
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0)
+        #         self.assertEqual(torch.device('cuda:0'), result_on_device_0.device)
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1)
+        #         self.assertEqual(torch.device('cuda:1'), result_on_device_1.device)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc_without_rank(self):
+        # default initialization uses file init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # env init
+        with self.assertRaisesRegex(ValueError, "environment variable RANK expected"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://")
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+        # tcp init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+                init_method="tcp://127.0.0.1:23456"
+            )
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_and_static_init_rpc_together(self):
+        # Initialize a static rpc group with size = self.world_size - 1
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        world_size_minus_one = self.world_size - 1
+        if self.rank < world_size_minus_one:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=world_size_minus_one,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+
+        # Attempt to add an additional dynamic group member
+        if self.rank == world_size_minus_one:
+            # Expect error message to be thrown
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "RPC group mixes statically and dynamically\
+ initialized members which is not supported.",
+            ):
+                rpc.init_rpc(
+                    name=worker_name(self.rank),
+                    backend=self.rpc_backend,
+                    rank=self.rank,
+                    rpc_backend_options=self.rpc_backend_options,
+                )
+
+
+class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def _test_device_maps(self, options, errMsg):
+        with self.assertRaisesRegex(ValueError, errMsg):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+        self.assertFalse(rpc.api._is_current_rpc_agent_set())
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_wrong_worker_name(self):
+        options = self.rpc_backend_options
+        options.set_device_map("none_exist", {0: 1})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has invalid target node names in its device maps",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_local_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {torch.cuda.device_count(): 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has source devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_remote_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: torch.cuda.device_count()})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has target devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_many_to_one(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+        options.set_device_map(dst, {0: 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has duplicated target devices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_one_to_many(self):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options.set_device_map(dst, {0: 1})
+            with self.assertRaisesRegex(
+                ValueError, "`set_device_map` only supports 1-to-1 mapping"
+            ):
+                options.set_device_map(dst, {0: 0})
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_min_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {-1: 0})
+
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {0: -1})
+
+    @staticmethod
+    def _gpu_add(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]):
+            return (x + y).to(0)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_gpu(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: 1, 1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(0)),
+        )
+        self.assertEqual(ret.device, torch.device(1))
+        self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1))
+        rpc.shutdown()
+
+    @staticmethod
+    def _gpu_add_given_devices(x, y, x_to, y_to, z_to):
+        x_device = "cpu" if x.device.type == "cpu" else x.device.index
+        y_device = "cpu" if y.device.type == "cpu" else y.device.index
+        if x_device == x_to and y_device == y_to:
+            return x.to(z_to) + y.to(z_to)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_gpu(
+        self, x_from, y_from, z_to, device_map, dst=None, fn=None
+    ):
+        fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn
+        x_to = device_map[x_from]
+        y_to = device_map[y_from]
+
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(x_from)
+        y = torch.ones(2).to(y_from)
+
+        ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to))
+
+        reverse_device_map = {device_map[k]: k for k in device_map}
+        z_from = reverse_device_map[z_to]
+
+        ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index
+        self.assertEqual(ret_device, z_from)
+        self.assertEqual(ret, torch.ones(2).to(z_from))
+
+        rpc.shutdown()
+
+    def test_device_map_cpu(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to="cpu",
+            device_map={"cpu": "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_cpu_to_gpu_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=0,
+            device_map={"cpu": 0},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_cpu_to_gpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=1,
+            device_map={"cpu": 1},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_gpu_to_cpu_default(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=0,
+            z_to="cpu",
+            device_map={0: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_to_cpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=1,
+            z_to="cpu",
+            device_map={1: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=0, device_map={0: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=1, device_map={1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default_to_non_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=1, device_map={0: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default_to_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=0, device_map={1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_1(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_2(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_3(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_4(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_5(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_6(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_7(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_8(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_1(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_2(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_3(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_4(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_5(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_6(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_7(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_8(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @staticmethod
+    def _gpu_add_multi_gpu(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]):
+            return x.to(0) + y, x - y.to(1)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_multi_gpu(self, dst):
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(1)
+        rets = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, args=(x, y)
+        )
+
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_multi_gpu(dst)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_multi_gpu(dst)
+
+    @staticmethod
+    def _gpu_add_return_to_gpu(x, y):
+        if x.device.type == "cpu" and y.device.type == "cpu":
+            return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_in_options(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
+                init_method=options.init_method,
+                num_worker_threads=options.num_worker_threads,
+                device_maps={dst: {0: 1, 1: 0}},
+                _transports=tp_transports(),
+            ),
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(1)),
+        )
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    def _test_device_maps_return_to_gpu(self, dst):
+        options = self.rpc_backend_options
+
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 2})
+        options.set_device_map(dst, {2: 3})
+        options.set_device_map(dst, {3: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu,
+            args=(torch.zeros(2), torch.ones(2)),
+        )
+        for i in range(len(rets)):
+            self.assertEqual(rets[i].device, torch.device((3 + i) % 4))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1))
+        self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @staticmethod
+    def _add_to_gpu(x, y):
+        return (x + y).to(0)
+
+    def _test_device_maps_missing_config(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = (
+            "TensorPipe RPC backend only supports CPU tensors by default.*"
+            "`set_device_map` on `TensorPipeRpcBackendOptions`"
+        )
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1))
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    def _test_device_maps_missing_config_response(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = "Response device mapping is not available"
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                )
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                ).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config(self):
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_missing_config_not_timeout(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        timeout = rpc.get_rpc_timeout()
+
+        tik = time.time()
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+        rpc.shutdown()
+        tok = time.time()
+
+        self.assertTrue(tok - tik < timeout)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote(self):
+        self._test_device_maps_missing_config(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_remote(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(
+            dst, TensorPipeAgentCudaRpcTest._add_to_gpu, args=(torch.zeros(2), 1)
+        )
+
+        self.assertEqual(rref.to_here().device.index, 1)
+        self.assertEqual(rref.to_here(), torch.ones(2).to(1))
+
+        rpc.shutdown()
+
+    @staticmethod
+    def _slow_add_on_user_stream(x, y):
+        s0 = torch.cuda.current_stream(x.device)
+        s1 = torch.cuda.Stream(device=x.device)
+        s1.wait_stream(s0)
+        x.record_stream(s1)
+        y.record_stream(s1)
+        with torch.cuda.stream(s1):
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            z = x + y
+        s0.wait_stream(s1)
+        z.record_stream(s0)
+        return z
+
+    def _test_custom_stream(self, fn, device_map):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        fn(dst)
+
+        rpc.shutdown()
+
+    def _test_stream_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, x)
+        )
+        self.assertEqual(ret, 2 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream(self):
+        self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"})
+
+    def _test_stream_multi_async(self, dst):
+        futs = []
+        for i in range(20):
+            x = torch.ones(2, 2).to(0) * i
+            futs.append(
+                rpc.rpc_async(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
+                    args=(x, x),
+                )
+            )
+
+        for i in range(20):
+            self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_multi(self):
+        self._test_custom_stream(self._test_stream_multi_async, {"cuda:0": "cuda:1"})
+
+    @staticmethod
+    def _nested_slow_add_on_user_stream(dst, x, y, z):
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, y)
+        )
+
+        return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z)
+
+    def _test_stream_nested_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        y = torch.ones(2, 2).to(0) * 2
+        z = torch.ones(2, 2).to(0) * 3
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+            args=(nested_dst, x, y, z),
+        )
+        self.assertEqual(ret, 6 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested(self):
+        self._test_custom_stream(
+            self._test_stream_nested_sync, {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
+        )
+
+    def _test_stream_nested_multi_async(self, dst):
+        if self.rank == 0:
+            futs = []
+            n = 5
+            xs, ys, zs = [], [], []
+            for i in range(n):
+                x = torch.ones(2, 2).to(0) * (i - 1)
+                y = torch.ones(2, 2).to(0) * i
+                z = torch.ones(2, 2).to(0) * (i + 1)
+                xs.append(x)
+                ys.append(y)
+                zs.append(z)
+                nested_dst = worker_name((self.rank + 2) % self.world_size)
+                futs.append(
+                    rpc.rpc_async(
+                        dst,
+                        TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+                        args=(nested_dst, x, y, z),
+                    )
+                )
+
+            for i in range(n):
+                self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i])
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested_multi(self):
+        self._test_custom_stream(
+            self._test_stream_nested_multi_async,
+            {"cuda:0": "cuda:1", "cuda:1": "cuda:0"},
+        )
+
+    @staticmethod
+    def _gpu_add_wrong_gpus(x, y):
+        if x.is_cuda and y.is_cuda:
+            return x.cpu() + y.cuda()
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_mismatch(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(0)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected all tensors to be on the same device, but found at least two devices",
+        ):
+            rpc.rpc_sync(
+                dst, TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, args=(x, y)
+            )
+
+        rpc.shutdown()
+
+    def _test_rref_synchronization(self, local_device, remote_device):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                x = torch.randn(200, 1, 28, 28).to(local_device)
+                actual = rref.remote().forward(x).to_here()
+                expected = rref.rpc_sync().forward(x)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_to_here_synchronization1(self):
+        self._test_rref_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization2(self):
+        self._test_rref_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization3(self):
+        self._test_rref_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization4(self):
+        self._test_rref_synchronization("cuda:0", "cuda:1")
+
+    def _test_rref_as_arg_synchronization(
+        self, local_device, remote_device, devicesOptions=None
+    ):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size)
+        options.set_device_map(input_src, {remote_device: local_device})
+
+        if devicesOptions is not None:
+            options.set_devices(devicesOptions[self.rank])
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                actual = rref.remote().forward(rref_x, True).to_here()
+                expected = rref.rpc_sync().forward(rref_x, True)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization1(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization2(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization3(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization4(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization5(self):
+        self._test_rref_as_arg_synchronization(
+            "cuda:0",
+            "cuda:0",
+            [["cuda:0"] for _ in range(4)],  # devicesOptions
+        )
+
+    @staticmethod
+    def _rref_relay(rref):
+        return rref.to_here()
+
+    def _test_rref_forward_synchronization(self, local_device, remote_device):
+        options = self.rpc_backend_options
+
+        input_src = worker_name(0)
+        model_dst = worker_name(1)
+        out_relay = worker_name(2)
+
+        if self.rank == 0:
+            # for 1) model construction 2) forward execution
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+            # Forward output will be first copied to the relay node before
+            # returning to the worker. This is intentional, to test RRef
+            # forward CUDA stream synchronizations.
+            options.set_device_map(out_relay, {local_device: local_device})
+        elif self.rank == 1:
+            # worker1 hosts the model and runs forward. The forward functions
+            # calls RRef.to_here(), hence needs to configure the device map
+            options.set_device_map(input_src, {remote_device: local_device})
+        elif self.rank == 2:
+            # worker2 will get the out RRef and call to_here() and hence, needs
+            # to configure device map.
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                rref_out = rref.remote().forward(rref_input, True)
+                out = rpc.remote(
+                    out_relay, TensorPipeAgentCudaRpcTest._rref_relay, args=(rref_out,)
+                ).to_here()
+                expected = rref.rpc_sync().forward(rref_input, True)
+                self.assertEqual(out, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_forward_synchronization1(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization2(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization3(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization4(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    def _test_owner_rref_forward_synchronization(self, local_device, remote_device):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            options.set_device_map("w0", {local_device: remote_device})
+            rpc.init_rpc("w0", rank=0, world_size=1, rpc_backend_options=options)
+
+            model = (
+                rpc.remote("w0", torch.nn.Linear, (2048, 20000))
+                .remote()
+                .to(remote_device)
+            )
+            for _ in range(30):
+                data = torch.rand(2048, 2048).to(local_device)
+                output = model.rpc_sync().forward(data)
+                # to_here() internally calls localValue as the caller is
+                # the owner of the RRef.
+                v0 = rpc.RRef(output).remote().sum().to_here().item()
+                v1 = output.sum().item()
+                self.assertEqual(v0, v1)
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_owner_rref_forward_synchronization1(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization2(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization3(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization4(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    @staticmethod
+    def _return_tensor_view(i):
+        x = torch.ones(1000, 200).cuda(0) * i
+        torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+        # serialization of the return value will create a new tensor from the
+        # view, which is done outside of the user function.
+        return x.split(100)[0]
+
+    @skip_if_lt_x_gpu(1)
+    def test_tensor_view_as_return_value(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        futs = [
+            rpc.rpc_async(
+                dst, TensorPipeAgentCudaRpcTest._return_tensor_view, args=(i,)
+            )
+            for i in range(5)
+        ]
+
+        for i in range(5):
+            self.assertEqual(torch.ones(100, 200) * i, futs[i].wait())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected source devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options = self.rpc_backend_options
+            options.set_device_map(dst, {0: 0})
+            options.set_devices([1])
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch_reverse(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected target devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+
+            options = rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                device_maps={dst: {0: 1}},
+                devices=[0],
+            )
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_int(self):
+        Future(devices=[0])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_str(self):
+        Future(devices=["cuda:0"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_device(self):
+        Future(devices=[torch.device("cuda", 0)])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_not_cuda(self):
+        with self.assertRaisesRegex(
+            ValueError, "Expected devices to have indices, got cpu"
+        ):
+            Future(devices=["cpu"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_callback_changes_devices(self):
+        # We check proper CUDA stream synchronization by filling the tensor with
+        # the expected value in one stream, and reading it from another stream.
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:0", "cuda:1"])
+
+        def cb(fut):
+            t0 = fut.value()
+            tensor1.copy_(t0, non_blocking=True)
+            return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_value_on_bad_device(self):
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:1"])
+
+        # As a plus, we test that futures still invoke callbacks even in case of
+        # error, and that the child futures are successful if those callbacks
+        # don't access the parent future.
+        def cb(fut):
+            with torch.cuda.device("cuda:1"):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor1.fill_(1)
+                return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with self.assertRaisesRegex(
+            ValueError,
+            r"The result contained tensors residing on device\(s\) cuda:0 "
+            r"which are not among the expected device\(s\) cuda:1",
+        ):
+            parent_future.wait()
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t = torch.zeros((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 1).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_nested_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        a = torch.ones((100,), device="cuda:0")
+        b = torch.ones((100,), device="cuda:0")
+        c = torch.ones((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 3).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_modify_tensor_inplace(self):
+        tensor = torch.zeros((100,), device="cuda:0")
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        tensor.fill_(1)
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_replace_tensor(self):
+        tensor_list = [torch.zeros((100,), device="cuda:0")]
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor_list)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        # We set things up so that the original tensor contained in the list
+        # gets deleted once we replace it with the other one. This will
+        # invalidate any cached information held by the future.
+        tensor_list[0] = torch.ones((100,), device="cuda:0")
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_with_unpickleable_attributes(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),))
+        rref.rpc_sync().increase(1)
+        ret = rref.rpc_sync().sum()
+        self.assertEqual(ret, 42)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..021ae60468009d2fd4fa947c90455d99c1c6d54e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
@@ -0,0 +1,28 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+from torch.testing._internal.common_distributed import tp_transports
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend, init_method=self.init_method, _transports=tp_transports()
+        )
+
+    def get_shutdown_error_regex(self):
+        # FIXME Once we consolidate the error messages returned by the
+        # TensorPipe agent put some more specific regex here.
+        error_regexes = [".*"]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24e4f97f05df22396dc08e3e6bc381085477882
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
@@ -0,0 +1,188 @@
+# mypy: allow-untyped-defs
+
+import os
+import sys
+import unittest
+
+from torch.testing._internal.common_distributed import MultiProcessTestCase
+from torch.testing._internal.common_utils import (
+    find_free_port,
+    IS_SANDCASTLE,
+    TEST_WITH_DEV_DBG_ASAN,
+)
+from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
+    CudaDdpComparisonTest,
+    DdpComparisonTest,
+    DdpUnderDistAutogradTest,
+)
+from torch.testing._internal.distributed.nn.api.remote_module_test import (
+    CudaRemoteModuleTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+)
+from torch.testing._internal.distributed.rpc.dist_autograd_test import (
+    CudaDistAutogradTest,
+    DistAutogradTest,
+    FaultyAgentDistAutogradTest,
+    TensorPipeAgentDistAutogradTest,
+    TensorPipeCudaDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
+    DistOptimizerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.parameter_server_test import (
+    ParameterServerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
+    ReinforcementLearningRpcTest,
+)
+from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
+    FaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
+    JitDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest
+from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
+    JitFaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+from torch.testing._internal.distributed.rpc.rpc_test import (
+    CudaRpcTest,
+    RpcTest,
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeAgentRpcTest,
+)
+
+
+def _check_and_set_tcp_init():
+    # if we are running with TCP init, set main address and port
+    # before spawning subprocesses, since different processes could find
+    # different ports.
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        os.environ["MASTER_ADDR"] = "127.0.0.1"
+        os.environ["MASTER_PORT"] = str(find_free_port())
+
+
+def _check_and_unset_tcp_init():
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        del os.environ["MASTER_ADDR"]
+        del os.environ["MASTER_PORT"]
+
+
+# The tests for the RPC module need to cover multiple possible combinations:
+# - different aspects of the API, each one having its own suite of tests;
+# - different agents (ProcessGroup, TensorPipe, ...);
+# To avoid a combinatorial explosion in code size, and to prevent forgetting to
+# add a combination, these are generated automatically by the code in this file.
+# Here, we collect all the test suites that we need to cover.
+# We then have one separate file for each agent, from which
+# we call the generate_tests function of this file, passing to it a fixture for
+# the agent, which then gets mixed-in with each test suite.
+
+
+@unittest.skipIf(
+    TEST_WITH_DEV_DBG_ASAN,
+    "Skip ASAN as torch + multiprocessing spawn have known issues",
+)
+class SpawnHelper(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        _check_and_set_tcp_init()
+        self._spawn_processes()
+
+    def tearDown(self):
+        _check_and_unset_tcp_init()
+        super().tearDown()
+
+
+# This list contains test suites that are agent-agnostic and that only verify
+# compliance with the generic RPC interface specification. These tests should
+# *not* make use of implementation details of a specific agent (options,
+# attributes, ...). These test suites will be instantiated multiple times, once
+# for each agent (except the faulty agent, which is special).
+GENERIC_TESTS = [
+    RpcTest,
+    ParameterServerTest,
+    DistAutogradTest,
+    DistOptimizerTest,
+    JitRpcTest,
+    JitDistAutogradTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+    DdpUnderDistAutogradTest,
+    DdpComparisonTest,
+    ReinforcementLearningRpcTest,
+]
+GENERIC_CUDA_TESTS = [
+    CudaRpcTest,
+    CudaDistAutogradTest,
+    CudaRemoteModuleTest,
+    CudaDdpComparisonTest,
+]
+
+
+# This list contains test suites that will only be run on the TensorPipeAgent.
+# These suites should be standalone, and separate from the ones in the generic
+# list (not subclasses of those!).
+TENSORPIPE_TESTS = [
+    TensorPipeAgentRpcTest,
+    TensorPipeAgentDistAutogradTest,
+]
+TENSORPIPE_CUDA_TESTS = [
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeCudaDistAutogradTest,
+]
+
+
+# This list contains test suites that will only be run on the faulty RPC agent.
+# That agent is special as it's only used to perform fault injection in order to
+# verify the error handling behavior. Thus the faulty agent will only run the
+# suites in this list, which were designed to test such behaviors, and not the
+# ones in the generic list.
+FAULTY_AGENT_TESTS = [
+    FaultyAgentRpcTest,
+    FaultyAgentDistAutogradTest,
+    JitFaultyAgentRpcTest,
+]
+
+
+def generate_tests(
+    prefix: str,
+    mixin: type[RpcAgentTestFixture],
+    tests: list[type[RpcAgentTestFixture]],
+    module_name: str,
+) -> dict[str, type[RpcAgentTestFixture]]:
+    """Mix in the classes needed to autogenerate the tests based on the params.
+
+    Takes a series of test suites, each written against a "generic" agent (i.e.,
+    derived from the abstract RpcAgentTestFixture class), as the `tests` args.
+    Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a
+    certain agent, as the `mixin` arg. Produces all combinations of them.
+    Returns a dictionary of class names to class type
+    objects which can be inserted into the global namespace of the calling
+    module. The name of each test will be a concatenation of the `prefix` arg
+    and the original name of the test suite.
+    The `module_name` should be the name of the calling module so
+    that the classes can be fixed to make it look like they belong to it, which
+    is necessary for pickling to work on them.
+    """
+    ret: dict[str, type[RpcAgentTestFixture]] = {}
+    for test_class in tests:
+        if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
+            print(
+                f"Skipping test {test_class} on sandcastle for the following reason: "
+                "Skip dev-asan as torch + multiprocessing spawn have known issues",
+                file=sys.stderr,
+            )
+            continue
+
+        name = f"{prefix}{test_class.__name__}"
+        class_ = type(name, (test_class, mixin, SpawnHelper), {})
+        class_.__module__ = module_name
+        ret[name] = class_
+    return ret
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..737b7d27a1561477c8a3781926453f90cf622c8c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/dynamo_pytree_test_utils.py
@@ -0,0 +1,28 @@
+import torch
+import torch._dynamo.test_case
+import torch.utils._pytree as pytree
+
+
+class PytreeRegisteringTestCase(torch._dynamo.test_case.TestCase):
+    """TestCase that prunes all temporary pytree registrations and resets Dynamo."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._registered_pytree_nodes: list[type] = []
+        self._registered_constant_nodes: list[type] = []
+
+    def tearDown(self) -> None:
+        for cls in reversed(self._registered_pytree_nodes):
+            pytree._deregister_pytree_node(cls)
+        for cls in reversed(self._registered_constant_nodes):
+            pytree._deregister_pytree_node(cls)
+        torch._dynamo.reset()
+        super().tearDown()
+
+    def register_pytree_node(self, cls, *args, **kwargs) -> None:  # type: ignore[no-untyped-def]
+        pytree.register_pytree_node(cls, *args, **kwargs)
+        self._registered_pytree_nodes.append(cls)
+
+    def register_constant(self, cls: type) -> None:
+        pytree.register_constant(cls)
+        self._registered_constant_nodes.append(cls)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff4118438e74cd2354997b0f3a76c4d59370b8bc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
@@ -0,0 +1,11 @@
+import sys
+from typing import Callable, Optional  # noqa: UP035
+
+from torch.utils._config_module import install_config_module
+
+
+e_list = [1]
+e_set = {1}
+e_func: Optional[Callable] = None
+
+install_config_module(sys.modules[__name__])
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..739e76d9191a0d2508572e7af8a996acd2d8ec20
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/generated/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d707d22ab81a6c191283a09ef9bfd54ae80cdc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
@@ -0,0 +1,513 @@
+# mypy: ignore-errors
+
+import functools
+import unittest
+
+import torch
+from functorch.experimental.control_flow import map
+from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import onlyCUDA
+from torch.testing._internal.common_dtype import all_types_and, custom_types
+from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
+from torch._higher_order_ops.invoke_subgraph import mark_compile_region
+from torch._higher_order_ops import InvokeQuant, invoke_quant_packed
+
+
+def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
+        args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
+    )
+
+
+def inner_f(x, y0, y1):
+    return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
+
+
+def simple_map(xs, y0, y1):
+    def f(x, y0, y1):
+        return inner_f(x, y0, y1)
+
+    return map(f, xs, y0, y1)
+
+
+def nested_map(xs, y0, y1):
+    def f1(xx, y0, y1):
+        def f2(x, y0, y1):
+            return inner_f(x, y0, y1)
+
+        return map(f2, xx, y0, y1)
+
+    return map(f1, xs, y0, y1)
+
+
+def triple_nested_map(xs, y0, y1):
+    def f0(xs, y0, y1):
+        def f1(xx, y0, y1):
+            def f2(x, y0, y1):
+                return inner_f(x, y0, y1)
+
+            return map(f2, xx, y0, y1)
+
+        return map(f1, xs, y0, y1)
+
+    return map(f0, xs, y0, y1)
+
+
+# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
+# and do add an OpInfo for your HOP.
+# The OpInfo lets us do automated testing for the HOP to check that
+# your HOP will work correctly with PyTorch!
+#
+# Your new HOP may fail some automated testing. That's OK. If you don't
+# care about certain features (like torch.export), it's fine to xfail those
+# failing tests. It is less fine to xfail a more critical check (like checking
+# if torch.compile works with your HOP, or if your HOP has a docstring).
+# If you don't know if a test is fine to xfail, please ask.
+#
+# There are legitimate reasons why something cannot be added to this list
+# (e.g. it uses executorch which is not in PyTorch). If that's the case then
+# please leave a comment.
+FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
+    "custom_function_call",
+    "autograd_function_apply",
+    "run_and_save_rng_state",
+    "run_with_rng_state",
+    "graphsafe_run_with_rng_state",
+    "out_dtype",
+    "trace_wrapped",
+    'tag_activation_checkpoint',
+    'executorch_call_delegate',
+    'wrap',
+    'wrap_with_set_grad_enabled',
+    'auto_functionalized_v2',
+    'associative_scan',
+    'flat_apply',  # is WIP, doesn't pass any of the tests yet
+    'wrap_with_autocast',
+    'wrap_activation_checkpoint',
+    'run_const_graph',
+    'auto_functionalized',
+    "map",  # T183144629
+    "map_impl",
+    "with_effects",
+    "strict_mode",
+    "_export_tracepoint",
+    "call_torchbind",
+    "triton_kernel_wrapper_mutation",
+    "triton_kernel_wrapper_functional",
+    "hints_wrapper",
+    "dynamo_bypassing_wrapper",  # TODO(soulitzer)
+    "foreach_map",
+    "aoti_call_delegate",
+    "print",
+    "inductor_compiled_code",  # Tested separately in test_inductor_wrap_inductor_compile_regions
+]
+
+torch.library.define(
+    "testlib::mutating_custom_op",
+    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
+    tags=torch.Tag.pt2_compliant_tag,
+)
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cpu")
+def foo_impl_cpu(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cuda")
+def foo_impl_cuda(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.register_fake("testlib::mutating_custom_op")
+def foo_impl_abstract(x, z):
+    return x, z, x + z
+
+
+def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+def simple_cond(x):
+    return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
+
+
+def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+@mark_compile_region
+def fn_for_invoke_subgraph(x):
+    return torch.sin(x)
+
+
+def simple_invoke_subgraph(x):
+    return fn_for_invoke_subgraph(x)
+
+
+def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
+    )
+
+
+def simple_auto_functionalize(x, z):
+    return torch.ops.testlib.mutating_custom_op(x, z)
+
+
+def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    def score_mod(score, b, h, m, n):
+        return score + h
+
+    q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
+    block_mask = _create_empty_block_mask(q, k)
+    yield SampleInput(q, k, v, score_mod, block_mask)
+
+
+def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        torch.tensor(3),
+        make_arg(2, 3, 4, low=0.1, high=2),
+    )
+
+
+def simple_while_loop(iter_t, x):
+    def cond_fn(iter_t, x):
+        return iter_t > 0
+
+    def body_fn(iter_t, x):
+        return iter_t - 1, x.cos()
+
+    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
+
+
+def simple_while_loop_stack_output(iter_t, x):
+    def cond_fn(iter_t, x):
+        return iter_t > 0
+
+    def body_fn(iter_t, x):
+        return iter_t - 1, x.cos()
+
+    return torch._higher_order_ops.while_loop_stack_output(
+        cond_fn, body_fn, (iter_t, x), tuple()
+    )
+
+
+def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
+    # TODO: once HOPs support DTensor inputs, we should also test DTensors
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        make_arg(2, 3, 4, low=0.1, high=2),
+        make_arg(2, 3, 4, low=0.1, high=2),
+    )
+
+
+def simple_local_map_hop(inp1, inp2):
+    def body_gm(inp1, inp2):
+        return inp1.cos() + inp2.sin()
+
+    gm = torch.fx.symbolic_trace(body_gm)
+
+    assert torch.distributed.is_available()
+    from torch.distributed.tensor.placement_types import Replicate
+
+    gm.meta["local_map_kwargs"] = {
+        "in_placements": (Replicate(), Replicate(), Replicate()),
+        "out_placements": ((Replicate(), Replicate(), Replicate()),),
+    }
+
+    # TODO: Dynamo would rewrite this op differently
+    return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
+
+
+def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        make_arg(2, 2, low=0.1, high=2),
+        make_arg(2, 2, 2, low=0.1, high=2),
+    )
+
+
+def simple_scan(init, xs):
+    def combine_fn(carry, x):
+        result = carry @ x + x
+        return result, carry.clone()
+
+    return torch._higher_order_ops.scan(combine_fn, init, xs)
+
+
+quant_tracer = InvokeQuant()
+
+
+def simple_invoke_quant(x):
+    def fn(x, y):
+        return (torch.sin(x) * y,)
+
+    return quant_tracer(fn, x, x)[0] * 2.0
+
+
+def simple_invoke_quant_packed(x):
+    def fn(x):
+        return (torch.sin(x),)
+
+    return invoke_quant_packed(fn, x)[0] * 2.0
+
+
+hop_db = [
+    OpInfo(
+        name="scan",
+        variant_test_name="simple",
+        op=simple_scan,
+        sample_inputs_func=sample_inputs_scan,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_subgraph",
+        variant_test_name="simple",
+        op=simple_invoke_subgraph,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="simple",
+        op=simple_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="nested",
+        op=nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="triple_nested",
+        op=triple_nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="cond",
+        variant_test_name="simple",
+        op=simple_cond,
+        sample_inputs_func=sample_inputs_cond,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant",
+        variant_test_name="simple",
+        op=simple_invoke_quant,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant_packed",
+        variant_test_name="simple",
+        op=simple_invoke_quant_packed,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="while_loop",
+        variant_test_name="simple",
+        op=simple_while_loop,
+        sample_inputs_func=sample_inputs_while_loop,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="while_loop_stack_output",
+        variant_test_name="simple",
+        op=simple_while_loop_stack_output,
+        sample_inputs_func=sample_inputs_while_loop,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="auto_functionalize",
+        variant_test_name="simple",
+        op=simple_auto_functionalize,
+        sample_inputs_func=sample_inputs_auto_functionalize,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="flex_attention",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+    OpInfo(
+        name="flex_attention_backward",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+    OpInfo(
+        name="local_map_hop",
+        variant_test_name="simple",
+        op=simple_local_map_hop,
+        sample_inputs_func=sample_inputs_local_map_hop,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[
+            onlyCUDA,
+            unittest.skipIf(
+                not torch.distributed.is_available(), "requires distributed build"
+            ),
+        ],
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a00e1e1a048a0e12c3e081da4415a980cfd97608
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
@@ -0,0 +1,379 @@
+# mypy: ignore-errors
+
+from collections import defaultdict
+from collections.abc import Iterable
+import numpy as np
+import torch
+
+import hypothesis
+from functools import reduce
+from importlib.metadata import version
+from hypothesis import assume
+from hypothesis import settings
+from hypothesis import strategies as st
+from hypothesis.extra import numpy as stnp
+from hypothesis.strategies import SearchStrategy
+
+from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
+
+# Setup for the hypothesis tests.
+# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
+# element is enforced zero_point. If None, any zero_point point within the
+# range of the data type is OK.
+
+# Tuple with all quantized data types.
+_ALL_QINT_TYPES = (
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+)
+
+# Enforced zero point for every quantized data type.
+# If None, any zero_point point within the range of the data type is OK.
+_ENFORCED_ZERO_POINT = defaultdict(lambda: None, {
+    torch.quint8: None,
+    torch.qint8: None,
+    torch.qint32: 0
+})
+
+def _get_valid_min_max(qparams):
+    scale, zero_point, _quantized_type = qparams
+    adjustment = 1 + torch.finfo(torch.float).eps
+    _long_type_info = torch.iinfo(torch.long)
+    long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
+    # make sure intermediate results are within the range of long
+    min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point))
+    max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point))
+    return np.float32(min_value), np.float32(max_value)
+
+# This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if
+# it is too old, removes the `width` parameter (which was introduced)
+# in 3.67.0
+def _floats_wrapper(*args, **kwargs):
+    if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0):
+        # As long as nan, inf, min, max are not specified, reimplement the width
+        # parameter for older versions of hypothesis.
+        no_nan_and_inf = (
+            (('allow_nan' in kwargs and not kwargs['allow_nan']) or
+             'allow_nan' not in kwargs) and
+            (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or
+             'allow_infinity' not in kwargs))
+        min_and_max_not_specified = (
+            len(args) == 0 and
+            'min_value' not in kwargs and
+            'max_value' not in kwargs
+        )
+        if no_nan_and_inf and min_and_max_not_specified:
+            if kwargs['width'] == 16:
+                kwargs['min_value'] = torch.finfo(torch.float16).min
+                kwargs['max_value'] = torch.finfo(torch.float16).max
+            elif kwargs['width'] == 32:
+                kwargs['min_value'] = torch.finfo(torch.float32).min
+                kwargs['max_value'] = torch.finfo(torch.float32).max
+            elif kwargs['width'] == 64:
+                kwargs['min_value'] = torch.finfo(torch.float64).min
+                kwargs['max_value'] = torch.finfo(torch.float64).max
+        kwargs.pop('width')
+    return st.floats(*args, **kwargs)
+
+def floats(*args, **kwargs):
+    if 'width' not in kwargs:
+        kwargs['width'] = 32
+    return _floats_wrapper(*args, **kwargs)
+
+"""Hypothesis filter to avoid overflows with quantized tensors.
+
+Args:
+    tensor: Tensor of floats to filter
+    qparams: Quantization parameters as returned by the `qparams`.
+
+Returns:
+    True
+
+Raises:
+    hypothesis.UnsatisfiedAssumption
+
+Note: This filter is slow. Use it only when filtering of the test cases is
+      absolutely necessary!
+"""
+def assume_not_overflowing(tensor, qparams):
+    min_value, max_value = _get_valid_min_max(qparams)
+    assume(tensor.min() >= min_value)
+    assume(tensor.max() <= max_value)
+    return True
+
+"""Strategy for generating the quantization parameters.
+
+Args:
+    dtypes: quantized data types to sample from.
+    scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3.
+    zero_point_min / zero_point_max: Min and max for the zero point. If None,
+        set to the minimum and maximum of the quantized data type.
+        Note: The min and max are only valid if the zero_point is not enforced
+              by the data type itself.
+
+Generates:
+    scale: Sampled scale.
+    zero_point: Sampled zero point.
+    quantized_type: Sampled quantized type.
+"""
+@st.composite
+def qparams(draw, dtypes=None, scale_min=None, scale_max=None,
+            zero_point_min=None, zero_point_max=None):
+    if dtypes is None:
+        dtypes = _ALL_QINT_TYPES
+    if not isinstance(dtypes, (list, tuple)):
+        dtypes = (dtypes,)
+    quantized_type = draw(st.sampled_from(dtypes))
+
+    _type_info = torch.iinfo(quantized_type)
+    qmin, qmax = _type_info.min, _type_info.max
+
+    # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
+    _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
+    if _zp_enforced is not None:
+        zero_point = _zp_enforced
+    else:
+        _zp_min = qmin if zero_point_min is None else zero_point_min
+        _zp_max = qmax if zero_point_max is None else zero_point_max
+        zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))
+
+    if scale_min is None:
+        scale_min = torch.finfo(torch.float).eps
+    if scale_max is None:
+        scale_max = torch.finfo(torch.float).max
+    scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))
+
+    return scale, zero_point, quantized_type
+
+"""Strategy to create different shapes.
+Args:
+    min_dims / max_dims: minimum and maximum rank.
+    min_side / max_side: minimum and maximum dimensions per rank.
+
+Generates:
+    Possible shapes for a tensor, constrained to the rank and dimensionality.
+
+Example:
+    # Generates 3D and 4D tensors.
+    @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4))
+    some_test(self, Q):...
+"""
+@st.composite
+def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
+    """Return a strategy for array shapes (tuples of int >= 1)."""
+    assert min_dims < 32
+    if max_dims is None:
+        max_dims = min(min_dims + 2, 32)
+    assert max_dims < 32
+    if max_side is None:
+        max_side = min_side + 5
+    candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
+    if max_numel is not None:
+        candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel)
+    return draw(candidate.map(tuple))
+
+
+"""Strategy for generating test cases for tensors.
+The resulting tensor is in float32 format.
+
+Args:
+    shapes: Shapes under test for the tensor. Could be either a hypothesis
+            strategy, or an iterable of different shapes to sample from.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Instance of the qparams strategy. This is used to filter the tensor
+             such that the overflow would not happen.
+
+Generates:
+    X: Tensor of type float32. Note that NaN and +/-inf is not included.
+    qparams: (If `qparams` arg is set) Quantization parameters for X.
+        The returned parameters are `(scale, zero_point, quantization_type)`.
+        (If `qparams` arg is None), returns None.
+"""
+@st.composite
+def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    return X, (scale, zp, qparams[2])
+
+@st.composite
+def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    # Permute to model quantization along an axis
+    axis = int(np.random.randint(0, X.ndim, 1))
+    permute_axes = np.arange(X.ndim)
+    permute_axes[0] = axis
+    permute_axes[axis] = 0
+    X = np.transpose(X, permute_axes)
+
+    return X, (scale, zp, axis, qparams[2])
+
+"""Strategy for generating test cases for tensors used in Conv.
+The resulting tensors is in float32 format.
+
+Args:
+    spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly
+                 picks one from the pool to make it the spatial dimension
+    batch_size_range: Range to generate `batch_size`.
+                      Must be tuple of `(min, max)`.
+    input_channels_per_group_range:
+        Range to generate `input_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    output_channels_per_group_range:
+        Range to generate `output_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    feature_map_range: Range to generate feature map size for each spatial_dim.
+                       Must be tuple of `(min, max)`.
+    kernel_range: Range to generate kernel size for each spatial_dim. Must be
+                  tuple of `(min, max)`.
+    max_groups: Maximum number of groups to generate.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Strategy for quantization parameters. for X, w, and b.
+             Could be either a single strategy (used for all) or a list of
+             three strategies for X, w, b.
+Generates:
+    (X, W, b, g): Tensors of type `float32` of the following drawen shapes:
+        X: (`batch_size, input_channels, H, W`)
+        W: (`output_channels, input_channels_per_group) + kernel_shape
+        b: `(output_channels,)`
+        groups: Number of groups the input is divided into
+Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either
+      None or (scale, zero_point, quantized_type)
+
+
+Example:
+    @given(tensor_conv(
+        spatial_dim=2,
+        batch_size_range=(1, 3),
+        input_channels_per_group_range=(1, 7),
+        output_channels_per_group_range=(1, 7),
+        feature_map_range=(6, 12),
+        kernel_range=(3, 5),
+        max_groups=4,
+        elements=st.floats(-1.0, 1.0),
+        qparams=qparams()
+    ))
+"""
+@st.composite
+def tensor_conv(
+    draw, spatial_dim=2, batch_size_range=(1, 4),
+    input_channels_per_group_range=(3, 7),
+    output_channels_per_group_range=(3, 7), feature_map_range=(6, 12),
+    kernel_range=(3, 7), max_groups=1, can_be_transposed=False,
+    elements=None, qparams=None
+):
+
+    # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW
+    batch_size = draw(st.integers(*batch_size_range))
+    input_channels_per_group = draw(
+        st.integers(*input_channels_per_group_range))
+    output_channels_per_group = draw(
+        st.integers(*output_channels_per_group_range))
+    groups = draw(st.integers(1, max_groups))
+    input_channels = input_channels_per_group * groups
+    output_channels = output_channels_per_group * groups
+
+    if isinstance(spatial_dim, Iterable):
+        spatial_dim = draw(st.sampled_from(spatial_dim))
+
+    feature_map_shape = [draw(st.integers(*feature_map_range)) for _ in range(spatial_dim)]
+
+    kernels = [draw(st.integers(*kernel_range)) for _ in range(spatial_dim)]
+
+    tr = False
+    weight_shape = (output_channels, input_channels_per_group) + tuple(kernels)
+    bias_shape = output_channels
+    if can_be_transposed:
+        tr = draw(st.booleans())
+        if tr:
+            weight_shape = (input_channels, output_channels_per_group) + tuple(kernels)
+            bias_shape = output_channels
+
+    # Resolve the tensors
+    if qparams is not None:
+        if isinstance(qparams, (list, tuple)):
+            assert len(qparams) == 3, "Need 3 qparams for X, w, b"
+        else:
+            qparams = [qparams] * 3
+
+    X = draw(tensor(shapes=(
+        (batch_size, input_channels) + tuple(feature_map_shape),),
+        elements=elements, qparams=qparams[0]))
+    W = draw(tensor(shapes=(weight_shape,), elements=elements,
+                    qparams=qparams[1]))
+    b = draw(tensor(shapes=(bias_shape,), elements=elements,
+                    qparams=qparams[2]))
+
+    return X, W, b, groups, tr
+
+
+# We set the deadline in the currently loaded profile.
+# Creating (and loading) a separate profile overrides any settings the user
+# already specified.
+hypothesis_version = tuple(map(int, version("hypothesis").split(".")[:3]))
+
+if (3, 16, 0) <= hypothesis_version < (3, 27, 0):
+    # Hypothesis 3.16 → 3.26: use `timeout` instead of `deadline`
+    settings.register_profile("no_deadline", timeout=hypothesis.unlimited)
+else:
+    # Hypothesis >=3.27: use `deadline=None`
+    settings.register_profile("no_deadline", deadline=None)
+
+# Activate the profile
+settings.load_profile("no_deadline")
+
+
+def assert_deadline_disabled():
+    """Check that deadlines are effectively disabled across Hypothesis versions."""
+    if hypothesis_version < (3, 27, 0):
+        import warnings
+
+        warning_message = (
+            "Your version of hypothesis is outdated. "
+            "To avoid `DeadlineExceeded` errors, please update. "
+            f"Current hypothesis version: {hypothesis.__version__}"
+        )
+        warnings.warn(warning_message, stacklevel=2)
+    else:
+        assert settings().deadline is None
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3dbb95f4ba9c3c430e27a677fc3850aee2b3549
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
@@ -0,0 +1,725 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+import torch.nn.functional as F
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+from torch.testing._internal.common_nn import module_tests, get_new_module_tests
+from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
+
+import collections
+from copy import deepcopy
+from typing import Any, Union
+import math  # noqa: F401
+
+# Testing utils
+from torch import inf
+
+assert torch.get_default_dtype() == torch.float32
+
+L = 20
+M = 10
+S = 5
+
+
+def unpack_variables(args):
+    if isinstance(args, tuple):
+        return tuple(unpack_variables(elem) for elem in args)
+    else:
+        return args
+
+class dont_convert(tuple):
+    __slots__ = ()
+
+non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
+
+def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None):
+    if not isinstance(call_args, tuple):
+        call_args = (call_args,)
+
+    def map_arg(arg):
+        def maybe_non_contig(tensor):
+            if not non_contiguous or tensor.numel() < 2:
+                return tensor.clone()
+
+            return noncontiguous_like(tensor)
+
+        def conjugate(tensor):
+            return tensor.conj()
+
+        if isinstance(arg, (torch.Size, dont_convert)):
+            return arg
+        elif isinstance(arg, tuple) and len(arg) == 0:
+            var = conjugate(torch.randn((), dtype=dtype, device=device))
+            var.requires_grad = requires_grad
+            return var
+        elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
+            return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
+        # double check casting
+        elif isinstance(arg, non_differentiable):
+            if isinstance(arg.tensor, torch.Tensor):
+                return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+            return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+        elif isinstance(arg, torch.Tensor):
+            if arg.is_complex() != dtype.is_complex:
+                raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
+                                   "which is not supported for now")
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
+            v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
+            v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
+            return v
+        elif callable(arg):
+            return map_arg(arg(dtype=dtype, device=device))
+        else:
+            return arg
+    args_out = tuple(map_arg(arg) for arg in call_args)
+    kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
+    return args_out, kwargs_out
+
+# NB: JIT script tests for all nn functional interfaces, script mode does
+# not support in_place operations yet, so no inplace operation tests added.
+# removed all the deprecated functions
+#
+# (
+#   method name,
+#   input size/constructing fn,
+#   args (tuple represents shape of a tensor arg),
+#   test variant name(will be used at test name suffix,
+#       'inplace' skips grad tests),                         // optional
+#   (True, nonfusible_nodes, fusible_nodes) for autodiff     // optional
+#   fn to determine if test should be skipped,               // optional
+#   fn mapping output to part that should be gradcheck'ed,   // optional
+#   kwargs for function,                                     // optional
+# )
+def get_nn_functional_tests():
+    nn_functional_tests = [
+        ('conv1d', (S, S, S), ((S, S, S),)),
+        ('conv2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_transpose1d', (S, S, S), ((S, S, S),)),
+        ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
+        ('avg_pool1d', (S, S, S), (3,)),
+        ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
+        ('avg_pool3d', (S, S, S, S, S), (3,)),
+        ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
+        ('max_pool1d', (S, S, S), (2, 1)),
+        ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
+        ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool3d', (S, S, S, S, S), (2, 1)),
+        ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
+        ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
+        ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
+        ('lp_pool1d', (S, S, S), (2., 3, 2,)),
+        ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
+        ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
+        ('adaptive_max_pool1d', (S, S, S), (5,)),
+        ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
+        ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
+        ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
+        ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
+        ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
+        ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
+        ('alpha_dropout', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
+        ('dropout3d', (S, S, S, S), (0.5,)),
+        ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
+        ('feature_alpha_dropout', (S, S, S), (0.5,)),
+        ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
+        ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
+        ('relu', (S, S, S), (), '', (True,)),
+        ('relu', (S, S, S), (), 'inplace'),
+        ('glu', (S - 1, S - 1, S - 1), (),),
+        ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
+        ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
+        ('relu6', (S, S, S), (), '', (True,)),
+        ('relu6', (S, S, S), (True), 'inplace'),
+        ('elu', (S, S, S), (0.9,),),
+        ('elu', (S, S, S), (0.9, True), 'inplace'),
+        ('selu', (S, S, S), (),),
+        ('selu', (S, S, S), (True), 'inplace'),
+        ('celu', (S, S, S), (0.9,),),
+        ('celu', (S, S, S), (0.9, True), 'inplace'),
+        ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
+        ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
+        ('rrelu', (S, S), (0.1, 0.3, False),),
+        ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
+        ('hardshrink', (S, S, S), (0.4,), '', (True,)),
+        ('tanhshrink', (S, S, S), (),),
+        ('softsign', (S, S, S), (),),
+        ('softplus', (S, S, S), (), '', (True,)),
+        ('softmin', (S, S, S), (0,),),
+        ('softmax', (S, S, S), (0,), '', (True,)),
+        ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
+        ('tanh', (S, S, S), (), '', (True,)),
+        ('sigmoid', (S, S, S), (), '', (True,)),
+        ('silu', (S, S, S), (), '', (True,)),
+        ('log_softmax', (S, S, S), (0,), '', (True,)),
+        ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
+        ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
+        ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
+        ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
+        ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
+            'training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), True, ),
+            'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, True, ),
+            'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, None, False, ),
+            'inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
+            'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), False, ),
+            'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, False, ),
+            'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
+        ('layer_norm', (S, S, S, S), ([5],), '',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+                                      non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
+        ('group_norm', (S, S, S), (1, torch.rand(5),),),
+        ('local_response_norm', (S, S, S), (2, ),),
+        ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
+        ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
+        ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
+        ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
+        ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('margin_ranking_loss', (S,), ((S,), (S,)),),
+        ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
+        ('pixel_shuffle', (1, 9, 4, 4), (3,),),
+        ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
+        ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
+        ('pad', (3, 3, 4, 2), ([1, 1],),),
+        ('pairwise_distance', (S, S), ((S, S),),),
+        ('pdist', (S, S), (),),
+        ('cosine_similarity', (S, S), ((S, S),),),
+        ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
+        ('normalize', (S, S, S), (),),
+        ('unfold', (S, S, S, S), ([2, 3]),),
+        ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
+        ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
+        ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
+        ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
+                                       1, 1., non_differentiable(torch.randn(S))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
+                                                               non_differentiable(torch.randn(3, 2))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
+            (non_differentiable(torch.rand(3, 2)),
+             non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
+        ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
+         (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
+          torch.randint(1, S, (S,), dtype=torch.long))),
+        ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
+        ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
+         'nearest_4d_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
+         'nearest_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
+         'bilinear_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
+         'bilinear_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
+         'bicubic_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
+         'bicubic_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
+         'nearest_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
+         'nearest_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
+         'linear_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
+         'linear_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
+         'nearest_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
+         'nearest_5d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
+         'trilinear_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
+         'trilinear_5d_with_size_not_recompute_scale_factor'),
+    ]
+    return nn_functional_tests
+
+script_template = '''
+def the_method({}):
+    return {}
+'''
+
+def value_to_literal(value):
+    if isinstance(value, str):
+        # Quotes string and escapes special characters
+        return ascii(value)
+    if isinstance(value, torch.Tensor):
+        return 'torch.' + str(value)
+    else:
+        return str(value)
+
+def get_call(method_name, func_type, args, kwargs):
+    kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
+    self_arg = args[0]
+    if func_type == 'method':
+        args = args[1:]
+
+    argument_str = ', '.join(args)
+    argument_str += ', ' if len(args) and len(kwargs) else ''
+    argument_str += kwargs_str
+
+    if func_type == 'functional' or func_type == 'function':
+        call = f'torch.{method_name}({argument_str})'
+    elif func_type == 'method':
+        call = f'{self_arg}.{method_name}({argument_str})'
+    elif func_type == 'nn_functional':
+        call = f'torch.nn.functional.{method_name}({argument_str})'
+    else:
+        raise TypeError('Unsupported function type')
+
+    return call
+
+def get_constant(x):
+    if x == inf:
+        return 'math.inf'
+    if x == -inf:
+        return '-math.inf'
+    return x
+
+def get_script_args(args):
+    formals: list[str] = []
+    tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+    actuals: list[str] = []
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            name = f'i{len(formals)}'
+            formals.append(name)
+            actuals.append(name)
+            tensors.append(arg)
+        elif is_iterable_of_tensors(arg):
+            name = f'i{len(formals)}'
+            formals.append(name + ': List[torch.Tensor]')
+            actuals.append(name)
+            tensors.append(list(arg))
+        elif isinstance(arg, str):
+            actuals.append(f"'{arg}'")
+        else:
+            actuals.append(str(get_constant(arg)))
+    return (formals, tensors, actuals)
+
+# create a script function from (name, func_type, output_process_fn),
+# and returns the compiled function and example inputs
+def gen_script_fn_and_args(method_name, func_type, *args, **kwargs):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    return CU.the_method, tensors
+
+# create a script function from (name, func_type),
+# returns a function takes in (args, kwargs) and runs the compiled function
+def create_script_fn(self, method_name, func_type):
+    # function returns tuple containing original output and
+    # filtered output to be used in checking gradients
+    def script_fn(*args, **kwargs):
+        fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
+        self.assertExportImport(fn.graph, tensors)
+        output = fn(*tensors)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        script_fn.last_graph = fn.graph_for(*tensors)  # type: ignore[attr-defined]
+        return output
+    return script_fn
+
+class SplitInputs:
+    all_tensors: list[Any]
+    tensor_args: list[Any]
+    nontensor_args: list[Any]
+    arg_types: list[str]
+    tensor_kwargs: dict[str, Any]
+    kwarg_order: list[str]
+    nontensor_kwargs: dict[str, Any]
+    kwarg_types: dict[str, Any]
+
+    @staticmethod
+    def _is_tensor_input(arg):
+        return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
+
+    def __init__(self, args, kwargs):
+        self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args]
+        self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()}
+        self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)]
+        self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)]
+        self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)}
+        self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)}
+        self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]]
+        self.kwarg_order = [k for k, v in kwargs.items()]
+
+    def nontensors_match(self, other: 'SplitInputs'):
+        if self.arg_types != other.arg_types:
+            return False
+        if self.kwarg_types != other.kwarg_types:
+            return False
+        if self.kwarg_order != other.kwarg_order:
+            return False
+        if self.nontensor_args != other.nontensor_args:
+            return False
+        if self.nontensor_kwargs != other.nontensor_kwargs:
+            return False
+        return True
+
+# make a new function where all non-tensor arguments in 'args' have been partially
+# applied, and all tensor arguments remain.
+# used to trace functions when some arguments are not tensors
+def partial_apply_nontensors(fn, args, kwargs):
+    inputs = SplitInputs(args, kwargs)
+
+    def new_fn(*tensors_):
+        tensors = iter(tensors_)
+        full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)]
+        full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()}
+        return fn(*full_args, **full_kwargs)
+
+    return new_fn, inputs
+
+# create a trace function from input fn
+def create_traced_fn(self, fn, cache_traced_fn=False):
+    def traced_fn(*inputs, **kwargs):
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `check_against_reference` already does all the checks
+        # against python function
+        fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs)
+        if not cache_traced_fn or not hasattr(traced_fn, 'traced'):
+            traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False)
+            self.assertExportImport(traced.graph, split_inputs.all_tensors)
+            output = traced(*split_inputs.all_tensors)
+            if cache_traced_fn:
+                traced_fn.traced = traced
+                traced_fn.split_inputs = split_inputs
+        else:
+            # Guard to check that nontensor inputs are the same as during tracing
+            self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs))
+            output = traced_fn.traced(*split_inputs.all_tensors)
+            traced = traced_fn.traced
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors)  # type: ignore[attr-defined]
+        traced_fn.graph = traced.graph  # type: ignore[attr-defined]
+        return output
+    return traced_fn
+
+# known to be failing in script
+EXCLUDE_SCRIPT = {
+    'test_norm_fro_default',
+    'test_norm_fro_cpu',
+    'test_norm_nuc',
+    'test_norm_fro',
+    'test_norm_nuc_batched',
+
+    # aten op has additional cudnn argument
+    'test_nn_unfold',
+
+    # flaky test - TODO fix
+    'test_nn_ctc_loss',
+
+    # unknown builtin op
+    'test_nn_fold',
+
+    # jit doesn't support sparse tensors.
+    'test_to_sparse',
+    'test_to_sparse_dim',
+}
+
+# generates a script function and set of example inputs
+# from a specified test in the format of nn_functional_tests
+def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args):
+    test_name = 'test_nn_' + name
+
+    if variant_name != '':
+        test_name = test_name + '_' + variant_name
+
+    self_variable = create_input((self_size,))[0][0]
+
+    # need to record this because methods can change the size (e.g. unsqueeze)
+    args_variable, _kwargs_variable = create_input(args)
+
+    self_tensor = deepcopy(self_variable.data)
+    args_tensor = deepcopy(unpack_variables(args_variable))
+
+    f_args_variable = (self_variable,) + args_variable
+    f_args_tensor = (self_tensor,) + args_tensor  # noqa: F841
+    with torch._jit_internal._disable_emit_hooks():
+        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
+    return script_fn, inputs
+
+
+
+EXCLUDE_SCRIPT_MODULES = {
+    'test_nn_AdaptiveAvgPool2d_tuple_none',
+    'test_nn_AdaptiveAvgPool3d_tuple_none',
+    'test_nn_AdaptiveMaxPool2d_tuple_none',
+    'test_nn_AdaptiveMaxPool3d_tuple_none',
+
+    # Doesn't use future division, so this is not supported
+    'test_nn_CrossMapLRN2d',
+    # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented
+    'test_nn_TransformerDecoderLayer_gelu_activation',
+    'test_nn_TransformerDecoderLayer_relu_activation',
+    'test_nn_TransformerEncoderLayer_gelu_activation',
+    'test_nn_TransformerEncoderLayer_relu_activation',
+    'test_nn_Transformer_multilayer_coder',
+}
+
+script_method_template = '''
+def forward({}):
+    return {}
+'''
+
+def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
+    def script_module(*args, **kwargs):
+        _formals, tensors, actuals = get_script_args(args)
+
+        method_args = ', '.join(['self'] + actuals)
+        call_args_str = ', '.join(actuals)
+        call = f"self.submodule({call_args_str})"
+        script = script_method_template.format(method_args, call)
+
+        submodule_constants = []
+        if kwargs.get('is_constant'):
+            submodule_constants = ['submodule']
+
+        # Create module to use the script method
+        class TheModule(torch.jit.ScriptModule):
+            __constants__ = submodule_constants
+
+            def __init__(self) -> None:
+                super().__init__()
+                self.submodule = nn_module(*constructor_args)
+
+        def make_module(script):
+            module = TheModule()
+            # check __repr__
+            str(module)
+            module.define(script)
+            return module
+
+        module = make_module(script)
+        if self:
+            self.assertExportImportModule(module, tensors)
+            module(*args)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        create_script_module.last_graph = module.graph  # type: ignore[attr-defined]
+        return module
+    return script_module
+
+def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    # to clean up IR
+    torch._C._jit_pass_inline(CU.the_method.graph)
+    torch._C._jit_pass_constant_propagation(CU.the_method.graph)
+    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)
+
+def get_nn_module_name_from_kwargs(**kwargs):
+    if 'module_name' in kwargs:
+        return kwargs['module_name']
+    elif 'fullname' in kwargs:
+        return kwargs['fullname']
+    elif 'constructor' in kwargs:
+        return kwargs['constructor'].__name__
+
+def get_nn_mod_test_name(**kwargs):
+    if 'fullname' in kwargs:
+        test_name = kwargs['fullname']
+    else:
+        test_name = get_nn_module_name_from_kwargs(**kwargs)
+        if 'desc' in kwargs:
+            test_name = f"{test_name}_{kwargs['desc']}"
+    return f'test_nn_{test_name}'
+
+def get_nn_module_class_from_kwargs(**kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+    index = name.find("_")
+    if index == -1:
+        return name
+    else:
+        return name[0:name.find("_")]
+
+def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+
+    if 'desc' in kwargs and 'eval' in kwargs['desc']:
+        # eval() is not supported, so skip these tests
+        return
+
+    test_name = name
+    if 'desc' in kwargs:
+        test_name = f"{test_name}_{kwargs['desc']}"
+    test_name = get_nn_mod_test_name(**kwargs)
+
+    if test_name in EXCLUDE_SCRIPT_MODULES:
+        return
+    if 'constructor' in kwargs:
+        nn_module = kwargs['constructor']
+    else:
+        nn_module = getattr(torch.nn, name)
+
+    if "FunctionalModule" in str(nn_module):
+        return
+
+    if 'constructor_args_fn' in kwargs:
+        constructor_args = kwargs['constructor_args_fn']()
+    else:
+        constructor_args = kwargs.get('constructor_args', ())
+
+    # Set up inputs from tuple of sizes or constructor fn
+    input_dtype = torch.double
+    if 'input_fn' in kwargs:
+        input = kwargs['input_fn']()
+        if isinstance(input, torch.Tensor):
+            input = (input,)
+
+        if all(tensor.is_complex() for tensor in input):
+            input_dtype = torch.cdouble
+    else:
+        input = (kwargs['input_size'],)
+
+    # Extra parameters to forward()
+    if 'extra_args' in kwargs:
+        input = input + kwargs['extra_args']
+
+    if 'target_size' in kwargs:
+        input = input + (kwargs['target_size'],)
+    elif 'target_fn' in kwargs:
+        if torch.is_tensor(input):
+            input = (input,)
+        input = input + (kwargs['target_fn'](),)
+
+    args_variable, _kwargs_variable = create_input(input, dtype=input_dtype)
+    f_args_variable = deepcopy(unpack_variables(args_variable))
+    out_var = deepcopy(f_args_variable)
+
+
+    _args, mod = f_args_variable, create_script_module(
+        None, nn_module, constructor_args, *f_args_variable
+    )(*f_args_variable)
+
+    return mod, out_var
+
+
+def get_all_nn_module_tests():
+    # additional modules test
+    # TODO: delete this list once we make all nn_tests work
+    additional_module_tests = [
+        {
+            'module_name': 'Bilinear',
+            'constructor_args': (S, S, M),
+            'input_size': (S, S),
+            'extra_args': ((S, S),)
+        },
+        {
+            'module_name': 'RNNCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'LSTMCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'GRUCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'MultiheadAttention',
+            'constructor_args': (128, 8),
+            'input_size': (10, 8, 128),
+            'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
+            'slowTest': True
+        },
+        {
+            'module_name': 'Transformer',
+            'constructor_args': (1, 1, 1, 1, 2),
+            'input_size': (3, 1, 1),
+            'extra_args': (torch.randn(1, 1, 1),),
+            'slowTest': True
+        }
+    ]
+
+    return module_tests + get_new_module_tests() + additional_module_tests
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aab838e8c87b229a824f1b4548f035cea614bfb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
@@ -0,0 +1,896 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.autograd import Variable
+from torch.autograd.function import _nested_map
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+
+from torch.onnx import OperatorExportTypes
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+import torch.jit.quantized
+import zipfile
+import functools
+
+# Testing utils
+from torch.testing import FileCheck
+from torch.testing._internal.common_utils import IS_WINDOWS, \
+    freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
+    is_iterable_of_tensors
+from torch.testing._internal.common_jit import JitCommonTestCase
+from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
+
+# Standard library
+from contextlib import contextmanager
+from functools import reduce
+from io import StringIO
+from collections import defaultdict
+
+import importlib.util
+import inspect
+import io
+import math
+import os
+import pickle
+import sys
+import tempfile
+import textwrap
+from importlib.abc import Loader
+from typing import Any, Union
+
+RUN_CUDA = torch.cuda.is_available()
+RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
+RUN_CUDA_HALF = RUN_CUDA
+# HIP supports half, no version check necessary
+if torch.cuda.is_available() and not torch.version.hip:
+    CUDA_VERSION = torch._C._cuda_getCompiledVersion()
+    for d in range(torch.cuda.device_count()):
+        major = torch.cuda.get_device_capability(d)[0]
+        if (major < 6):
+            RUN_CUDA_HALF = False
+
+def execWrapper(code, glob, loc):
+    exec(code, glob, loc)
+
+def do_input_map(fn, input):
+    return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
+
+def clear_class_registry():
+    torch._C._jit_clear_class_registry()
+    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
+    torch.jit._state._clear_class_state()
+
+def get_execution_plan(graph_executor_state):
+    execution_plans = list(graph_executor_state.execution_plans.values())
+    num_plans = len(execution_plans)
+    if num_plans != 1:
+        raise RuntimeError('This test assumes this GraphExecutor should '
+                           f'only have one execution plan, got: {num_plans}')
+    return execution_plans[0]
+
+class _AssertRaisesRegexWithHighlightContext:
+    """
+    A context manager that is useful for checking that error messages highlight
+    the correct part of the source code.
+    """
+
+    def __init__(self, test_case, exception, regex, highlight):
+        self.test_case = test_case
+        self.exception_type = exception
+        self.regex = regex
+        self.highlight = highlight
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
+            if type:
+                raise value
+
+        if self.highlight:
+            FileCheck().check_source_highlighted(self.highlight).run(str(value))
+
+        return True
+
+FUSION_GROUP = "prim::TensorExprGroup"
+
+class JitTestCase(JitCommonTestCase):
+    _do_cuda_memory_leak_check = True
+    _restored_warnings = False
+
+    class capture_stdout(list):
+        """
+        Replace sys.stdout with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stdout = sys.stdout
+            self.stringio = StringIO()
+            sys.stdout = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stdout = self.sys_stdout
+
+    class capture_stderr(list):
+        """
+        Replace sys.stderr with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stderr = sys.stderr
+            self.stringio = StringIO()
+            sys.stderr = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stderr = self.sys_stderr
+
+    def setHooks(self):
+        torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
+
+    def clearHooks(self):
+        torch._C._jit_set_emit_hooks(None, None)
+
+    def setUp(self):
+        super().setUp()
+        # unittest overrides all warning filters and forces all of them to show up
+        # after we install our own to silence those coming from inside PyTorch.
+        # This will ensure that our filter still takes precedence.
+        if not JitTestCase._restored_warnings:
+            torch.jit.TracerWarning.ignore_lib_warnings()
+            JitTestCase._restored_warnings = True
+        self.setHooks()
+
+    def tearDown(self):
+        super().tearDown()
+        # needs to be cleared because python might be unloaded before
+        # the callback gets destructed
+        self.clearHooks()
+        clear_class_registry()
+
+    def assertAllFused(self, graph, except_for=()):
+
+        # note this helper collects nodes on 'fast path' only
+        # i.e. the true blocks of specialized checks
+        def get_nodes_and_parents_recursively(block, kind, acc):
+            for node in block.nodes():
+                if node.kind() == kind:
+                    acc[block].append(node)
+                elif node.kind() == 'prim::DifferentiableGraph':
+                    get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
+                elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
+                                                    node.inputs().__next__().node().kind() == 'prim::TypeCheck' or
+                                                    node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
+                    get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
+                else:
+                    for inner_block in node.blocks():
+                        get_nodes_and_parents_recursively(inner_block, kind, acc)
+
+        allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
+                         'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
+
+        fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list)
+        get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
+        self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
+        (graph, fusion_nodes) = next(iter(fusion_groups.items()))
+        # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
+        self.assertTrue(len(fusion_nodes) == 1, f'got {graph}')
+        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
+                        f'got {graph}')
+
+    def _isHookExceptionOk(self, e):
+        se = str(e)
+        allowed = ("Could not export Python function",
+                   "closures are not exportable")
+        for a in allowed:
+            if a in se:
+                return True
+        return False
+
+    def _compared_saved_loaded(self, m):
+        def extract_files(buffer):
+            # crack open the zip format to get at the main module code
+            with zipfile.ZipFile(buffer) as archive:
+                # check that we have no duplicate names
+                self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
+                files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
+                # unwrap all the code files into strings
+                code_files_str = filter(lambda x: x.endswith('.py'), files)
+                code_files = []
+                for f in code_files_str:
+                    with archive.open(f) as stream:
+                        code_files.append("".join([line.decode() for line in stream]))
+
+                # unpickled all the debug files
+                debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
+                debug_files = []
+                for f in debug_files_str:
+                    with archive.open(f) as stream:
+                        debug_files.append(pickle.load(stream))
+                return code_files, debug_files
+
+        # disable the hook while we parse code, otherwise we will re-enter the hook
+        with torch._jit_internal._disable_emit_hooks():
+            try:
+                # short-circuit if this is an empty function or module
+                if len(m.code) == 0:
+                    return
+                if isinstance(m, torch._C.ScriptModule):
+                    if len(m._method_names()) == 0:
+                        return
+
+                # save the module to a buffer
+                buffer = io.BytesIO()
+                torch.jit.save(m, buffer)
+                # copy the data in the buffer so we can restore it later. This
+                # is because py2 and py3 have different semantics with zipfile
+                # and it's easier to just work with a fresh copy each time.
+                buffer_copy = buffer.getvalue()
+
+                code_files, _debug_files = extract_files(buffer)
+
+            except RuntimeError as e:
+                if not self._isHookExceptionOk(e):
+                    raise
+                else:
+                    return
+
+            # import the model again (from a the copy we made of the original)
+            buffer2 = io.BytesIO(buffer_copy)
+            imported = torch.jit.load(buffer2)
+
+            # save it again
+            saved_module_buffer_2 = io.BytesIO()
+            torch.jit.save(imported, saved_module_buffer_2)
+
+            saved_module_buffer_2.seek(0)
+            code_files_2, _debug_files_2 = extract_files(saved_module_buffer_2)
+
+            for a, b in zip(code_files, code_files_2, strict=True):
+                self.assertMultiLineEqual(a, b)
+
+            if isinstance(m, torch._C.ScriptModule):
+                self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
+
+
+    def emitFunctionHook(self, func):
+        # func has invalid names for export, skip the jitter check
+        if func.name == "" or "aten::" in func.name:
+            return
+        self._compared_saved_loaded(func)
+
+    def emitModuleHook(self, module):
+        self._compared_saved_loaded(module)
+
+
+    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
+        torch.jit.save(m, buffer)
+        m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+        imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+
+        if not also_test_file:
+            return imported
+
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            try:
+                f.close()
+                imported.save(f.name)
+                result = torch.jit.load(f.name, map_location=map_location)
+            finally:
+                os.unlink(f.name)
+
+        result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        return result
+
+    def assertGraphContains(self, graph, kind, consider_subgraphs=False):
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            self.assertTrue(count > 0)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        self.assertTrue(len(out_nodes) > 0)
+
+    def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
+        def perform_assert(graph, kind, actual, expected, consider_subgraphs):
+            if actual == expected:
+                return
+            subgraph = 'including' if consider_subgraphs else 'excluding'
+            raise AssertionError(
+                f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}')
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            perform_assert(graph, kind, count, num_kind_nodes,
+                           consider_subgraphs)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        perform_assert(graph, kind, len(out_nodes), num_kind_nodes,
+                       consider_subgraphs)
+
+    def assertExpectedONNXGraph(self, g, *args, **kwargs):
+        g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
+        self.assertExpectedGraph(g, *args, **kwargs)
+
+    def assertExpectedGraph(self, trace, *args, **kwargs):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+        else:
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        torch._C._jit_pass_dce(graph)
+        torch._C._jit_pass_lint(graph)
+        graph = torch._C._jit_pass_canonicalize(graph)
+        torch._C._jit_pass_lint(graph)
+        self.assertExpected(str(graph), *args, **kwargs)
+
+    def run_pass(self, name, trace):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+            set_graph = False
+        else:
+            set_graph = True
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        result = getattr(torch._C, '_jit_pass_' + name)(graph)
+        if result is not None and not isinstance(result, bool):
+            graph = result
+        torch._C._jit_pass_lint(graph)
+
+        if set_graph:
+            trace.set_graph(graph)
+        return graph
+
+    def get_frame_vars(self, frames_up):
+        frame = inspect.currentframe()
+        if not frame:
+            raise RuntimeError("failed to inspect frame")
+        i = 0
+        while i < frames_up + 1:
+            frame = frame.f_back
+            if not frame:
+                raise RuntimeError("failed to get frame")
+            i += 1
+        defined_vars: dict[str, Any] = {}
+        defined_vars.update(frame.f_locals)
+        defined_vars.update(frame.f_globals)
+        return defined_vars
+
+    def assertRaisesRegexWithHighlight(self, exception, regex, highlight):
+        return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
+
+    def checkScriptRaisesRegex(self, script, inputs, exception, regex,
+                               name=None, outputs=None, capture_output=False,
+                               frames_up=1, profiling=ProfilingMode.PROFILING):
+        """
+        Checks that a given function will throw the correct exception,
+        when executed with normal python, the string frontend, and the
+        AST frontend. Logic taken from `checkScript` (see comments there
+        for details)
+        """
+        with enable_profiling_mode_for_profiling_tests():
+            # Normal Python
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    frame = self.get_frame_vars(frames_up)
+                    the_locals: dict[str, Any] = {}
+                    execWrapper(script, glob=frame, loc=the_locals)
+                    frame.update(the_locals)
+
+                    python_fn = frame[name]
+                else:
+                    python_fn = script
+
+                python_fn(*inputs)
+
+            # String frontend
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+                    string_frontend = getattr(cu, name)
+                else:
+                    source = textwrap.dedent(inspect.getsource(script))
+                    cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
+                    string_frontend = getattr(cu, script.__name__)
+
+                string_frontend(*inputs)
+
+            # Python AST frontend
+            if not isinstance(script, str):
+                with self.assertRaisesRegex(exception, regex):
+                    ge = torch.jit.script(python_fn)
+                    ge(*inputs)
+
+    def checkBailouts(self, model, inputs, expected):
+        state = model.get_debug_state()
+        plan = get_execution_plan(state)
+        num_bailouts = plan.code.num_bailouts()
+        for i in range(num_bailouts):
+            plan.code.request_bailout(i)
+            bailout_outputs = model(*inputs)
+            self.assertEqual(bailout_outputs, expected)
+
+    def checkScript(self,
+                    script,
+                    inputs,
+                    name='func',
+                    optimize=True,
+                    inputs_requires_grad=False,
+                    capture_output=False,
+                    frames_up=1,
+                    profiling=ProfilingMode.PROFILING,
+                    atol=None,
+                    rtol=None):
+        """
+        Checks that a given script generates the same output as the Python
+        version using the given inputs.
+        """
+        with torch.jit.optimized_execution(optimize), enable_profiling_mode_for_profiling_tests():
+            extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs)
+            if isinstance(script, str):
+                # Compile the string to a Script function
+                # with enable_profiling_mode():
+                cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+
+                # Execute the Python function so we can run it later and get its
+                # outputs
+
+                frame = self.get_frame_vars(frames_up)
+                the_locals: dict[str, Any] = {}
+                execWrapper(script, glob=frame, loc=the_locals)
+                frame.update(the_locals)
+
+                python_fn = frame[name]
+                scripted_fn = getattr(cu, name)
+            else:
+
+                # Check the string frontend first
+                source = textwrap.dedent(inspect.getsource(script))
+                self.checkScript(
+                    source,
+                    inputs,
+                    script.__name__,
+                    optimize=optimize,
+                    inputs_requires_grad=inputs_requires_grad,
+                    capture_output=capture_output,
+                    profiling=profiling,
+                    frames_up=2)
+
+                # Continue checking the Python frontend
+                scripted_fn = torch.jit.script(script, _frames_up=1)
+                python_fn = script
+
+            if inputs_requires_grad:
+                recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
+            else:
+                recording_inputs = inputs
+
+            if capture_output:
+                with self.capture_stdout() as script_stdout:
+                    script_outputs = scripted_fn(*recording_inputs)
+                with self.capture_stdout():
+                    opt_script_outputs = scripted_fn(*recording_inputs)
+                with self.capture_stdout():
+                    python_outputs = python_fn(*inputs)
+                if not IS_WINDOWS:
+                    self.assertExpected(script_stdout[0], subname='stdout')
+                self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+            else:
+                # profiling run
+                script_outputs = scripted_fn(*recording_inputs)
+                if inputs_requires_grad or extra_profile_runs:
+                    opt_script_outputs = scripted_fn(*recording_inputs)
+                # optimized run
+                opt_script_outputs = scripted_fn(*recording_inputs)
+                if TEST_BAILOUTS:
+                    self.checkBailouts(scripted_fn, inputs, opt_script_outputs)
+                python_outputs = python_fn(*inputs)
+            self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol)
+            self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+            return scripted_fn
+
+    def checkTrace(self, func, reference_tensors, input_tensors=None,
+                   drop=None, allow_unused=False, verbose=False,
+                   inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
+                   _force_outplace=False, grad_atol=None, grad_rtol=None):
+
+        # TODO: check gradients for parameters, not just inputs
+        def allSum(vs):
+            # drop allows us to remove some values from ever being used
+            # to test unused outputs
+            if drop is not None:
+                vs = vs[:-drop]
+            # we don't want all the grad for all the outputs to be the same
+            # so we multiply each by a constant
+            return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
+        if input_tensors is None:
+            input_tensors = reference_tensors
+
+        def flatten_inputs(inputs):
+            def input_reduce(input, fn, acc):
+                if isinstance(input, torch.Tensor):
+                    fn(input, acc)
+                elif isinstance(input, dict):
+                    reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
+                else:
+                    reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
+                return acc
+            return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
+
+        nograd_inputs = reference_tensors
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+        else:
+            recording_inputs = reference_tensors
+
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `checkTrace` already does all the checks
+        # against python function
+        ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance,
+                             _force_outplace=_force_outplace, check_trace=False)
+
+        if export_import:
+            ge = self.getExportImportCopy(ge)
+
+        if verbose:
+            print(ge.graph)
+
+        # test no gradients case
+        outputs = func(*nograd_inputs)
+        outputs_ge = ge(*nograd_inputs)
+        self.assertEqual(outputs, outputs_ge)
+
+        # test gradients case
+        outputs = func(*recording_inputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
+                                        allow_unused=allow_unused)
+
+        outputs_ge = ge(*recording_inputs)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
+                                           allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+
+        # test the grad grad case
+        outputs = func(*recording_inputs)
+        l1 = allSum(outputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
+                                        allow_unused=allow_unused)
+        if inputs_require_grads:
+            l2 = (allSum(grads) * l1)
+            grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+
+        outputs_ge = ge(*recording_inputs)
+        l1_ge = allSum(outputs_ge)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(
+                l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            l2_ge = (allSum(grads_ge) * l1_ge)
+            grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+            for g2, g2_ge in zip(grads2, grads2_ge, strict=True):
+                if g2 is None and g2_ge is None:
+                    continue
+                self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4)
+
+        return ge
+
+    def checkModule(self, nn_module, args):
+        """
+        Check that a nn.Module's results in Script mode match eager and that it
+        can be exported
+        """
+        sm = torch.jit.script(nn_module)
+
+        with freeze_rng_state():
+            eager_out = nn_module(*args)
+
+        with freeze_rng_state():
+            script_out = sm(*args)
+
+        self.assertEqual(eager_out, script_out)
+        self.assertExportImportModule(sm, args)
+
+        return sm
+
+class NoTracerWarnContextManager:
+    def __enter__(self):
+        self.prev = torch._C._jit_get_tracer_state_warn()
+        torch._C._jit_set_tracer_state_warn(False)
+
+    def __exit__(self, *args):
+        torch._C._jit_set_tracer_state_warn(self.prev)
+
+@contextmanager
+def inline_everything_mode(should_inline):
+    old = torch._C._jit_get_inline_everything_mode()
+    torch._C._jit_set_inline_everything_mode(should_inline)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_inline_everything_mode(old)
+
+@contextmanager
+def set_fusion_group_inlining(inlining):
+    old = torch._C._debug_get_fusion_group_inlining()
+    torch._C._debug_set_fusion_group_inlining(inlining)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_fusion_group_inlining(old)
+
+# note: not re-entrant, use unnested only
+@contextmanager
+def disable_autodiff_subgraph_inlining(enabled=True):
+    torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_autodiff_subgraph_inlining(True)
+
+def _inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(True):
+            fn(*args, **kwargs)
+    return wrapper
+
+# this exists for forward compatibility reasons temporarily.
+# TODO(suo) remove
+def _tmp_donotuse_dont_inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(False):
+            fn(*args, **kwargs)
+    return wrapper
+
+# make it easy to quickly define/trace a function for these tests
+def _trace(*args, **kwargs):
+    def wrapper(func):
+        return torch.jit.trace(func, args, **kwargs)
+    return wrapper
+
+
+def enable_cpu_fuser(fn):
+    def wrapper(*args, **kwargs):
+        torch._C._jit_override_can_fuse_on_cpu_legacy(True)
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch._C._jit_override_can_fuse_on_cpu_legacy(False)
+            torch._C._jit_override_can_fuse_on_cpu(False)
+            torch._C._jit_set_te_must_use_llvm_cpu(True)
+    return wrapper
+
+
+def enable_cpu_fuser_if(cond):
+    if cond:
+        return enable_cpu_fuser
+    else:
+        def noop_fuser(fn):
+            def wrapper(*args, **kwargs):
+                return fn(*args, **kwargs)
+            return wrapper
+        return noop_fuser
+
+def get_forward(c):
+    return c._get_method('forward')
+
+def get_forward_graph(c):
+    return c._get_method('forward').graph
+
+def get_module_method(m, module, method):
+    return m._c.getattr(module)._get_method(method)
+
+def attrs_with_prefix(module, prefix):
+    return [x for x, _ in module._modules._c.items()
+            if x.startswith(prefix)]
+
+def warmup_backward(f, *args):
+    profiling_count = 3
+    results = []
+    for _ in range(profiling_count):
+        if len(args) > 0:
+            r = torch.autograd.grad(f, *args)
+            results.append(r)
+        else:
+            f.backward(retain_graph=True)
+
+    return results
+
+# TODO: Remove me once https://bugs.python.org/issue42666 is resolved
+def make_global(*args):
+    for arg in args:
+        setattr(sys.modules[arg.__module__], arg.__name__, arg)
+
+# Helper function to eval Python3 code without causing a syntax error for
+# this file under py2
+def _get_py3_code(code, fn_name):
+    with tempfile.TemporaryDirectory() as tmp_dir:
+        script_path = os.path.join(tmp_dir, 'script.py')
+        with open(script_path, 'w') as f:
+            f.write(code)
+        spec = importlib.util.spec_from_file_location(fn_name, script_path)
+        module = importlib.util.module_from_spec(spec)
+        loader = spec.loader
+        assert isinstance(loader, Loader)  # Assert type to meet MyPy requirement
+        loader.exec_module(module)
+        fn = getattr(module, fn_name)
+        return fn
+
+class TensorExprTestOptions:
+    def __init__(self) -> None:
+        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
+        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
+
+        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
+        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_override_can_fuse_on_gpu(True)
+        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
+        torch._C._jit_set_texpr_fuser_enabled(True)
+        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
+        torch._C._debug_set_fusion_group_inlining(False)
+        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+
+    def restore(self):
+        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
+        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
+
+        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
+        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
+        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
+        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
+        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
+
+def clone_inputs(args):
+    inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            inputs.append(arg.detach().clone())
+        elif is_iterable_of_tensors(arg):
+            inputs.append([t.detach().clone() for t in arg])
+        else:
+            inputs.append(arg)
+
+    return inputs
+
+def get_traced_sample_variant_pairs(device, dtype, op):
+    # tuples of (variant, sample)
+    outputs: list[tuple[Any, Any]] = []
+
+    samples = op.sample_inputs(device, dtype)
+
+    # Acquires variants to test
+    func = op.get_op()
+    method = op.get_method()
+    variants = {
+        # TODO: inplace tests currently fail, fix and add inplace variant
+        'function': func, 'method': method,
+    }
+
+    # TODO: find better way to standardize on op registration itself..
+    has_fake_function = op.name in ["resize_", 'resize_as_']
+
+    if has_fake_function:
+        variants = {'method': getattr(torch.Tensor, op.name)}
+
+    # In eager mode, these ops can take (Tensor, bool) args; but in
+    # JIT they can only take (Tensor, Scalar), and bool is not a
+    # scalar in the JIT type system. So to test these in JIT, the bool
+    # is converted to an int for the test.
+    ops_with_unsupported_bool_args = [
+        {
+            "name": "div_floor_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_no_rounding_mode",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_trunc_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "index_fill",
+            "arg_idx": [2],
+        },
+        {
+            "name": "full_like",
+            "arg_idx": [0],
+        },
+        {
+            "name": "mul",
+            "arg_idx": [0],
+        },
+        {
+            "name": "new_full",
+            "arg_idx": [1],
+        },
+    ]
+
+    # doesn't support tracing
+    if has_fake_function:
+        return outputs
+
+    for sample in samples:
+        for variant in variants.values():
+            if variant is None:
+                continue
+
+            if is_lambda(variant):
+                continue
+
+            matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
+            for op_data in matching_ops:
+                for idx in op_data["arg_idx"]:
+                    args = list(sample.args)
+                    if len(sample.args) > idx and isinstance(sample.args[idx], bool):
+                        args[idx] = int(args[idx])
+                    sample.args = tuple(args)
+
+            outputs.append((variant, sample))
+
+    return outputs
+
+# types.LambdaType gave false positives
+def is_lambda(lamb):
+    LAMBDA = lambda: 0  # noqa: E731
+    return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a9b21471b28b9e66058c510f66f9cef12e8bd5b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9767455e976b02f8541e8d258e50f3fc81537d3c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/refs.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afb4e394e88c0de7d9446b70db842caa8330ef0a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__pycache__/utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26d3f402e741a54f21a5fca48beded5b0a58aec
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
@@ -0,0 +1,26 @@
+# mypy: ignore-errors
+
+from torch.testing._internal.opinfo.core import OpInfo
+from torch.testing._internal.opinfo.definitions import (
+    _masked,
+    fft,
+    linalg,
+    signal,
+    special,
+)
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    *fft.op_db,
+    *linalg.op_db,
+    *signal.op_db,
+    *special.op_db,
+    *_masked.op_db,
+]
+
+python_ref_db: list[OpInfo] = [
+    *fft.python_ref_db,
+    *linalg.python_ref_db,
+    *special.python_ref_db,
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0619574182201745df1c436096fbf8051c72bed9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f48e28643fab9ab62360e8a39f8235afa2db3f6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cc1e53d995f190ea958d72b121e8a0e5889423c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eeb4e6a7ba304e664a0251add6635890ff789bb0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea720dfb8097072261c33a16389da99a84e659ad
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c6e3afbb31799baa1c4783f159f657280130076
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..deb0ddab2d80a6cb73db7b640ceefd347af46be1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bce354d1f9a14e72a752b6cc5f85d393f436cace
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
new file mode 100644
index 0000000000000000000000000000000000000000..d65fbef658a4545ae9459fc5ad561572865d96f3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -0,0 +1,1212 @@
+# mypy: ignore-errors
+
+import unittest
+from collections.abc import Sequence
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+    complex_types,
+    floating_and_complex_types_and,
+    floating_types_and,
+    integral_types,
+)
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    gradcheck_wrapper_masked_operation,
+    gradcheck_wrapper_masked_pointwise_operation,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    sample_inputs_reduction,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
+
+
+# Used for log_softmax, softmax, softmin
+def sample_inputs_softmax_variant(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    with_dtype=False,
+    use_zero_dimensions=True,
+    **kwargs,
+):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    cases = [
+        ((S,), (0,)),
+        ((S, S), (0,)),
+        ((S, S), (1,)),
+        ((S, S), (-1,)),
+        ((S, M, S), (2,)),
+        *([((S, 0, 0), (-1,))] if use_zero_dimensions else []),
+    ]
+    kwargs = dict(dtype=torch.float64) if with_dtype else None
+
+    # PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
+    # See https://github.com/pytorch/xla/issues/3061 for more details.
+    if torch.device(device).type != "xla":
+        cases.append(((), (0,)))
+
+    return (
+        SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
+    )
+
+
+def _generate_masked_op_mask(input_shape, device, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=torch.bool, device=device, requires_grad=False
+    )
+    yield None
+    yield make_arg(input_shape)
+    if len(input_shape) > 2:
+        # broadcast last mask dimension:
+        yield make_arg(input_shape[:-1] + (1,))
+        # broadcast middle mask dimension:
+        yield make_arg(input_shape[:1] + (1,) + input_shape[2:])
+        # broadcast first mask dimension:
+        yield make_arg((1,) + input_shape[1:])
+        # mask.ndim < input.ndim
+        yield make_arg(input_shape[1:])
+        # mask.ndim == 1
+        yield make_arg(input_shape[-1:])
+        # masks that require broadcasting of inputs (mask.ndim >
+        # input.ndim) will not be supported, however, we may
+        # reconsider this if there will be demand on this kind of
+        # degenerate cases.
+
+
+def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked reduction operators.
+
+    Masked reduction operator is a reduction operator with trailing
+    mask optional argument. A mask is a bool tensor with the same
+    shape as input or a shape that is broadcastable to input shape.
+    """
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = (
+                sample_input.args,
+                dict(mask=mask, **sample_input.kwargs),
+            )
+            yield SampleInput(
+                sample_input.input.detach().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+            if (
+                not requires_grad
+                and dtype.is_floating_point
+                and sample_input.input.ndim == 2
+                and mask is not None
+                and mask.shape == sample_input.input.shape
+            ):
+                for v in [torch.inf, -torch.inf, torch.nan]:
+                    t = sample_input.input.detach()
+                    t.diagonal(0, -2, -1).fill_(v)
+                    yield SampleInput(
+                        t.requires_grad_(requires_grad),
+                        args=sample_input_args,
+                        kwargs=sample_input_kwargs,
+                    )
+
+
+def sample_inputs_sparse_coo_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse coo layouts.
+    """
+    if op_info.supports_sparse:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse())
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in {"prod", "amax", "amin"}:
+                    # FIXME: for now reductions with non-zero reduction identity and
+                    # unspecified mask are not supported for sparse COO
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def sample_inputs_sparse_csr_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse csr layouts.
+    """
+    if op_info.supports_sparse_csr:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if not (
+                sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
+            ):
+                # - sparse CSR tensors are always 2-D tensors
+                # - masked reduction on CSR tensors are defined only if keepdim is True.
+                continue
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse_csr())
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in ["prod", "amax", "amin", "mean"]:
+                    # reductions with non-zero reduction identity and
+                    # unspecified mask is not supported for sparse CSR
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+            yield new_sample
+            if sample_input.kwargs["dim"] == 0:
+                # Reductions of CSR tensors use different implementations for
+                # inner and/or outer dimensions. So, as a minimum of testing CSR
+                # implementations the following kwargs must be generated:
+                #   dict(dim=0, keepdim=True)
+                #   dict(dim=1, keepdim=True)
+                #   dict(dim=(0, 1), keepdim=True)
+                # Here we generate the dim=1 case from the dim=0 case.
+                sample_input_kwargs = new_sample.kwargs.copy()
+                sample_input_kwargs.update(dim=1)
+                yield SampleInput(
+                    new_sample.input.clone(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+
+
+def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked norm."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = (
+                (ord,) + sample_input.args,
+                sample_input.kwargs.copy(),
+            )
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+
+
+def reference_masked_std_var(
+    numpy_fn,
+):
+    ref = reference_reduction_numpy(numpy_fn)
+
+    # Translate unbiased or correction arguments into ddof
+    def func(
+        input,
+        dim=None,
+        unbiased=None,
+        *,
+        correction=None,
+        **kwargs,
+    ):
+        ddof = 1
+        if unbiased is not None:
+            ddof = 1 if unbiased else 0
+        if correction is not None:
+            ddof = correction
+
+        if isinstance(dim, Sequence):
+            dim = tuple(dim)
+
+        return ref(input, dim, ddof=ddof, **kwargs)
+
+    return func
+
+
+def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked std/var."""
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+    from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
+
+    def masked_samples():
+        for sample_input in sample_inputs_std_var(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if len(sample_input.args) and isinstance(sample_input.args[0], bool):
+                continue  # masked.{std, var} doesn't support `.var(unbiased)`
+
+            for mask in _generate_masked_op_mask(
+                sample_input.input.shape, device, **kwargs
+            ):
+                sample_input_args, sample_input_kwargs = (
+                    sample_input.args,
+                    dict(mask=mask, **sample_input.kwargs),
+                )
+                yield SampleInput(
+                    sample_input.input.detach().requires_grad_(requires_grad),
+                    args=sample_input_args,
+                    kwargs=sample_input_kwargs,
+                )
+                if (
+                    not requires_grad
+                    and dtype.is_floating_point
+                    and sample_input.input.ndim == 2
+                    and mask is not None
+                    and mask.shape == sample_input.input.shape
+                ):
+                    for v in [torch.inf, -torch.inf, torch.nan]:
+                        t = sample_input.input.detach()
+                        t.diagonal(0, -2, -1).fill_(v)
+                        yield SampleInput(
+                            t.requires_grad_(requires_grad),
+                            args=sample_input_args,
+                            kwargs=sample_input_kwargs,
+                        )
+
+    for sample_input in masked_samples():
+        correction = sample_input.kwargs.get("correction")
+        if correction is None:
+            correction = int(sample_input.kwargs.get("unbiased", True))
+
+        dim = sample_input.kwargs.get("dim", None)
+
+        if sample_input.kwargs.get("mask") is None:
+            orig_count = torch.masked.sum(
+                torch.ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+            )
+        else:
+            inmask = torch.masked._input_mask(
+                sample_input.input, *sample_input.args, **sample_input.kwargs
+            )
+            orig_count = torch.masked.sum(
+                inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+                mask=inmask,
+            )
+        if orig_count.min() <= correction + 1:
+            # Skip samples that lead to nans in var computation
+            continue
+
+        yield sample_input
+
+
+def sample_inputs_masked_softmax(
+    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
+):
+    """Sample inputs for masked softmax, log_softmax, and softmin.
+
+    Masked normalization operator is a reduction operator with
+    trailing mask optional argument. A mask is a bool tensor with the
+    same shape as input or a shape that is broadcastable to input
+    shape.
+    """
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input.args,
+                mask=mask,
+                **sample_input.kwargs,
+            )
+
+
+def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked cumsum and cumprod."""
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            if type(mask) is not torch.Tensor:
+                continue
+            sample_input_args, sample_input_kwargs = (
+                sample_input.args,
+                dict(mask=mask, **sample_input.kwargs),
+            )
+            if "keepdim" in sample_input_kwargs:
+                sample_input_kwargs.pop("keepdim")
+            # dimension is required
+            if sample_input_args:
+                dim = sample_input.args[0]
+            else:
+                if "dim" not in sample_input_kwargs:
+                    continue
+                dim = sample_input_kwargs.pop("dim")
+                sample_input_args = (dim,)
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input_args,
+                **sample_input_kwargs,
+            )
+
+
+def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked logaddexp."""
+    shapes = [(S,), (S, S), (S, M, S)]
+    input_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+    other_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for shape, input_masks, other_masks in zip(
+        shapes, input_mask_lists, other_mask_lists, strict=True
+    ):
+        for input_mask, other_mask in zip(input_masks, other_masks, strict=True):
+            yield SampleInput(
+                make_arg(shape),
+                make_arg(shape),
+                input_mask=input_mask,
+                other_mask=other_mask,
+            )
+
+
+def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked normalize."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_softmax_variant(
+            op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                ord,
+                *sample_input.args,
+                **sample_input.kwargs,
+            )
+
+
+op_db: list[OpInfo] = [
+    ReductionOpInfo(
+        "masked.sum",
+        ref=reference_reduction_numpy(np.sum),
+        method_variant=None,
+        identity=0,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=5e-2),
+                        torch.float16: tol(atol=1e-03, rtol=5e-3),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=0.1, rtol=0.1),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestMasked",
+                "test_mask_layout",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    ReductionOpInfo(
+        "masked.prod",
+        ref=prod_numpy,
+        method_variant=None,
+        identity=1,
+        nan_policy="propagate",
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                "TestReductions",
+                "test_ref_small_input",
+                dtypes=(torch.int8, torch.int16, torch.int32),
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cuda",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_duplicate_values",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
+                "TestOperators",
+                "test_jvp",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    OpInfo(
+        "masked.cumsum",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.cumprod",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCompositeCompliance",
+                "test_backward",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amax",
+        nan_policy="propagate",
+        supports_out=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amax),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amin",
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amin),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmax",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmax
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmin",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmin
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.mean",
+        ref=reference_reduction_numpy(np.mean)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_sparse_csr=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=0.05),
+                        torch.float16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}),
+                "TestSparseCompressed",
+                "test_consistency",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.median",
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        method_variant=None,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=partial(
+            sample_inputs_masked_softmax, use_zero_dimensions=False
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.norm",
+        identity=0,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        promotes_int_to_float=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # torch.jit.frontend.NotSupportedError: Compiled functions
+            # can't take variable number of arguments or use
+            # keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_masked_norm,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.var",
+        ref=reference_masked_std_var(np.var)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=4e-5, rtol=2e-2),
+                    }
+                ),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    ReductionOpInfo(
+        "masked.std",
+        ref=reference_masked_std_var(np.std)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-02, rtol=1e-02),
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=5e-03, rtol=5e-04),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    OpInfo(
+        "masked.softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.log_softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.softmin",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME:
+            # Mismatched elements: 2 / 2 (100.0%)
+            # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed)
+            # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOperators",
+                "test_vmapvjpvjp",
+                device_type="cpu",
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.normalize",
+        method_variant=None,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_normalize,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.logaddexp",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_logaddexp,
+        gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation,
+    ),
+    ReductionOpInfo(
+        "masked.logsumexp",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # Identity can't be -torch.inf without overflow
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestReductions",
+                "test_empty_tensor_empty_slice",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # all the values are the same except for -inf vs nan
+            DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
+            # FIXME:
+            # Mismatched elements: 2 / 12 (16.7%)
+            # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0)
+            # Greatest relative difference: 0.0 at index (0, 0, 0, 1)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cpu",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8293fca978f262d7bf6eea6b546b2c3cd500f227
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
@@ -0,0 +1,809 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import SM53OrLater
+from torch.testing._internal.common_device_type import precisionOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+)
+from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    sample_inputs_spectral_ops,
+    SampleInput,
+    SpectralFuncInfo,
+    SpectralFuncType,
+)
+from torch.testing._internal.opinfo.refs import (
+    _find_referenced_opinfo,
+    _inherit_constructor_args,
+    PythonRefInfo,
+)
+
+
+has_scipy_fft = False
+if TEST_SCIPY:
+    try:
+        import scipy.fft
+
+        has_scipy_fft = True
+    except ModuleNotFoundError:
+        pass
+
+
+class SpectralFuncPythonRefInfo(SpectralFuncInfo):
+    """
+    An OpInfo for a Python reference of an elementwise unary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant="",
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant, op_db=op_db
+        )
+        assert isinstance(self.torch_opinfo, SpectralFuncInfo)
+
+        inherited = self.torch_opinfo._original_spectral_func_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        super().__init__(**ukwargs)
+
+
+def error_inputs_fft(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Zero-dimensional tensor has no dimension to take FFT of
+    yield ErrorInput(
+        SampleInput(make_arg()),
+        error_type=IndexError,
+        error_regex="Dimension specified as -1 but tensor has no dimensions",
+    )
+
+
+def error_inputs_fftn(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Specifying a dimension on a zero-dimensional tensor
+    yield ErrorInput(
+        SampleInput(make_arg(), dim=(0,)),
+        error_type=IndexError,
+        error_regex="Dimension specified as 0 but tensor has no dimensions",
+    )
+
+
+def sample_inputs_fft_with_min(
+    op_info, device, dtype, requires_grad=False, *, min_size, **kwargs
+):
+    yield from sample_inputs_spectral_ops(
+        op_info, device, dtype, requires_grad, **kwargs
+    )
+    if TEST_WITH_ROCM:
+        # FIXME: Causes floating point exception on ROCm
+        return
+
+    # Check the "Invalid number of data points" error isn't too strict
+    # https://github.com/pytorch/pytorch/pull/109083
+    a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(a)
+
+
+def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
+    def mt(shape, **kwargs):
+        return make_tensor(
+            shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+
+    yield SampleInput(mt((9, 10)))
+    yield SampleInput(mt((50,)), kwargs=dict(dim=0))
+    yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
+    yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
+    yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    SpectralFuncInfo(
+        "fft.fft",
+        aten_name="fft_fft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.fft2",
+        aten_name="fft_fft2",
+        ref=np.fft.fft2,
+        decomp_aten_name="_fft_c2c",
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+                dtypes=[torch.complex32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.fftn",
+        aten_name="fft_fftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+    ),
+    SpectralFuncInfo(
+        "fft.hfft",
+        aten_name="fft_hfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.hfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfft2",
+        aten_name="fft_hfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfft2 if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+            # FIXME: errors are too large; needs investigation
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfftn",
+        aten_name="fft_hfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.rfft",
+        aten_name="fft_rfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        skips=(),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.rfft2",
+        aten_name="fft_rfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.rfftn",
+        aten_name="fft_rfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifft",
+        aten_name="fft_ifft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ifft2",
+        aten_name="fft_ifft2",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifftn",
+        aten_name="fft_ifftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft",
+        aten_name="fft_ihfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.ihfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fft,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        skips=(),
+        check_batched_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft2",
+        aten_name="fft_ihfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=(
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ihfftn",
+        aten_name="fft_ihfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfft",
+        aten_name="fft_irfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.irfft2",
+        aten_name="fft_irfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfftn",
+        aten_name="fft_irfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    OpInfo(
+        "fft.fftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    OpInfo(
+        "fft.ifftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft",
+        torch_opinfo_name="fft.fft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft",
+        torch_opinfo_name="fft.ifft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft",
+        torch_opinfo_name="fft.rfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft",
+        torch_opinfo_name="fft.irfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft",
+        torch_opinfo_name="fft.hfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft",
+        torch_opinfo_name="fft.ihfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fftn",
+        torch_opinfo_name="fft.fftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifftn",
+        torch_opinfo_name="fft.ifftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfftn",
+        torch_opinfo_name="fft.rfftn",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfftn",
+        torch_opinfo_name="fft.irfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfftn",
+        torch_opinfo_name="fft.hfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfftn",
+        torch_opinfo_name="fft.ihfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # AssertionError: Reference result was farther (0.09746177145360499) from the precise
+            # computation than the torch result was (0.09111555632069855)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_torch_fallback",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+            # AssertionError: Reference result was farther (0.0953431016138116) from the precise
+            # computation than the torch result was (0.09305490684430734)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft2",
+        torch_opinfo_name="fft.fft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft2",
+        torch_opinfo_name="fft.ifft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft2",
+        torch_opinfo_name="fft.rfft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft2",
+        torch_opinfo_name="fft.irfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft2",
+        torch_opinfo_name="fft.hfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft2",
+        torch_opinfo_name="fft.ihfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # FIXME:
+            # Reference result was farther (0.0953431016138116) from the precise computation
+            # than the torch result was (0.09305490684430734)!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                device_type="cuda",
+            ),
+        ],
+    ),
+    PythonRefInfo(
+        "_refs.fft.fftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.fftshift",
+    ),
+    PythonRefInfo(
+        "_refs.fft.ifftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.ifftshift",
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
new file mode 100644
index 0000000000000000000000000000000000000000..f41cadad67eb780aa6980306002a27cacfd2eb30
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
@@ -0,0 +1,2392 @@
+# mypy: ignore-errors
+
+import itertools
+import random
+import unittest
+from collections.abc import Iterable
+from functools import partial
+from itertools import chain, product
+
+import numpy as np
+from numpy import inf
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import _get_magma_version, with_tf32_off
+from torch.testing._internal.common_device_type import (
+    has_cusolver,
+    skipCPUIfNoLapack,
+    skipCUDAIfNoCusolver,
+    skipCUDAIfNoMagma,
+    skipCUDAIfNoMagmaAndNoCusolver,
+    skipCUDAIfNoMagmaAndNoLinalgsolver,
+    skipCUDAIfRocm,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import (
+    all_types_and_complex,
+    all_types_and_complex_and,
+    floating_and_complex_types,
+    floating_and_complex_types_and,
+)
+from torch.testing._internal.common_utils import (
+    GRADCHECK_NONDET_TOL,
+    make_fullrank_matrices_with_distinct_singular_values,
+    skipIfSlowGradcheckEnv,
+    slowTest,
+    TEST_WITH_ROCM,
+    TEST_XPU,
+)
+from torch.testing._internal.opinfo.core import (
+    clone_sample,
+    DecorateInfo,
+    ErrorInput,
+    gradcheck_wrapper_hermitian_input,
+    L,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo
+
+
+def sample_kwargs_vector_norm(t, **kwargs):
+    # orders with / without identity
+    def ords():
+        has_id = (6, 4, 2, 1, 0, 0.9)
+        no_id = (inf, -2.1, -inf)
+        if t.numel() == 0:
+            dim = kwargs.get("dim")
+            if dim is None:
+                return has_id
+            if not isinstance(dim, Iterable):
+                dim = (dim,)
+            for d in dim:
+                if t.size(d) == 0:
+                    return has_id
+        return has_id + no_id
+
+    return (((), dict(ord=o)) for o in ords())
+
+
+def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    is_linalg_svd = "linalg.svd" in op_info.name
+    batches = [(), (0,), (3,)]
+    ns = [0, 3, 5]
+
+    def uniformize(usv):
+        S = usv[1]
+        k = S.shape[-1]
+        U = usv[0][..., :k]
+        Vh = usv[2] if is_linalg_svd else usv[2].mH
+        Vh = Vh[..., :k, :]
+        return U, S, Vh
+
+    def fn_U(usv):
+        U, _, _ = uniformize(usv)
+        return U.abs()
+
+    def fn_S(usv):
+        return uniformize(usv)[1]
+
+    def fn_Vh(usv):
+        # We also return S to test
+        _, S, Vh = uniformize(usv)
+        return S, Vh.abs()
+
+    def fn_UVh(usv):
+        U, S, Vh = uniformize(usv)
+        return U @ Vh, S
+
+    fns = (fn_U, fn_S, fn_Vh, fn_UVh)
+
+    fullmat = "full_matrices" if is_linalg_svd else "some"
+
+    for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns):
+        shape = batch + (n, k)
+        yield SampleInput(
+            make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn
+        )
+
+
+def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
+    yield SampleInput(
+        make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
+    )
+    yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
+
+
+def error_inputs_cross(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
+    err = "inputs dimension -1 must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
+    err = "inputs must have the same number of dimensions"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
+    err = "must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(
+        input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
+    )
+    err = "Dimension out of range"
+    yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+
+
+def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
+    """
+    This function generates input for torch.linalg.householder_product (torch.orgqr).
+    The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors.
+    Empty, square, rectangular, batched square and batched rectangular input is generated.
+    """
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        low=-2,
+        high=2,
+    )
+    # Each column of the matrix is getting multiplied many times leading to very large values for
+    # the Jacobian matrix entries and making the finite-difference result of grad check less accurate.
+    # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here.
+    yield SampleInput(make_arg((S, S)), make_arg((S,)))
+    yield SampleInput(make_arg((S + 1, S)), make_arg((S,)))
+    yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S)))
+    yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S)))
+    yield SampleInput(
+        make_arg((0, 0), low=None, high=None),
+        make_arg((0,), low=None, high=None),
+    )
+    yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None))
+    # m = n = S, k = S - 2
+    yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None))
+    # m = S, n = S -1, k = S - 2
+    yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None))
+
+
+def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_arg_fullrank = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    # (, ())
+    test_sizes = [
+        (1, ()),
+        (2, (0,)),
+        (2, (2,)),
+    ]
+
+    for matrix_size, batch_sizes in test_sizes:
+        size = batch_sizes + (matrix_size, matrix_size)
+        for n in (0, 3, 5):
+            yield SampleInput(make_arg(size), args=(n,))
+        for n in [-4, -2, -1]:
+            yield SampleInput(make_arg_fullrank(*size), args=(n,))
+
+
+def sample_inputs_linalg_det_logdet_slogdet(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    batches = [(), (0,), (3,)]
+    ns = [0, 1, 5]
+
+    is_logdet = op_info.name == "logdet"
+
+    for (
+        batch,
+        n,
+    ) in product(batches, ns):
+        shape = batch + (n, n)
+        A = make_arg(*shape)
+        # Need to make the matrices in A have positive determinant for autograd
+        # To do so, we multiply A by its determinant to flip the sign of its determinant
+        if is_logdet and not A.is_complex() and A.numel() > 0:
+            s = torch.linalg.slogdet(A).sign
+            A = A * s.unsqueeze(-1).unsqueeze(-1)
+            A.requires_grad_(requires_grad)
+        yield SampleInput(A)
+
+
+def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """Samples the inputs for both linalg.lu_solve and lu_solve"""
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(make_fn, dtype=dtype, device=device)
+    make_b = partial(make_tensor, dtype=dtype, device=device)
+
+    def clone(X, requires_grad):
+        Y = X.clone()
+        Y.requires_grad_(requires_grad)
+        return Y
+
+    is_linalg_lu_solve = op_info.name == "linalg.lu_solve"
+
+    batches = ((), (0,), (2,))
+    ns = (3, 1, 0)
+    nrhs = (4, 1, 0)
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        A = make_a(*(batch + (n, n)))
+        if torch.device(device).type == "mps":
+            # TODO: Fix lu_factor for MPS, because it does not work for all of
+            # these cases. So we resort to the CPU impl here and move the
+            # outputs back to MPS.
+            LU, pivots = (x.to(device) for x in torch.linalg.lu_factor(A.cpu()))
+        else:
+            LU, pivots = torch.linalg.lu_factor(A)
+
+        B = make_b(batch + (n, rhs))
+
+        grads = (False,) if not requires_grad else (True, False)
+        # we try all possible combinations of requires_grad for each input
+        for LU_grad, B_grad in product(grads, grads):
+            # when requires_grad == True, at least one input has to have requires_grad enabled
+            if requires_grad and not LU_grad and not B_grad:
+                continue
+
+            if is_linalg_lu_solve:
+                for adjoint, left in product((True, False), repeat=2):
+                    yield SampleInput(
+                        clone(LU, LU_grad),
+                        args=(pivots, clone(B if left else B.mT, B_grad)),
+                        kwargs=dict(adjoint=adjoint, left=left),
+                    )
+            else:
+                yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots))
+
+
+def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
+    # Each test case consists of the sizes in the chain of multiplications
+    # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
+    test_cases = [
+        [1, 2, 1],
+        [2, 0, 2],
+        [0, 2, 2],
+        [2, 2, 2, 2],
+        [2, 3, 4, 5],
+        [5, 4, 0, 2],
+        [2, 4, 3, 5, 3, 2],
+    ]
+
+    for sizes in test_cases:
+        tensors = []
+        for size in itertools.pairwise(sizes):
+            t = make_tensor(
+                size, dtype=dtype, device=device, requires_grad=requires_grad
+            )
+            tensors.append(t)
+        yield SampleInput(tensors)
+
+
+def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):
+    low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    sizes = ((2, 2), (2, 3, 2))
+    if dtype in low_precision_dtypes:
+        # svdvals not supported for low precision dtypes
+        ords = ("fro", inf, -inf, 1, -1)
+    else:
+        ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2)
+    dims = ((-2, -1), (-1, 0))
+
+    for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]):
+        yield SampleInput(make_arg(size), args=(ord, dim, keepdim))
+
+
+def sample_inputs_linalg_norm(
+    op_info, device, dtype, requires_grad, *, variant=None, **kwargs
+):
+    if variant is not None and variant != "subgradient_at_zero":
+        raise ValueError(
+            f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
+        )
+
+    test_sizes = [
+        (S,),
+        (0,),
+        (S, S),
+        (0, 0),
+        (S, 0),
+        (0, S),
+        (S, S, S),
+        (0, S, S),
+        (S, 0, S),
+        (0, 0, 0),
+    ]
+
+    vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf)
+    if dtype in {torch.float16, torch.bfloat16, torch.complex32}:
+        # svdvals not supported for low precision dtypes
+        matrix_ords = ("fro", inf, -inf, 1, -1)
+    else:
+        matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)
+
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        low=None,
+        high=None,
+    )
+
+    for test_size in test_sizes:
+        is_vector_norm = len(test_size) == 1
+        is_matrix_norm = len(test_size) == 2
+
+        # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
+        is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0)
+
+        for keepdim in [False, True]:
+            if variant != "subgradient_at_zero" and is_valid_for_p2:
+                yield SampleInput(make_arg(test_size), keepdim=keepdim)
+
+            if not (is_vector_norm or is_matrix_norm):
+                continue
+
+            ords = vector_ords if is_vector_norm else matrix_ords
+
+            for ord in ords:
+                if is_vector_norm and test_size[-1] == 0:
+                    if ord == np.inf or (ord is not None and ord < 0):
+                        # RuntimeError: linalg.vector_norm cannot compute the
+                        # {ord} norm on an empty tensor because the operation
+                        # does not have an identity
+                        continue
+                elif is_matrix_norm:
+                    dims_to_check = {
+                        None: (0,),
+                        -1: (1,),
+                        -2: (0, 1),
+                        -np.inf: (0,),
+                    }.get(ord, ())
+
+                    if any(test_size[d] == 0 for d in dims_to_check):
+                        # IndexError: amax(): Expected reduction dim {dim} to
+                        # have non-zero size.
+                        continue
+
+                    no_grad_dims_to_check = {
+                        np.inf: (0,),
+                        2: (0, 1),
+                        1: (1,),
+                    }.get(ord, ())
+
+                    if (
+                        any(test_size[d] == 0 for d in no_grad_dims_to_check)
+                        and requires_grad
+                    ):
+                        continue
+
+                if variant == "subgradient_at_zero":
+                    yield SampleInput(
+                        torch.zeros(
+                            test_size,
+                            dtype=dtype,
+                            device=device,
+                            requires_grad=requires_grad,
+                        ),
+                        ord,
+                        keepdim=keepdim,
+                    )
+                else:
+                    yield SampleInput(make_arg(test_size), ord, keepdim=keepdim)
+
+                    if ord in ["nuc", "fro"]:
+                        yield SampleInput(
+                            make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1)
+                        )
+
+
+def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    batches = ((), (0,), (1,), (5,))
+    ns = (0, 1, 3, 5)
+    for b, n in product(batches, ns):
+        shape = b + (n,)
+        yield SampleInput(make_arg(shape), args=(make_arg(shape),))
+        for i in range(len(shape)):
+            yield SampleInput(
+                make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)
+            )
+
+
+def sample_inputs_linalg_invertible(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates invertible inputs for linear algebra ops
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+
+    for batch, n in product(batches, ns):
+        yield SampleInput(make_arg(*batch, n, n))
+
+
+def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function produces inputs for matrix rank that test
+    all possible combinations for atol and rtol
+    """
+
+    def make_tol_arg(kwarg_type, inp):
+        if kwarg_type == "none":
+            return None
+        if kwarg_type == "float":
+            return 1.0
+        assert kwarg_type == "tensor"
+        return torch.ones(inp.shape[:-2], device=device)
+
+    for tol_type in ["float", "tensor"]:
+        for atol_type, rtol_type in product(["none", tol_type], repeat=2):
+            if (
+                not atol_type and not rtol_type
+            ):  # default behavior, so skipped here so it's not tested 2 extra times
+                continue
+            for sample in sample_inputs_linalg_invertible(
+                op_info, device, dtype, requires_grad
+            ):
+                assert sample.kwargs == {}
+                sample.kwargs = {
+                    "atol": make_tol_arg(atol_type, sample.input),
+                    "rtol": make_tol_arg(rtol_type, sample.input),
+                }
+                yield sample
+
+    # default kwargs
+    yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+
+
+def sample_inputs_linalg_pinv_singular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to
+    test the backward method of `linalg_pinv`. That way we always preserve the rank of the
+    input no matter the perturbations applied to it by the gradcheck.
+    Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood.
+    """
+    batches = [(), (0,), (2,), (1, 1)]
+    # the size of at least 30 is required to cause failures for the previous implicit implementation
+    # of the pinv's backward method, albeit it is slow.
+    size = [0, 3, 50]
+
+    for batch, m, n in product(batches, size, size):
+        for k in range(min(3, m, n)):
+            # Note that by making the columns of `a` and `b` orthonormal we make sure that
+            # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
+            a = (
+                torch.rand(*batch, m, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            b = (
+                torch.rand(*batch, n, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            yield SampleInput(a, args=(b,))
+
+
+def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # autograd is not supported for inputs with zero number of elements
+    shapes = (
+        (S, S),
+        (2, S, S),
+        (2, 1, S, S),
+    )
+
+    for shape in shapes:
+        yield SampleInput(make_arg(shape))
+
+
+def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    shapes = (
+        (),
+        (1,),
+        (S,),
+        (2, S),
+    )
+
+    for shape in shapes:
+        if len(shape) > 0 and shape[-1] > 1:
+            yield SampleInput(make_arg(shape))
+        n = shape[-1] if len(shape) > 0 else 1
+        for i in range(3):
+            # n-1, n, n+1
+            N = n + i - 1
+            if N < 2:
+                continue
+            yield SampleInput(make_arg(shape), kwargs=dict(N=N))
+
+
+def np_vander_batched(x, N=None):
+    # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests)
+    if x.ndim == 0:
+        x = x[np.newaxis]
+    if x.ndim == 1:
+        y = np.vander(x, N=N, increasing=True)
+        return y
+    else:
+        if N is None:
+            N = x.shape[-1]
+        y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N))
+        return y
+
+
+def sample_inputs_linalg_cholesky_inverse(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    # Cholesky factorization is for positive-definite matrices
+    single_well_conditioned_matrix = random_well_conditioned_matrix(
+        S, S, dtype=dtype, device=device
+    )
+    batch_well_conditioned_matrices = random_well_conditioned_matrix(
+        2, S, S, dtype=dtype, device=device
+    )
+    single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH
+    batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH
+
+    inputs = (
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+        single_pd,
+        batch_pd,
+    )
+    test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs)
+    for l in test_cases:
+        # generated lower-triangular samples
+        l.requires_grad = requires_grad
+        yield SampleInput(l)  # upper=False by default
+        yield SampleInput(
+            l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False)
+        )
+
+        # generate upper-triangular inputs
+        u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad)
+        yield SampleInput(u, kwargs=dict(upper=True))
+
+
+def sample_inputs_linalg_ldl_factor(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+
+    # Symmetric inputs
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # single matrix
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # batch of matrices
+    yield SampleInput(
+        torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # 0x0 matrix
+    yield SampleInput(
+        torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # zero batch of matrices
+
+    # Hermitian inputs
+    # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
+    magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4)
+    if dtype.is_complex and (device.type == "cpu" or magma_254_available):
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # single matrix
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # batch of matrices
+
+
+def sample_inputs_linalg_ldl_solve(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # Generate LDL factors of symmetric (and Hermitian on CPU) matrices
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+    symmetric_inputs = (
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),  # single matrix
+        random_symmetric_pd_matrix(
+            S, 2, dtype=dtype, device=device
+        ),  # batch of matrices
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+    )
+    hermitian_inputs = (
+        (
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+        )
+        if device.type == "cpu" and dtype.is_complex
+        else ()
+    )
+    test_cases1 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs
+    )
+    test_cases2 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs
+    )
+
+    # Symmetric case
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    for test_case in test_cases1:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=False)
+            )
+
+    # Hermitian case
+    for test_case in test_cases2:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=True)
+            )
+
+
+def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    device = torch.device(device)
+
+    drivers: tuple[str, ...]
+    if device.type == "cuda":
+        drivers = ("gels",)
+    else:
+        drivers = ("gels", "gelsy", "gelss", "gelsd")
+
+    # we generate matrices of shape (..., n + delta, n)
+    deltas: tuple[int, ...]
+    if device.type == "cpu" or has_cusolver():
+        deltas = (-1, 0, +1)
+    # only square systems if Cusolver is not available
+    # because we solve a lstsq problem with a transposed matrix in the backward
+    else:
+        deltas = (0,)
+
+    for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
+        shape = batch + (3 + delta, 3)
+        a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
+        a.requires_grad_(requires_grad)
+        b = make_tensor(
+            shape,
+            dtype=dtype,
+            device=device,
+            low=None,
+            high=None,
+            requires_grad=requires_grad,
+        )
+        yield SampleInput(a, b, driver=driver)
+
+
+def error_inputs_lstsq(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d,)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d, None)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # Shapes for 2D Tensors
+    shapes_2d = ((S, S), (3, 5), (5, 3))
+
+    # Shapes for 3D Tensors
+    shapes_3d = ((S, S, S),)
+
+    kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1))
+    kwargs_3d = (
+        dict(offset=1, dim1=1, dim2=2),
+        dict(offset=2, dim1=0, dim2=1),
+        dict(offset=-2, dim1=0, dim2=1),
+    )
+
+    for shape, kwarg in chain(
+        product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)
+    ):
+        yield SampleInput(make_arg(shape), kwargs=kwarg)
+
+
+def error_inputs_diagonal_diag_embed(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    shapes1d = (0, 1, (0,), (1,))
+    shapes2d = ((M, L),)
+    shapes3d = ((M, S, L),)
+
+    kwargs1d = {}
+
+    kwargs2d = (
+        # dim1 == dim2 is not allowed
+        dict(dim1=1, dim2=1),
+        # out of bounds dims are not allowed
+        dict(dim1=10000),
+        dict(dim2=10000),
+    )
+
+    kwargs3d = kwargs2d
+
+    samples1d = product(shapes1d, kwargs1d)
+    samples2d = product(shapes2d, kwargs2d)
+    samples3d = product(shapes3d, kwargs3d)
+
+    for shape, kwargs in chain(samples1d, samples2d, samples3d):
+        arg = make_arg(shape)
+        sample = SampleInput(input=arg, kwargs=kwargs)
+
+        dim1 = kwargs.get("dim1")
+        dim2 = kwargs.get("dim2")
+
+        if "diagonal" in op_info.name:
+            num_dim = arg.dim()
+        elif op_info.name in ("diag_embed", "_refs.diag_embed"):
+            # these are valid inputs for diag_embed
+            if shape in ((0,), (1,)):
+                continue
+            num_dim = arg.dim() + 1
+        else:
+            raise RuntimeError("should be unreachable")
+
+        bound1 = -num_dim
+        bound2 = num_dim - 1
+        dim_range = range(bound1, bound2 + 1)
+        dim1_cond = dim1 and dim1 not in dim_range
+        dim2_cond = dim2 and dim2 not in dim_range
+
+        if dim1 == dim2:
+            err = f"diagonal dimensions cannot be identical {dim1}, {dim2}"
+            yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+        elif dim1_cond or dim2_cond:
+            err_dim = dim1 if dim1_cond else dim2
+            err = (
+                r"Dimension out of range \(expected to be in range of "
+                rf"\[{bound1}, {bound2}\], but got {err_dim}\)"
+            )
+            yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+        else:
+            raise RuntimeError("should be unreachable")
+
+
+def sample_inputs_linalg_cholesky(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates always positive-definite input for torch.linalg.cholesky using
+    random_hermitian_pd_matrix.
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    from torch.testing._internal.common_utils import random_hermitian_pd_matrix
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+    for batch, n, upper in product(batches, ns, [True, False]):
+        a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
+        a.requires_grad = requires_grad
+        yield SampleInput(a, upper=upper)
+
+
+def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eig
+    """
+
+    def out_fn(output):
+        return output[0], abs(output[1])
+
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
+    """
+
+    def out_fn(output):
+        if isinstance(output, tuple):
+            # eigh function
+            return output[0], abs(output[1])
+        else:
+            # eigvalsh function
+            return output
+
+    # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        # Note: we cannot use np.random.choice here as TorchDynamo
+        # does not support tensors of strings.
+        sample.kwargs = {"UPLO": random.choice(["L", "U"])}
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=False keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        real_dtype = o.input.real.dtype if dtype.is_complex else dtype
+        # requires_grad path for rtol tensor is not implemented
+        for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)):
+            o = clone_sample(o)
+            o.kwargs = {"rtol": rtol}
+            yield o
+
+
+def sample_inputs_linalg_pinv_hermitian(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=True keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        o.kwargs = {"hermitian": True}
+        yield o
+
+
+def sample_inputs_linalg_solve(
+    op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs
+):
+    """
+    This function generates always solvable input for torch.linalg.solve
+    We sample a fullrank square matrix (i.e. invertible) A
+    The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
+    The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
+    In total this function generates 18 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices.
+    'ns' gives 0x0 and 5x5 matrices.
+    and 'nrhs' controls the number of vectors to solve for:
+        () - using 1 as the number of vectors implicitly
+        (1,) - same as () but explicit
+        (3,) - solve for 3 vectors.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs.
+    torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow
+    1D tensors (vectors) as the right-hand-side.
+    Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
+    'vector_rhs_allowed' may be removed here as well.
+    """
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_b = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (2, 2)]
+    ns = [5, 0]
+    if vector_rhs_allowed:
+        nrhs = [(), (1,), (3,)]
+    else:
+        nrhs = [(1,), (3,)]
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),))
+
+
+def sample_inputs_linalg_solve_triangular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    make_arg = partial(make_tensor, dtype=dtype, device=device)
+    bs = (1, 2, 0)
+    ns = (3, 0)
+    ks = (1, 3, 0)
+
+    for b, n, k, (left, upper, uni) in product(
+        bs, ns, ks, product((True, False), repeat=3)
+    ):
+        if b == 1:
+            A = make_arg((n, n)) if left else make_arg((k, k))
+            B = make_arg((n, k))
+        else:
+            A = make_arg((b, n, n)) if left else make_arg((b, k, k))
+            B = make_arg((b, n, k))
+        if uni:
+            # Not really necessary, but writing it for consistency
+            A.diagonal(0, -2, -1).fill_(1.0)
+        else:
+            d = A.diagonal(0, -2, -1)
+            d[d.abs() < 1e-6] = 1.0
+        if upper:
+            A.triu_()
+        else:
+            A.tril_()
+        kwargs = {"upper": upper, "left": left, "unitriangular": uni}
+        if requires_grad:
+            for grad_A, grad_B in product((True, False), repeat=2):
+                # Either A or B needs to have a gradient
+                if not grad_A and not grad_B:
+                    continue
+                yield SampleInput(
+                    A.clone().requires_grad_(grad_A),
+                    args=(B.clone().requires_grad_(grad_B),),
+                    kwargs=kwargs,
+                )
+        else:
+            yield SampleInput(A, args=(B,), kwargs=kwargs)
+
+
+def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates always solvable input for legacy solve functions
+    (the ones that are not in torch.linalg module).
+    The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation
+    should have b.ndim >= 2, vectors are not allowed.
+    Also the arguments order is swapped.
+    """
+    out = sample_inputs_linalg_solve(
+        op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
+    )
+
+    def out_fn(output):
+        return output[0]
+
+    # Reverses tensor order
+    for sample in out:
+        sample.input, sample.args = sample.args[0], (sample.input,)
+        if op_info.name == "solve":
+            sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs):
+    full_rank = op_info.name == "linalg.lu_factor"
+    make_fn = (
+        make_tensor
+        if not full_rank
+        else make_fullrank_matrices_with_distinct_singular_values
+    )
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    def out_fn(output):
+        if op_info.name == "linalg.lu":
+            return output[1], output[2]
+        else:
+            return output
+
+    batch_shapes = ((), (3,), (3, 3), (0,))
+    # pivot=False only supported in CUDA
+    pivots = (True, False) if torch.device(device).type == "cuda" else (True,)
+    deltas = (-2, -1, 0, +1, +2)
+    for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas):
+        shape = batch_shape + (S + delta, S)
+        # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple!
+        A = make_arg(shape) if not full_rank else make_arg(*shape)
+        yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn)
+
+
+def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, m, n in product(batches, ns, ns):
+        yield SampleInput(make_arg(batch + (m, n)))
+
+
+def sample_inputs_linalg_qr_geqrf(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # QR is just well defined when the matrix is full rank
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, (m, n) in product(batches, product(ns, ns)):
+        shape = batch + (m, n)
+        yield SampleInput(make_arg(*shape))
+
+
+def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
+    a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
+    # Zero-dim tensors are not supported in NumPy, so we skip them for now.
+    # NumPy is used in reference check tests.
+    # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix.
+    # a_shapes += [(0, 0, 1, 2, 3, 0)]
+    dimss = [None, (0, 2)]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for a_shape, dims in itertools.product(a_shapes, dimss):
+        a = make_arg(a_shape)
+        b = make_arg(a_shape[:2])
+        yield SampleInput(a, b, dims=dims)
+
+
+def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = make_fullrank_matrices_with_distinct_singular_values
+
+    def make_input():
+        return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # lhs / rhs shape can have any number of dimensions as long as their product equals 12
+    shapes = [
+        ((2, 2, 3), (12, 1)),
+        ((4, 3), (6, 1, 2)),
+    ]
+
+    for shape_lhs, shape_rhs in shapes:
+        inp = make_input().reshape(*shape_lhs, *shape_rhs).detach()
+        inp.requires_grad_(requires_grad)
+        yield SampleInput(inp, ind=len(shape_lhs))
+
+
+op_db: list[OpInfo] = [
+    OpInfo(
+        "linalg.cross",
+        ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
+        op=torch.linalg.cross,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        aten_name="linalg_cross",
+        sample_inputs_func=sample_inputs_cross,
+        error_inputs_func=error_inputs_cross,
+        supports_out=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.det",
+        aten_name="linalg_det",
+        op=torch.linalg.det,
+        aliases=("det",),
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        check_batched_gradgrad=False,
+    ),
+    OpInfo(
+        "linalg.diagonal",
+        aten_name="linalg_diagonal",
+        aten_backward_name="diagonal_backward",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.float16, torch.chalf
+        ),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_diagonal_diag_embed,
+        error_inputs_func=error_inputs_diagonal_diag_embed,
+    ),
+    OpInfo(
+        "linalg.cholesky",
+        aten_name="linalg_cholesky",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.cholesky_ex",
+        aten_name="linalg_cholesky_ex",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vecdot",
+        aten_name="linalg_vecdot",
+        ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_linalg_vecdot,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.cond",
+        aten_name="linalg_cond",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_cond,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eig",
+        aten_name="linalg_eig",
+        op=torch.linalg.eig,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eig,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # AssertionError: Scalars are not equal!
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+    ),
+    OpInfo(
+        "linalg.eigvals",
+        aten_name="linalg_eigvals",
+        op=torch.linalg.eigvals,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigh",
+        aten_name="linalg_eigh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigvalsh",
+        aten_name="linalg_eigvalsh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # Pre-existing condition; Needs to be fixed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.householder_product",
+        aten_name="linalg_householder_product",
+        op=torch.linalg.householder_product,
+        aliases=("orgqr",),
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        # TODO: backward uses in-place operations that vmap doesn't like
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_householder_product,
+        decorators=[
+            skipCUDAIfNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)})
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped! Flaky"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cpu",
+                dtypes=(torch.complex128,),
+            ),
+            skipCUDAIfRocm,  # regression in ROCm 6.4
+        ],
+    ),
+    OpInfo(
+        "linalg.ldl_factor",
+        aten_name="linalg_ldl_factor",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_factor_ex",
+        aten_name="linalg_ldl_factor_ex",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_solve",
+        aten_name="linalg_ldl_solve",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_solve,
+        decorators=[
+            skipCUDAIfNoCusolver,
+            skipCUDAIfRocm,
+            skipCPUIfNoLapack,
+        ],
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        dtypes=floating_and_complex_types(),
+        supports_out=True,
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # we skip gradient checks for this suite as they are tested in
+            # variant_test_name='grad_oriented'
+            DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"),
+            # The values for attribute 'shape' do not match
+            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        variant_test_name="grad_oriented",
+        # gradchecks for forward AD fails with full output tuple
+        # works when taking [:2], which is (solution, residuals)
+        op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2],
+        supports_out=False,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq_grad_oriented,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # tests do not work with passing lambda for op
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_power",
+        aliases=("matrix_power",),
+        aten_name="linalg_matrix_power",
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_inplace_autograd=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}),
+                "TestConsistency",
+                "test_output_grad_match",
+                device_type="mps",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_linalg_matrix_power,
+    ),
+    OpInfo(
+        "linalg.multi_dot",
+        # Need this lambda because gradcheck does not work with TensorList inputs
+        aten_name="linalg_multi_dot",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        supports_inplace_autograd=False,
+        # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # https://github.com/pytorch/pytorch/issues/66357
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_multi_dot,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/67470
+            DecorateInfo(
+                unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            # Fails on XLA.
+            # AssertionError: False is not true : Tensors failed to compare as equal!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOpInfo",
+                device_type="xla",
+                dtypes=(torch.long,),
+            ),
+            # https://github.com/pytorch/pytorch/issues/71774
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNNCOpInfo",
+                "test_nnc_correctness",
+                device_type="cpu",
+                dtypes=(torch.long,),
+            ),
+        ),
+    ),
+    # NB: linalg.norm has two variants so that different skips can be used for different sample inputs
+    OpInfo(
+        "linalg.norm",
+        aten_name="linalg_norm",
+        op=torch.linalg.norm,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_norm,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.norm",
+        op=torch.linalg.norm,
+        variant_test_name="subgradients_at_zero",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=partial(
+            sample_inputs_linalg_norm, variant="subgradient_at_zero"
+        ),
+        aten_name="linalg_norm",
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
+        # Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # [NEW] Skips specifically for sample inputs at zero
+            # norm's vjp/jvp are not well-conditioned near zero
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_norm",
+        aten_name="linalg_matrix_norm",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_matrix_norm,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.qr",
+        aten_name="linalg_qr",
+        op=torch.linalg.qr,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # In-place ops
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_qr_geqrf,
+        decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.slogdet",
+        aten_name="linalg_slogdet",
+        op=torch.linalg.slogdet,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vander",
+        aten_name="linalg_vander",
+        ref=np_vander_batched,
+        op=torch.linalg.vander,
+        dtypes=all_types_and_complex(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_linalg_vander,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    ReductionOpInfo(
+        "linalg.vector_norm",
+        op=torch.linalg.vector_norm,
+        identity=0,
+        nan_policy="propagate",
+        supports_multiple_dims=True,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
+        # got: Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        generate_args_kwargs=sample_kwargs_vector_norm,
+        aten_name="linalg_vector_norm",
+    ),
+    OpInfo(
+        "linalg.lu_factor",
+        aten_name="linalg_lu_factor",
+        op=torch.linalg.lu_factor,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_factor_ex",
+        aten_name="linalg_lu_factor_ex",
+        op=torch.linalg.lu_factor_ex,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu",
+        aten_name="linalg_lu",
+        op=torch.linalg.lu,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_compare_cpu",
+                active_if=(not TEST_XPU),
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_solve",
+        op=torch.linalg.lu_solve,
+        aten_name="linalg_lu_solve",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_lu_solve,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Tests different backward paths"),
+                "TestCommon",
+                "test_floating_inputs_are_differentiable",
+            ),
+        ),
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+    ),
+    OpInfo(
+        "linalg.inv",
+        aten_name="linalg_inv",
+        op=torch.linalg.inv,
+        aliases=("inverse",),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.inv_ex",
+        aten_name="linalg_inv_ex",
+        op=torch.linalg.inv_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve",
+        aten_name="linalg_solve",
+        op=torch.linalg.solve,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_ex",
+        aten_name="linalg_solve_ex",
+        op=torch.linalg.solve_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_triangular",
+        aten_name="linalg_solve_triangular",
+        op=torch.linalg.solve_triangular,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve_triangular,
+        supports_fwgrad_bwgrad=True,
+        skips=(skipCPUIfNoLapack,),
+        # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
+        supports_forward_ad=True,
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_matrix_rank,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            # jit doesn't accept tensor inputs for matrix rank
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=[torch.complex64, torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        op=torch.linalg.pinv,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # errors with "leaked XXXX bytes CUDA memory on device 0"
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="singular",
+        # pinv is Frechet-differentiable in a rank-preserving neighborhood,
+        # so we feed inputs that are the products of two full-rank factors,
+        # to avoid any rank changes caused by the perturbations in the gradcheck
+        op=lambda a, b: torch.linalg.pinv(a @ b.mT),
+        dtypes=floating_and_complex_types(),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv_singular,
+        # Only large tensors show issues with implicit backward used prior to
+        # explicit backward implementation.
+        decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # CUDA runs out of memory
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+            # This test takes almost 2 hours to run!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestBwdGradients",
+                "test_fn_gradgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            # This test is flaky under slow gradcheck, likely due to rounding issues
+            DecorateInfo(
+                skipIfSlowGradcheckEnv,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svd",
+        op=torch.linalg.svd,
+        aten_name="linalg_svd",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_svd,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svdvals",
+        op=torch.linalg.svdvals,
+        aten_name="linalg_svdvals",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_svdvals,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorinv",
+        ref=np.linalg.tensorinv,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorinv,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorsolve",
+        ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorsolve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}),
+                "TestConsistency",
+                "test_output_match",
+                device_type="mps",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # torch.linalg
+    #
+    PythonRefInfo(
+        "_refs.linalg.cross",
+        torch_opinfo_name="linalg.cross",
+        supports_out=True,
+        op_db=op_db,
+        skips=(
+            # TODO: is this really needed?
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_python_ref_errors"
+            ),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.linalg.diagonal",
+        torch_opinfo_name="linalg.diagonal",
+        supports_out=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.vecdot",
+        torch_opinfo_name="linalg.vecdot",
+        op_db=op_db,
+    ),
+    ReductionPythonRefInfo(
+        "_refs.linalg.vector_norm",
+        torch_opinfo_name="linalg.vector_norm",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.matrix_norm",
+        torch_opinfo_name="linalg.matrix_norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.norm",
+        torch_opinfo_name="linalg.norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svd",
+        torch_opinfo_name="linalg.svd",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svdvals",
+        torch_opinfo_name="linalg.svdvals",
+        supports_out=True,
+        op_db=op_db,
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f58ad2d7fb890346622a68f7b743f06f4c0f894
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
@@ -0,0 +1,1594 @@
+# mypy: ignore-errors
+
+import math
+from copy import copy
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+from torch.fx.experimental.symbolic_shapes import is_nested_int
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    ReductionOpInfo,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.utils._pytree import tree_flatten, tree_map
+
+
+@dataclass
+class ExtraOpData:
+    """
+    Contains info on top of the typical OpInfo data that is useful for NJT test generation.
+
+    The process that converts the standard op_db -> an NJT-compatible op_db will attach this
+    data onto each associated OpInfo entry.
+    """
+
+    # Indicates whether the associated op is a view op
+    is_view: bool = False
+
+    # Specifies the names of any dim-related args that the op takes in. This is useful
+    # for NJT tests because there is often asymmetry across the supported set of dims for
+    # an op; it may make sense to operate over the batch dim but not the ragged dim, for
+    # example. The length of this list should match the number of relevant overloads.
+    # Each list item of the outer list should specify dim argnames. Ellipses should be used
+    # to indicate multi-dim support for a given overload.
+    #
+    # For example, squeeze() has both a dim and multi-dim overload, where the argname for
+    # each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
+    #
+    # If no overload of the op accepts dim-related args, this should be None.
+    dim_args: list[list[str]] = None
+
+    # Helper function to extract names of dim-related args.
+    # Returns: tuple of (single dim argname if available, dim list argname if available)
+    # If the op doesn't support dim-related args at all OR this op only has overloads
+    # with multiple dim args (e.g. transpose()), then this returns (None, None).
+    def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]:
+        if self.dim_args is None:
+            return (None, None)
+
+        # name for the dim arg that supports a single dim
+        single_dim_argname = None
+        # name for the dim arg that supports a list of dims
+        dimlist_argname = None
+        for overload in self.dim_args:
+            # only consider overloads with a single dim-related arg
+            if len(overload) != 1:
+                continue
+            if overload[0].endswith("..."):
+                dimlist_argname = overload[0].replace("...", "")
+                if single_dim_argname is None:
+                    single_dim_argname = dimlist_argname
+            else:
+                single_dim_argname = overload[0]
+        return (single_dim_argname, dimlist_argname)
+
+
+# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use
+# in test generation.
+extra_op_data = {
+    "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]),
+    "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]),
+    "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argmax": ExtraOpData(dim_args=[["dim"]]),
+    "argmin": ExtraOpData(dim_args=[["dim"]]),
+    "amax": ExtraOpData(dim_args=[["dim..."]]),
+    "amin": ExtraOpData(dim_args=[["dim..."]]),
+    "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argsort": ExtraOpData(dim_args=[["dim"]]),
+    "broadcast_to": ExtraOpData(is_view=True),
+    "cat": ExtraOpData(dim_args=[["dim"]]),
+    "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "conj": ExtraOpData(is_view=True),
+    "contiguous": ExtraOpData(is_view=True),
+    "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "cummax": ExtraOpData(dim_args=[["dim"]]),
+    "cummin": ExtraOpData(dim_args=[["dim"]]),
+    "cumprod": ExtraOpData(dim_args=[["dim"]]),
+    "cumsum": ExtraOpData(dim_args=[["dim"]]),
+    "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diff": ExtraOpData(dim_args=[["dim"]]),
+    "expand": ExtraOpData(is_view=True),
+    "expand_as": ExtraOpData(is_view=True),
+    "fft.fft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.hfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ifft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ihfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.irfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.rfft": ExtraOpData(dim_args=[["dim"]]),
+    "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
+    "flip": ExtraOpData(dim_args=[["dims..."]]),
+    "gather": ExtraOpData(dim_args=[["dim"]]),
+    "hash_tensor": ExtraOpData(dim_args=[["dim..."]]),
+    "imag": ExtraOpData(is_view=True),
+    "index_add": ExtraOpData(dim_args=[["dim"]]),
+    "index_copy": ExtraOpData(dim_args=[["dim"]]),
+    "index_fill": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "index_select": ExtraOpData(dim_args=[["dim"]]),
+    "kthvalue": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.cross": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]),
+    "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]),
+    "log_softmax": ExtraOpData(dim_args=[["dim"]]),
+    "logcumsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.mean": ExtraOpData(dim_args=[["dim"]]),
+    "masked.norm": ExtraOpData(dim_args=[["dim"]]),
+    "masked.prod": ExtraOpData(dim_args=[["dim"]]),
+    "masked.std": ExtraOpData(dim_args=[["dim"]]),
+    "masked.sum": ExtraOpData(dim_args=[["dim"]]),
+    "masked.var": ExtraOpData(dim_args=[["dim"]]),
+    "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "median": ExtraOpData(dim_args=[["dim"]]),
+    "mean": ExtraOpData(dim_args=[["dim..."]]),
+    "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "mode": ExtraOpData(dim_args=[["dim"]]),
+    "movedim": ExtraOpData(
+        dim_args=[["source", "destination"], ["source...", "destination..."]]
+    ),
+    "nanmean": ExtraOpData(dim_args=[["dim..."]]),
+    "nanmedian": ExtraOpData(dim_args=[["dim"]]),
+    "nansum": ExtraOpData(dim_args=[["dim..."]]),
+    "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "narrow_copy": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]),
+    "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]),
+    "positive": ExtraOpData(is_view=True),
+    "prod": ExtraOpData(dim_args=[["dim"]]),
+    "ravel": ExtraOpData(is_view=True),
+    "real": ExtraOpData(is_view=True),
+    "renorm": ExtraOpData(dim_args=[["dim"]]),
+    "reshape": ExtraOpData(is_view=True),
+    "reshape_as": ExtraOpData(is_view=True),
+    "roll": ExtraOpData(dim_args=[["dims..."]]),
+    "rot90": ExtraOpData(dim_args=[["dims..."]]),
+    "scatter": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_add": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]),
+    "select": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "select_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "slice_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "softmax": ExtraOpData(dim_args=[["dim"]]),
+    "sort": ExtraOpData(dim_args=[["dim"]]),
+    "split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]),
+    "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]),
+    "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "stack": ExtraOpData(dim_args=[["dim"]]),
+    "std": ExtraOpData(dim_args=[["dim..."]]),
+    "std.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "sum": ExtraOpData(dim_args=[["dim..."]]),
+    "t": ExtraOpData(is_view=True),
+    "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "tensordot": ExtraOpData(dim_args=[["dims..."]]),
+    "tile": ExtraOpData(dim_args=[["dims..."]]),
+    "topk": ExtraOpData(dim_args=[["dim"]]),
+    "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]),
+    "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]),
+    "trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "trapz": ExtraOpData(dim_args=[["dim"]]),
+    "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]),
+    "unfold_copy": ExtraOpData(dim_args=[["dimension"]]),
+    "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]),
+    "unsafe_split": ExtraOpData(dim_args=[["dim"]]),
+    "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]),
+    "var": ExtraOpData(dim_args=[["dim..."]]),
+    "var.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "view": ExtraOpData(is_view=True),
+    "view_as": ExtraOpData(is_view=True),
+    "view_as_complex": ExtraOpData(is_view=True),
+    "view_as_real": ExtraOpData(is_view=True),
+}
+
+
+# random integer used for sizes
+def _rnd():
+    return torch.randint(3, 8, ()).item()
+
+
+def _raggedness_matches(nt1, nt2):
+    return (
+        nt1.is_nested
+        and nt2.is_nested
+        and nt1._ragged_idx == nt2._ragged_idx
+        and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx]
+    )
+
+
+# Helper function to avoid reusing the exact same tensor / NJT across SampleInputs,
+# as this causes autograd problems.
+def _clone(t):
+    requires_grad = t.requires_grad
+    return t.detach().clone().requires_grad_(requires_grad)
+
+
+# Helper function to update a sample with new kwargs / name
+def _update_sample(sample, new_kwargs):
+    all_kwargs = dict(sample.kwargs)
+    all_kwargs.update(new_kwargs)
+    full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())])
+    return SampleInput(
+        _clone(sample.input),
+        args=sample.args,
+        kwargs=all_kwargs,
+        name=full_name,
+    )
+
+
+# Generates a random NT.
+# dims should be something like [5, None, 10], with None indicating that a
+# random ragged structure should be used
+def random_nt_from_dims(
+    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
+):
+    sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])]
+    return torch.nested.nested_tensor(
+        [torch.randn(*size) for size in sizes],
+        device=device,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+    )
+
+
+# Helper function to get a reasonable string representation of an NJT for use in
+# SampleInput names.
+def _describe_njt(njt) -> str:
+    contig_type = "_contig" if njt.is_contiguous() else "_noncontig"
+    if njt._lengths is not None and njt._offsets is not None:
+        contig_type += "_holes"
+    elif njt._ragged_idx != 1:
+        contig_type += "_transposed"
+
+    cached_data = "_without_seqlen_cache"
+    if njt._max_seqlen_tensor is not None:
+        cached_data = "_with_seqlen_cache"
+
+    return f"{njt.dim()}D{contig_type}{cached_data}"
+
+
+# Helper function to get a reasonable string representation of a given dim wrt an NJT.
+def _describe_dim(njt, dim):
+    if dim == 0:
+        return "batch_dim"
+    elif dim == njt._ragged_idx:
+        return "ragged_dim"
+    return "normal_dim"
+
+
+# Helper function for generating a comprehensive set of NJT sample inputs.
+def _sample_njts(device, dtype, requires_grad=False, dims=None):
+    if dims is None:
+        dims = [2, 3, 4]
+    if not isinstance(dims, (list, tuple)):
+        dims = [dims]
+
+    # contiguous NJTs
+    for dim in dims:
+        # with min / max seqlen cached
+        shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)])
+        nt = random_nt_from_dims(
+            shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield nt
+
+        # without min / max seqlen cached
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_(
+            requires_grad
+        )
+
+        # non-contiguous transposed NJT (not possible for 2D)
+        if dim > 2:
+            yield nt.transpose(-1, nt._ragged_idx)
+
+        # non-contiguous with holes NJT
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        # subtract 1 to cause holes
+        lengths = _clone(offsets.diff() - 1)
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        ).requires_grad_(requires_grad)
+
+
+# Computes an unbind-based reference for a given OpInfo on a given SampleInput.
+# This reference unbinds the input NJT and invokes the op on each of the components,
+# optionally wrapping the result in an NJT.
+def unbind_reference(op, sample, wrap_output_as_njt=True):
+    # first NJT in the arglist determines expected ragged structure
+    nt_inp = (
+        sample.input
+        if sample.input.is_nested
+        # TODO: look in kwargs too?
+        else next(a for a in sample.args if a.is_nested)
+    )
+
+    out_ref_components = []
+    for i in range(nt_inp.shape[0]):
+
+        def _slice_input(t, i=i, inp=nt_inp):
+            # any NJT with the same ragged structure as the input should
+            # be sliced to pass to the reference
+            if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp):
+                return t[i]
+            # allow the SampleInput to tell us how to slice it for ref calculation
+            elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"):
+                bdim = t._batch_dim  # type: ignore[attr]
+                if t.shape[bdim] == 1:
+                    return t[0]
+                else:
+                    return t.select(bdim, i)
+            else:
+                return t
+
+        inp = _slice_input(sample.input)
+        args = tree_map(_slice_input, sample.args)
+        kwargs = tree_map(_slice_input, sample.kwargs)
+
+        # Handle indices in index_put
+        if "index_put" in op.full_name and "indices" in kwargs:
+            if len(kwargs["indices"]) > 1:
+                # If after unrolling we still have indices left, use them
+                kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
+            else:
+                # If no indices are left, create them so they match the NJT implementation
+                sequence_put = kwargs["indices"][0].tolist()
+                if i in sequence_put:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            list(range(inp.shape[0])),
+                            dtype=torch.int32,
+                            device=kwargs["indices"][0].device,
+                        )
+                    ]
+                else:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            [], dtype=torch.int32, device=kwargs["indices"][0].device
+                        )
+                    ]
+
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        # Need to adjust dims to apply on NJT component
+        if op._extra_op_data.dim_args is not None:
+            # get all possible dim-related argnames that could be encountered for this op
+            argnames = tree_map(
+                lambda a: a.replace("...", ""),
+                tree_flatten(op._extra_op_data.dim_args)[0],
+            )
+            # for all dim-related args present, convert from outer -> inner dim space
+            for argname in {a for a in argnames if a in kwargs}:
+                # allow the SampleInput to tell us how to canonicalize the dim kwargs
+                ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim()
+                kwargs[argname] = _outer_to_inner_dim(
+                    ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True
+                )
+
+        out_ref_component = op.op(inp, *args, **kwargs)
+        out_ref_components.append(out_ref_component)
+
+    if wrap_output_as_njt:
+        # handle list / tuple of outputs
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (list, tuple)
+        ):
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # construct NJTs from same index returns from each invocation
+            njt_returns = [
+                torch.nested.as_nested_tensor(
+                    [o[r] for o in out_ref_components], layout=torch.jagged
+                )
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(njt_returns)
+        return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged)
+
+    return out_ref_components
+
+
+# Computes the reference value for a non-reduction unary op with dim-wise application.
+def unary_dimwise_reference(op, sample, batchwise_reference=None):
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    # only support a single non-list dim arg for now
+    assert dimlist_argname is None
+    assert single_dim_argname is not None
+    if sample.kwargs[single_dim_argname] == 0:
+        # unbind reference won't work for batch-wise operation; handle this case here
+        assert batchwise_reference is not None
+        return batchwise_reference(op, sample)
+    return unbind_reference(op, sample)
+
+
+# Computes the reference value for a reduction op.
+def reduction_reference(op, sample):
+    assert sample.input.is_nested
+
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+
+    dim = sample.kwargs.get(
+        dimlist_argname, sample.kwargs.get(single_dim_argname, None)
+    )
+    keepdim = sample.kwargs.get("keepdim", False)
+    assert dim != 0, "reductions over just the batch dim are not supported"
+    if isinstance(dim, (tuple, list)):
+        reduce_on_ragged = sample.input._ragged_idx in dim
+        reduce_on_batch = 0 in dim
+    else:
+        reduce_on_ragged = sample.input._ragged_idx == dim
+        reduce_on_batch = dim == 0
+
+    if dim is None:
+        # calculate reference value by running reduction on values buffer
+        return op.op(sample.input.values(), *sample.args, **sample.kwargs)
+
+    if reduce_on_ragged and reduce_on_batch:
+        # run reference directly on buffer with dims converted to inner space
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        ref_kwargs = dict(sample.kwargs)
+        assert dimlist_argname is not None
+        ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
+            sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True
+        )
+        out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
+        if keepdim:
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(0) for o in out)
+            else:
+                out = out.unsqueeze(0)
+        return out
+
+    if reduce_on_ragged and not reduce_on_batch:
+        # calculate reference value by running an unbind reference and stacking
+        out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False)
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (tuple, list)
+        ):
+            # some ops return multiple things; stack all of them
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # stack same index returns from each invocation
+            stacked_returns = [
+                torch.stack([o[r] for o in out_ref_components], dim=0)
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(stacked_returns)
+        return torch.stack(out_ref_components, dim=0)
+
+    # unbind reference works for other reductions
+    return unbind_reference(op, sample)
+
+
+def sample_inputs_elementwise_njt_unary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt))
+
+
+def sample_inputs_elementwise_njt_binary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt1 in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt1)
+        njt2 = torch.randn_like(njt1)
+        yield SampleInput(
+            _clone(njt1),
+            args=(njt2,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, NT)",
+        )
+
+        # broadcasting case: (B, j0, ...) with (B, 1, ...)
+        dense_shape = list(njt1.shape)
+        dense_shape[njt1._ragged_idx] = 1
+        t = torch.randn(
+            dense_shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged",
+        )
+
+        # broadcasting case: (B, j0, ...) with (1, 1...)
+        t = torch.randn(
+            [1 for _ in range(njt1.dim())],
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting all 1s",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting all 1s",
+        )
+
+        # broadcasting case: (B, j0, ...) with (...)
+        if njt1.dim() > njt1._ragged_idx + 1:
+            t = torch.randn(
+                njt1.shape[njt1._ragged_idx + 1 :],
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+            # (NT, T)
+            yield SampleInput(
+                _clone(njt1),
+                args=(_clone(t),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (NT, T) broadcasting normal dims",
+            )
+            # (T, NT)
+            yield SampleInput(
+                _clone(t),
+                args=(_clone(njt1),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (T, NT) broadcasting normal dims",
+            )
+
+        # broadcasting case: (B, j0, ...) with scalar
+        t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad)
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(_clone(t),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting with scalar",
+        )
+        # (T, NT)
+        yield SampleInput(
+            _clone(t),
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting with scalar",
+        )
+
+    # mixed broadcasting case: (B, j0, 1) with (B, 1, D)
+    B = 4
+    D = 16
+    njt = random_nt_from_dims(
+        (B, None, 1),
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        layout=torch.jagged,
+    )
+    njt_desc = _describe_njt(njt)
+    t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad)
+    t2 = _clone(t)
+    # used for slicing in unbind_reference()
+    t._batch_dim = 0
+    t2._batch_dim = 0
+
+    # (NT, T)
+    yield SampleInput(
+        _clone(njt),
+        args=(t,),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (NT, T) mixed broadcasting",
+    )
+    # (T, NT)
+    yield SampleInput(
+        t2,
+        args=(_clone(njt),),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (T, NT) mixed broadcasting",
+    )
+
+
+def sample_inputs_njt_reduction(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    supports_keepdim=True,
+    op_kwargs=None,
+    **kwargs,
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    # extract info about the dim args this op supports
+    assert op_info._extra_op_data.dim_args is not None
+    (
+        single_dim_argname,
+        dimlist_argname,
+    ) = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    supports_dimlist = dimlist_argname is not None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt)
+        keepdim_values = [False, True] if supports_keepdim else [None]
+        for keepdim in keepdim_values:
+            keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else ""
+            # single dim-wise reduction; includes reduction over the ragged dim
+            # NB: reduction over the batch dim is not supported!
+            # TODO: Cover this in the set of error inputs
+            for dim in range(1, njt.dim()):
+                dim_desc = "normal" if dim != njt._ragged_idx else "ragged"
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        single_dim_argname: dim,
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}",
+                )
+
+            if supports_dimlist:
+                # reduce on both batch and ragged dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: [0, njt._ragged_idx],
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}",
+                )
+
+                # reduce on batch, ragged, and other dims
+                for other_dim in range(njt._ragged_idx + 1, njt.dim()):
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [0, njt._ragged_idx, other_dim],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=(
+                            f"{njt_desc}: batch+ragged+dim={other_dim} "
+                            f"reduction{keepdim_suffix}"
+                        ),
+                    )
+
+                # reduce on two non-ragged, non-batch dims
+                if njt.dim() > 3 and njt._ragged_idx == 1:
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [njt.dim() - 2, njt.dim() - 1],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}",
+                    )
+
+                # full reduction by specifying all dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: list(range(njt.dim())),
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: all dim reduction{keepdim_suffix}",
+                )
+
+                # TODO: Reducing on ragged dim and non-batch dim is not supported;
+                # cover this in the set of error inputs.
+
+        # full reduction
+        yield SampleInput(
+            _clone(njt),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: full reduction with keepdim={keepdim}",
+        )
+
+
+def unsupported_sample_inputs_func(op_name):
+    def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not support NJT. Support can be added by modifying "
+            "torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+def unsupported_reference(op_name):
+    def _f(op, sample):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not define a ref() function. Support can be added by "
+            "modifying torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+def sample_inputs_unary_dimwise(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if op_kwargs is None:
+        op_kwargs = {}
+
+    # only support a single non-list dim arg for now
+    assert op_info._extra_op_data is not None
+    single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    assert dimlist_argname is None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            kwargs = {single_dim_argname: dim}
+            kwargs.update(op_kwargs)
+            yield SampleInput(
+                _clone(njt),
+                kwargs=kwargs,
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+
+def batchwise_reference_chunk(op, sample):
+    # reference for chunk() over dim=0
+    B = sample.input.size(0)
+    num_chunks = sample.kwargs["chunks"]
+    chunk_size = math.ceil(B / num_chunks)
+    num_full_chunks = B // chunk_size
+    chunk_sizes = [chunk_size for _ in range(num_full_chunks)]
+    if B % chunk_size != 0:
+        # final chunk contains the leftovers
+        chunk_sizes.append(B % chunk_size)
+
+    # split unbound components into chunks according to calculated sizes
+    components = list(sample.input.unbind())
+    start = 0
+    chunks = []
+    for chunk_size in chunk_sizes:
+        chunks.append(components[start : start + chunk_size])
+        start += chunk_size
+
+    # rejoin into NJT outputs
+    return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]
+
+
+def batchwise_reference_narrow(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_select(op, sample):
+    # reference for select() over dim=0
+    return sample.input.unbind()[sample.kwargs["index"]]
+
+
+def batchwise_reference_split(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_split_with_sizes(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unflatten(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unsqueeze(op, sample):
+    raise ValueError("unsqueeze() is not intended to operate on the batch dim")
+
+
+def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
+    # non-contiguous NJTs
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, name=_describe_njt(njt))
+
+    for memory_format in (torch.contiguous_format, torch.preserve_format):
+        # construct a "non-contiguous with holes" NJT
+        values = torch.randn(
+            10, 5, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64)
+        lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64)
+        njt = torch.nested.nested_tensor_from_jagged(
+            values, offsets=offsets, lengths=lengths
+        )
+
+        njt_desc = _describe_njt(njt)
+        yield SampleInput(
+            njt,
+            kwargs={"memory_format": memory_format},
+            name=f"{njt_desc}: {memory_format})",
+        )
+
+
+def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs):
+    # scalar case
+    unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0})
+    yield from unary_func(op_info, device, dtype, requires_grad)
+
+    # TODO: add Tensor case
+
+
+def sample_inputs_mvl_gamma(p):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p})
+
+
+def sample_inputs_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_special_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[2, 3, 4],
+    ):
+        other_dtypes = (
+            d for d in (torch.float32, torch.half, torch.double) if d is not dtype
+        )
+        for other_dtype in other_dtypes:
+            sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}"
+            yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name)
+
+        # only include device transfer for CUDA inputs
+        if "cuda" in device:
+            other_device = "cpu"
+            sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}"
+            yield SampleInput(
+                _clone(njt), kwargs={"device": other_device}, name=sample_name
+            )
+
+
+def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            B, D = njt_3d.shape[0], njt_3d.shape[-1]
+            E = D + 2
+            other = torch.randn(B, D, E, device=device, dtype=dtype)
+            # used for slicing in unbind_reference()
+            other._batch_dim = 0
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"mat2": other},
+                name=f"{njt_desc}: (B, j, D) x (B, D, E)",
+            )
+
+        # TODO (need factory functions):
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+
+
+def reference_bmm(op, sample):
+    # unbind reduces a dim and bmm requires 3D, so use matmul as the reference
+    matmul_op = copy(op)
+    matmul_op.op = torch.matmul
+    # change arg name from mat2 -> other
+    modified_sample = copy(sample)
+    other = modified_sample.kwargs["mat2"]
+    del modified_sample.kwargs["mat2"]
+    modified_sample.kwargs["other"] = other
+    return unbind_reference(matmul_op, modified_sample)
+
+
+def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single chunks value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"chunks": 3})
+        # other dim chunking: test different chunks values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for chunks in [1, D // 2, D - 1, D]:
+                yield _update_sample(sample_input, {"chunks": chunks})
+
+
+def sample_inputs_matmul(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    # also run bmm samples through
+    for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad):
+        # change arg name from mat2 -> other
+        other = sample_input.kwargs["mat2"]
+        del sample_input.kwargs["mat2"]
+        sample_input.kwargs["other"] = other
+        yield sample_input
+
+    # 3D cases not covered by bmm
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            D = njt_3d.shape[-1]
+            E = D + 2
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D) x (D, E)",
+            )
+
+    # 4D cases
+    for njt_4d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[4]
+    ):
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        if njt_4d._ragged_idx == 1:
+            E = njt_4d.shape[-1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_4d)
+            yield SampleInput(
+                _clone(njt_4d),
+                kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D, E) x (E, F)",
+            )
+
+    # Dense x NJT cases
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, F, E) x (B, E, j1) => (B, F, j1)
+        if njt_3d._ragged_idx == 2:
+            B = njt_3d.shape[0]
+            E = njt_3d.shape[1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_3d)
+            dense_t = torch.randn(
+                B, F, E, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            dense_t._batch_dim = 0  # for unbind_reference()
+            yield SampleInput(
+                dense_t,
+                args=(_clone(njt_3d),),
+                name=f"{njt_desc}: (B, F, E) x (B, E, j1)",
+            )
+
+    # NJT x NJT => Dense case
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, E, j1) x (B, j1, F) => (B, E, F)
+        if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous():
+            B, E, _ = njt_3d.shape
+            sum_j1 = len(njt_3d.values())
+            other_cont = torch.randn(
+                sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            other_njt = torch.nested.nested_tensor_from_jagged(
+                other_cont, njt_3d.offsets(), lengths=njt_3d._lengths
+            )
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": _clone(other_njt)},
+                name=f"{njt_desc}: (B, E, j1) x (B, j1, F)",
+            )
+
+        # TODO (need factory functions):
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F)
+
+
+def sample_inputs_masked_select(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2]
+    ):
+        yield SampleInput(
+            njt,
+            kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)},
+            name=_describe_njt(njt),
+        )
+
+
+def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim narrowing: test a single start, length value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"start": 1, "length": 2})
+        # other dim narrowing: test different start, length values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]:
+                yield _update_sample(sample_input, {"start": start, "length": length})
+
+
+def sample_inputs_nn_functional_embedding(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    indices = torch.nested.nested_tensor(
+        [
+            torch.tensor([0, 2, 1, 3]),
+            torch.tensor([4, 2, 1]),
+            torch.tensor([6, 7, 5, 2, 4]),
+        ],
+        layout=torch.jagged,
+        dtype=torch.int64,
+        device=device,
+    )
+
+    NUM_EMBEDDINGS = 20
+    EMBEDDING_DIM = 32
+    weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype)
+
+    # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+    # can be checked
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+    )
+
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+        kwargs={"padding_idx": 1},
+    )
+
+
+def sample_inputs_index_put(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            indices = [
+                torch.tensor(list(range(njt.size(0))), device=njt.device),
+                *[
+                    torch.tensor([0] * njt.size(0), device=njt.device)
+                    for _ in range(dim - 1)
+                ],
+            ]
+            njt_desc = _describe_njt(njt)
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "indices": indices,
+                    "values": torch.tensor(1.0, device=njt.device),
+                },
+                name=f"{njt_desc}: up to dim {dim - 1}",
+            )
+
+    # Non-cont NJT for completeness
+    offsets = torch.tensor([0, 2, 5, 7], device=device)
+    lengths = torch.tensor([2, 2, 2], device=device)
+    indices = [
+        torch.tensor([0, 1, 2], device=device),
+        torch.tensor([0, 1, 1], device=device),
+        torch.tensor([0, 0, 0], device=device),
+    ]
+    a = torch.nested.nested_tensor_from_jagged(
+        torch.zeros(7, 3, device=device), offsets, lengths
+    ).requires_grad_(requires_grad)
+
+    njt_desc = _describe_njt(a)
+    yield SampleInput(
+        _clone(a),
+        kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
+        name=f"{njt_desc}: all dims",
+    )
+
+
+def sample_inputs_nn_functional_embedding_bag(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for generate_per_sample_weight in (True, False):
+        for mode in ("sum", "mean", "max"):
+            # per_sample_weights is only supported for mode='sum'
+            if mode != "sum" and generate_per_sample_weight:
+                continue
+
+            NUM_EMBEDDINGS = 10
+            EMBEDDING_DIM = 32
+            weight = torch.randn(
+                NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device
+            )
+
+            njt = torch.nested.nested_tensor(
+                [
+                    torch.randint(0, NUM_EMBEDDINGS, size=(2,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(3,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(4,)),
+                ],
+                layout=torch.jagged,
+                dtype=torch.int64,
+                device=device,
+            )
+
+            per_sample_weights = None
+            if generate_per_sample_weight:
+                per_sample_weights = torch.randn_like(njt, dtype=dtype)
+
+            # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+            # can be checked
+            yield SampleInput(
+                weight,
+                args=(njt,),
+                kwargs={
+                    "mode": mode,
+                    "per_sample_weights": per_sample_weights,
+                },
+            )
+
+
+def reference_nn_functional_embedding_bag(op, sample):
+    # run reference on a single bag at a time
+    new_kwargs = dict(sample.kwargs)
+    new_kwargs.update(
+        {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)}
+    )
+    # flip input / weight back to what unbind_reference() expects
+    sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs)
+    old_op = op.op
+    op.op = torch.nn.functional.embedding_bag
+    output = unbind_reference(op, sample, wrap_output_as_njt=False)
+    op.op = old_op
+    # concat bag outputs to get final output
+    return torch.cat(output, dim=0)
+
+
+def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5]
+    ):
+        # projection over a ragged dim is not currently supported
+        if is_nested_int(njt.size(-1)):
+            continue
+
+        # with bias
+        NUM_OUTPUT = 10
+        weight = torch.randn(
+            NUM_OUTPUT,
+            njt.size(-1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        bias = torch.randn(
+            NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+                "bias": _clone(bias),
+            },
+            name=f"{_describe_njt(njt)}: with bias",
+        )
+
+        # without bias
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: without bias",
+        )
+
+
+def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # Second dim is interpreted as number of channels; this should be non-ragged for now
+        num_channels = njt.size(1)
+        if is_nested_int(num_channels):
+            continue
+
+        # 1D weight
+        weight = torch.randn(
+            num_channels,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: 1D weight",
+        )
+
+        # scalar tensor weight
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": torch.tensor(4.2, device=device, dtype=dtype),
+            },
+            name=f"{_describe_njt(njt)}: scalar tensor weight",
+        )
+
+
+def sample_inputs_nn_functional_rms_norm(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # normalize over non-ragged dims
+        for start_dim in range(njt.dim()):
+            if start_dim <= njt._ragged_idx:
+                continue
+
+            normalized_shape = njt.shape[start_dim:]
+            weight = torch.randn(
+                normalized_shape,
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "normalized_shape": normalized_shape,
+                    "weight": weight,
+                },
+                name=f"{_describe_njt(njt)}",
+            )
+
+
+sample_inputs_nn_functional_threshold = partial(
+    sample_inputs_elementwise_njt_unary,
+    op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9},
+)
+
+
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single index
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"index": 0})
+        # other dim chunking: test different indices
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for index in [0, D // 2, D - 1]:
+                yield _update_sample(sample_input, {"index": index})
+
+
+def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single split size
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"split_size_or_sections": 3})
+        # other dim chunking: test different split sizes
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for split_size in [1, D // 2, D - 1, D]:
+                yield _update_sample(
+                    sample_input, {"split_size_or_sections": split_size}
+                )
+
+
+def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # splits should add up to D
+        split1 = torch.randint(0, D - 1, size=()).item()
+        split2 = D - split1
+        yield _update_sample(sample_input, {"split_sizes": [split1, split2]})
+
+
+def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
+    # squeeze-specific NJT generator (need to ensure there are some 1s in the shape)
+    def _get_njts():
+        njt = random_nt_from_dims(
+            (4, None, 1, 3, 1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield njt
+        # without min / max seqlen cached
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(values, offsets)
+        # non-contiguous transposed
+        yield njt.transpose(1, 3)
+        # non-contiguous with holes
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        # subtract 1 to cause holes
+        lengths = (offsets.diff() - 1).detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        )
+
+    for njt in _get_njts():
+        # single dim operation
+        for dim in range(njt.dim()):
+            # Operation on batch / ragged dim is never expected to work.
+            # TODO: Handle these via error_inputs.
+            if dim == 0 or dim == njt._ragged_idx:
+                continue
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={"dim": dim},
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+        # multiple dim operation (pass no args)
+        yield SampleInput(
+            _clone(njt),
+            kwargs={"dim": dim},
+            name=f"{_describe_njt(njt)}: multiple dims",
+        )
+
+
+def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # sizes should multiply to be D
+        yield _update_sample(sample_input, {"sizes": [D, 1]})
+        yield _update_sample(sample_input, {"sizes": [1, D]})
+        if D % 2 == 0:
+            yield _update_sample(sample_input, {"sizes": [D // 2, 2]})
+            yield _update_sample(sample_input, {"sizes": [2, D // 2]})
+
+
+def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        yield sample_input
+
+        last_dim_sample = _update_sample(sample_input, {"dim": -1})
+        last_dim_sample.name = (
+            f"{_describe_njt(last_dim_sample.input)}: add dim to the end"
+        )
+        # Tell the unbind reference how to canonicalize the dim kwargs
+        # This is necessary because unsqueeze() allows for a dim after
+        # the last dim to indicate an unsqueeze at the end.
+        last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1
+        yield last_dim_sample
+
+
+def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
+    for sample in sample_inputs_elementwise_njt_binary(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        other = sample.args[0]
+        sample.args = ()
+        sample.kwargs["other"] = other
+        sample.kwargs["condition"] = sample.input > 0.0
+        sample.name = sample.name.replace("(", "(NT, ")
+        yield sample
+
+
+# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+
+
+# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
+# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name
+# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary
+# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
+# in alphabetical order!
+njt_sample_inputs = {
+    "bmm": sample_inputs_bmm,
+    "chunk": sample_inputs_chunk,
+    "clone": sample_inputs_clone,
+    "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False),
+    "fill": sample_inputs_fill,
+    **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
+    "nn.functional.embedding": sample_inputs_nn_functional_embedding,
+    "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag,
+    "nn.functional.linear": sample_inputs_nn_functional_linear,
+    "nn.functional.prelu": sample_inputs_nn_functional_prelu,
+    "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm,
+    "nn.functional.threshold": sample_inputs_nn_functional_threshold,
+    **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},
+    "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0),
+    "to": sample_inputs_to,
+    "matmul": sample_inputs_matmul,
+    "masked_select": sample_inputs_masked_select,
+    "narrow": sample_inputs_narrow,
+    "index_put": sample_inputs_index_put,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": sample_inputs_njt_reduction,
+    "min.reduction_with_dim": sample_inputs_njt_reduction,
+    "select": sample_inputs_select,
+    "split": sample_inputs_split,
+    "split_with_sizes": sample_inputs_split_with_sizes,
+    "squeeze": sample_inputs_squeeze,
+    "unflatten": sample_inputs_unflatten,
+    "unsqueeze": sample_inputs_unsqueeze,
+    "where": sample_inputs_where,
+}
+
+njt_references = {
+    "bmm": reference_bmm,
+    "chunk": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk
+    ),
+    "count_nonzero": reduction_reference,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": reduction_reference,
+    "min.reduction_with_dim": reduction_reference,
+    "narrow": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow
+    ),
+    "select": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_select
+    ),
+    "split": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_split
+    ),
+    "split_with_sizes": partial(
+        unary_dimwise_reference,
+        batchwise_reference=batchwise_reference_split_with_sizes,
+    ),
+    "squeeze": unbind_reference,
+    "nn.functional.embedding_bag": reference_nn_functional_embedding_bag,
+    "unflatten": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten
+    ),
+    "unsqueeze": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze
+    ),
+}
+
+
+# Translates an OpInfo entry to one that operates on NJTs.
+def translate_opinfo(op):
+    new_op = copy(op)
+    new_op.supports_njt = True
+    # add some extra info for use in generating tests on the right subset of ops
+    new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData())
+
+    if op.full_name in njt_sample_inputs:
+        new_op.sample_inputs_func = njt_sample_inputs[op.full_name]
+        new_op.ref = njt_references.get(op.full_name, unbind_reference)
+    elif isinstance(op, UnaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_unary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, BinaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_binary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, ReductionOpInfo):
+        new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None)
+        new_op.ref = reduction_reference
+    # TODO: Translate the rest of the OpInfos
+    else:
+        new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name)
+        new_op.ref = unsupported_reference(op.full_name)
+        new_op.supports_njt = False
+
+    return new_op
+
+
+njt_op_db = [translate_opinfo(op) for op in op_db]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81efd19dbc6c804f066fd89a7068dce8ecf515f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
@@ -0,0 +1,459 @@
+# mypy: ignore-errors
+
+import unittest
+from collections.abc import Callable
+from functools import partial
+from itertools import product
+
+import numpy
+
+import torch
+from torch.testing._internal.common_dtype import floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    SampleInput,
+)
+
+
+if TEST_SCIPY:
+    import scipy.signal
+
+
+def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Base function used to create sample inputs for windows.
+
+    For additional required args you should use *args, as well as **kwargs for
+    additional keyword arguments.
+    """
+
+    # Remove include_conjugated_inputs from kwargs
+    kwargs.pop("include_conjugated_inputs", None)
+    # Tests window sizes up to 5 samples.
+    for size, sym in product(range(6), (True, False)):
+        yield SampleInput(
+            size,
+            *args,
+            sym=sym,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            **kwargs,
+        )
+
+
+def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Reference inputs function to use for windows which have a common signature, i.e.,
+    window size and sym only.
+
+    Implement other special functions for windows that have a specific signature.
+    See exponential and gaussian windows for instance.
+    """
+    yield from sample_inputs_window(
+        op_info, device, dtype, requires_grad, *args, **kwargs
+    )
+
+    cases = (8, 16, 32, 64, 128, 256)
+
+    for size in cases:
+        yield SampleInput(size, sym=False)
+        yield SampleInput(size, sym=True)
+
+
+def reference_inputs_exponential_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"center": 4, "tau": 0.5}),
+        (16, {"center": 8, "tau": 2.5}),
+        (32, {"center": 16, "tau": 43.5}),
+        (64, {"center": 20, "tau": 3.7}),
+        (128, {"center": 62, "tau": 99}),
+        (256, {"tau": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        kw["center"] = None
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"std": 0.1}),
+        (16, {"std": 1.2}),
+        (32, {"std": 2.1}),
+        (64, {"std": 3.9}),
+        (128, {"std": 4.5}),
+        (256, {"std": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"beta": 2}),
+        (16, {"beta": 12}),
+        (32, {"beta": 30}),
+        (64, {"beta": 35}),
+        (128, {"beta": 41.2}),
+        (256, {"beta": 100}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_cosine_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"a": [0.5, 0.5]}),
+        (16, {"a": [0.46, 0.54]}),
+        (32, {"a": [0.46, 0.23, 0.31]}),
+        (64, {"a": [0.5]}),
+        (128, {"a": [0.1, 0.8, 0.05, 0.05]}),
+        (256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_hamming_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"alpha": 0.54}),
+        (16, {"alpha": 0.5}),
+        (32, {"alpha": 0.23}),
+        (64, {"alpha": 0.8}),
+        (128, {"alpha": 0.9}),
+        (256, {"alpha": 0.05}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def error_inputs_window(op_info, device, *args, **kwargs):
+    # Tests for windows that have a negative size
+    yield ErrorInput(
+        SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="requires non-negative window length, got M=-1",
+    )
+
+    # Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo.
+    yield ErrorInput(
+        SampleInput(
+            3,
+            *args,
+            layout=torch.sparse_coo,
+            device=device,
+            dtype=torch.float32,
+            **kwargs,
+        ),
+        error_type=ValueError,
+        error_regex="is implemented for strided tensors only, got: torch.sparse_coo",
+    )
+
+    # Tests for window tensors that are not floating point dtypes, for instance, torch.long.
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.long, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.int64",
+    )
+
+    # Tests for window tensors that are bfloat16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.bfloat16",
+    )
+
+    # Tests for window tensors that are float16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.float16",
+    )
+
+
+def error_inputs_exponential_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, **kwargs)
+
+    # Tests for negative decay values.
+    yield ErrorInput(
+        SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Tau must be positive, got: -1 instead.",
+    )
+
+    # Tests for symmetric windows and a given center value.
+    yield ErrorInput(
+        SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device),
+        error_type=ValueError,
+        error_regex="Center must be None for symmetric windows",
+    )
+
+
+def error_inputs_gaussian_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, std=0.5, **kwargs)
+
+    # Tests for negative standard deviations
+    yield ErrorInput(
+        SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Standard deviation must be positive, got: -1 instead.",
+    )
+
+
+def error_inputs_kaiser_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, beta=12, **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="beta must be non-negative, got: -1 instead.",
+    )
+
+
+def error_inputs_general_cosine_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs),
+        error_type=TypeError,
+        error_regex="Coefficients must be a list/tuple",
+    )
+
+    yield ErrorInput(
+        SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Coefficients cannot be empty",
+    )
+
+
+def reference_signal_window(fn: Callable):
+    r"""Wrapper for scipy signal window references.
+
+    Discards keyword arguments for window reference functions that don't have a matching signature with
+    torch, e.g., gaussian window.
+    """
+
+    def _fn(
+        *args,
+        dtype=numpy.float64,
+        device=None,
+        layout=torch.strided,
+        requires_grad=False,
+        **kwargs,
+    ):
+        r"""The unused arguments are defined to disregard those values"""
+        return fn(*args, **kwargs).astype(dtype)
+
+    return _fn
+
+
+def make_signal_windows_opinfo(
+    name: str,
+    ref: Callable,
+    sample_inputs_func: Callable,
+    reference_inputs_func: Callable,
+    error_inputs_func: Callable,
+    *,
+    skips: tuple[DecorateInfo, ...] = (),
+):
+    r"""Helper function to create OpInfo objects related to different windows."""
+    return OpInfo(
+        name=name,
+        ref=ref if TEST_SCIPY else None,
+        dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_func,
+        reference_inputs_func=reference_inputs_func,
+        error_inputs_func=error_inputs_func,
+        supports_out=False,
+        supports_autograd=False,
+        skips=(
+            # TODO: same as this?
+            # https://github.com/pytorch/pytorch/issues/81774
+            # also see: arange, new_full
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # skip these tests since we have non tensor input
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_vmap_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_op_has_batch_rule",
+            ),
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            *skips,
+        ),
+    )
+
+
+op_db: list[OpInfo] = [
+    make_signal_windows_opinfo(
+        name="signal.windows.hamming",
+        ref=reference_signal_window(scipy.signal.windows.hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.hann",
+        ref=reference_signal_window(scipy.signal.windows.hann) if TEST_SCIPY else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.bartlett",
+        ref=reference_signal_window(scipy.signal.windows.bartlett)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.blackman",
+        ref=reference_signal_window(scipy.signal.windows.blackman)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.cosine",
+        ref=reference_signal_window(scipy.signal.windows.cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.exponential",
+        ref=reference_signal_window(scipy.signal.windows.exponential)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, tau=2.78),
+        reference_inputs_func=partial(reference_inputs_exponential_window, tau=2.78),
+        error_inputs_func=error_inputs_exponential_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.gaussian",
+        ref=reference_signal_window(scipy.signal.windows.gaussian)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, std=1.92),
+        reference_inputs_func=partial(reference_inputs_gaussian_window, std=1.92),
+        error_inputs_func=error_inputs_gaussian_window,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.kaiser",
+        ref=reference_signal_window(scipy.signal.windows.kaiser)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, beta=12.0),
+        reference_inputs_func=partial(reference_inputs_kaiser_window, beta=12.0),
+        error_inputs_func=error_inputs_kaiser_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_cosine",
+        ref=reference_signal_window(scipy.signal.windows.general_cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, a=[0.54, 0.46]),
+        reference_inputs_func=partial(
+            reference_inputs_general_cosine_window, a=[0.54, 0.46]
+        ),
+        error_inputs_func=error_inputs_general_cosine_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_hamming",
+        ref=reference_signal_window(scipy.signal.windows.general_hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, alpha=0.54),
+        reference_inputs_func=partial(
+            reference_inputs_general_hamming_window, alpha=0.54
+        ),
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.nuttall",
+        ref=reference_signal_window(scipy.signal.windows.nuttall)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..200a3ad9ed902962edcc2da0153117e83d64131a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
@@ -0,0 +1,928 @@
+# mypy: ignore-errors
+
+import os
+
+import torch
+from torch.testing import make_tensor  # noqa: F401
+from torch.testing._internal.opinfo.core import (  # noqa: F401
+    BinaryUfuncInfo,
+    ErrorInput,
+    generate_elementwise_binary_tensors,
+    ReductionOpInfo,
+    sample_inputs_reduction,
+    SampleInput,
+)
+
+
+def _check_validate(op_info, sample):
+    def _check_fail(sample):
+        try:
+            op_info(
+                sample.sample_input.input,
+                *sample.sample_input.args,
+                **sample.sample_input.kwargs,
+            )
+        except sample.error_type:
+            pass
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}"
+            )
+        else:
+            raise AssertionError(
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got none."
+            )
+
+    def _check_success(sample):
+        try:
+            op_info(sample.input, *sample.args, **sample.kwargs)
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample=} expected to succeed "
+                f", got {type(msg).__name__}: {msg}"
+            )
+
+    if isinstance(sample, ErrorInput):
+        _check_fail(sample)
+    else:
+        _check_success(sample)
+
+
+def _sample_inputs_sparse(
+    sample_inputs,
+    maybe_failing_sample_inputs,
+    validate_sample_input,
+    op_info,
+    *args,
+    **kwargs,
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+        # Error inputs are handled in error_inputs_sparse
+
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+
+
+def _error_inputs_sparse(
+    maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, ErrorInput):
+            yield sample
+        # Sample inputs are handled in sample_inputs_sparse
+
+
+def _apply_requires_grad_to_samples(sample_inputs):
+    """Decorator to _maybe_failing_sample_inputs_... generator functions
+    that clones and sets requires_grad argument to tensors in sample
+    input arguments. This is needed when the generated samples share
+    tensor instances.
+    """
+
+    def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs):
+        def apply_requires_grad(x):
+            if (
+                not isinstance(x, torch.Tensor)
+                or x.requires_grad
+                or not requires_grad
+                or not (x.is_floating_point() or x.is_complex())
+            ):
+                return x
+            return x.detach().clone().requires_grad_(requires_grad)
+
+        if requires_grad:
+            for sample_input in sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            ):
+                yield sample_input.transform(apply_requires_grad)
+        else:
+            yield from sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            )
+
+    return wrapper
+
+
+def sample_inputs_sparse_reduction(
+    op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs
+):
+    """Sample inputs for reduction operations on sparse tensors."""
+    layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0]
+    op_supports_layout = getattr(op_info, "supports_" + layout_name)
+    if not op_supports_layout:
+        return
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        if sample_input.input.ndim == 0:
+            # scalar sparse tensors are not supported
+            continue
+
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if sample_input.input.ndim < 2:
+                # conversion to sparse compressed tensors requires at
+                # least 2 dimensional tensors
+                continue
+            if sample_input.input.ndim > 2 and (sample_input.input == 0).any():
+                # Skip batched sparse compressed samples that contain
+                # explicit zeros because to_sparse(layout=..) will
+                # fail, see gh-98495.
+                # TODO: remove this if-block after gh-98495 is fixed.
+                continue
+
+        if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None:
+            blocksize = (1, 1)
+
+        yield SampleInput(
+            sample_input.input.detach()
+            .to_sparse(layout=layout, blocksize=blocksize)
+            .requires_grad_(requires_grad),
+            args=sample_input.args,
+            kwargs=sample_input.kwargs,
+        )
+
+        if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
+            # uncoalesced samples
+            inp = sample_input.input.detach().to_sparse(layout=layout)
+            inp = torch.sparse_coo_tensor(
+                inp.indices().repeat(1, 2),
+                inp.values().repeat(2),
+                inp.shape,
+                dtype=inp.dtype,
+                device=inp.device,
+            )
+            assert not inp.is_coalesced()
+            yield SampleInput(
+                inp.requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+        if sample_input.input.ndim > 2:
+            # hybrid samples
+            yield SampleInput(
+                sample_input.input.detach()
+                .to_sparse(
+                    layout=layout,
+                    blocksize=blocksize,
+                    dense_dim=sample_input.input.ndim - 2,
+                )
+                .requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+
+def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    UNSPECIFIED = object()
+    if op_info.name == "sum":
+        sample = _validate_sample_input_sparse_reduction_sum(sample)
+
+    if op_info.name == "masked.sum":
+        mask = sample.kwargs.get("mask", UNSPECIFIED)
+        if (
+            mask not in {None, UNSPECIFIED}
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif mask is UNSPECIFIED:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+
+    if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}:
+        t_inp = sample.input
+        mask = sample.kwargs.get("mask")
+        if (
+            mask is not None
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif mask is None:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif (
+            mask.layout is sample.input.layout
+            and mask.ndim > 2
+            and op_info.name == "masked.mean"
+        ):
+            sample = ErrorInput(
+                sample,
+                error_type=TypeError,
+                error_regex=(
+                    "where[(][)] received an invalid combination of arguments"
+                    " - got [(]Tensor, Tensor, NoneType[)]"
+                ),
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif (
+            sample.input.ndim > 2
+            and (sample.kwargs.get("dim") not in {0, 1})
+            and mask.ndim > 2
+            and mask.layout is not torch.strided
+        ):
+            if sample.kwargs.get("dim") == (0, -1):
+                sample = ErrorInput(
+                    sample,
+                    error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities",
+                )
+            elif op_info.name == "masked.prod":
+                sample = ErrorInput(
+                    sample,
+                    error_regex="input_dim == 2 INTERNAL ASSERT FAILED at",
+                )
+            else:
+                sample = ErrorInput(
+                    sample,
+                    error_type=AssertionError,
+                    error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.",
+                )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+        elif (
+            mask.layout is t_inp.layout
+            and mask._nnz() != t_inp._nnz()
+            and t_inp.dense_dim() > 0
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="Index tensor must have the same number of dimensions as src tensor",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+
+    return sample
+
+
+def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_kwargs = sample.input, sample.kwargs
+    dim = t_kwargs.get("dim")
+    keepdim = t_kwargs.get("keepdim")
+    layout = t_inp.layout
+    if isinstance(dim, (int, list, tuple)):
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout"
+                    ),
+                )
+            if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "reduction operations on CSR tensors with keepdim=False is unsupported"
+                    ),
+                )
+            if t_inp.dim() != 2:
+                return ErrorInput(
+                    sample,
+                    error_regex=("input_dim == 2 INTERNAL ASSERT"),
+                )
+            if layout == torch.sparse_csr:
+                if t_inp.dtype == torch.bool:
+                    return ErrorInput(
+                        sample,
+                        error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"),
+                    )
+                if t_inp.dtype == torch.complex32:
+                    return ErrorInput(
+                        sample,
+                        error_regex=(
+                            "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'"
+                        ),
+                    )
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+    if layout in [
+        torch.sparse_csr,
+        torch.sparse_csc,
+    ]:
+        # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend.
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+    if layout in [
+        torch.sparse_bsr,
+        torch.sparse_bsc,
+    ]:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(2, 2))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+
+def sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for sum on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_reduction,
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs):
+    """Error inputs for sum on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def sample_inputs_sparse_elementwise_binary_operation(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for elementwise binary operations on sparse tensors.
+
+    The samples include regular, zero-sized, batched, and hybrid
+    sparse tensors as well as rhs scalars. All tensors are full tensors.
+    """
+
+    def _to_sparse(tensor, **kwargs):
+        return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad)
+
+    for sample_input in generate_elementwise_binary_tensors(
+        op_info,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=True,
+        **kwargs,
+    ):
+        lhs, rhs = sample_input.input, sample_input.args[0]
+        min_dense_dim = 0
+        max_dense_dim = lhs.ndim - 1
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if lhs.ndim < 2:
+                # sparse compressed tensors sparse_dim must be 2
+                continue
+            max_dense_dim = lhs.ndim - 2
+
+        for dense_dim in range(min_dense_dim, max_dense_dim + 1):
+            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
+                blocksizes = [(1, 1)]
+                if lhs.numel() > 0:
+                    blocksizes.append(
+                        (
+                            lhs.shape[lhs.ndim - 2 - dense_dim],
+                            lhs.shape[lhs.ndim - 1 - dense_dim],
+                        )
+                    )
+            else:
+                blocksizes = [None]
+            for blocksize in blocksizes:
+                to_sparse_kwargs = dict(
+                    layout=layout, dense_dim=dense_dim, blocksize=blocksize
+                )
+                lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs)
+                rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs)
+                # op(sparse, sparse)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(rhs_sparse, *sample_input.args[1:]),
+                    kwargs=sample_input.kwargs,
+                )
+                # op(sparse, scalar)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(
+                        make_tensor(
+                            (), dtype=dtype, device=device, requires_grad=requires_grad
+                        ),
+                        *sample_input.args[1:],
+                    ),
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def _validate_sample_input_elementwise_binary_sparse_mul(sample):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_args = sample.input, sample.args
+    batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
+    layout = t_inp.layout
+    dtype = t_inp.dtype
+    if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
+                " tensors with sparse_dim[(][)]!=2 is not supported"
+            ),
+        )
+    elif layout is torch.sparse_csc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample, error_regex="Expected result Tensor to be of format CSR"
+        )
+    elif layout is torch.sparse_bsr and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr",
+        )
+    elif layout is torch.sparse_bsc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc",
+        )
+    elif (
+        layout is torch.sparse_coo
+        and dtype is torch.bool
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+        and t_inp.dense_dim() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'"
+        )
+    elif (
+        layout in {torch.sparse_coo, torch.sparse_csr}
+        and dtype is torch.bool
+        and t_inp._nnz() > 0
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'"
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_args[0].layout is torch.strided
+        and 0 < t_args[0].ndim
+        and t_args[0].ndim < t_inp.ndim
+    ):
+        return ErrorInput(
+            sample, error_regex="sparse_mask_sparse_csr expects self to be 2D"
+        )
+    elif layout is torch.sparse_csr and (
+        (t_args[0].layout is torch.strided and 0 < t_args[0].ndim)
+        or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape)
+    ):
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "expects sparse inputs with equal dimensionality, number of sparse dimensions,"
+                " and shape of sparse dimensions"
+            ),
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_inp.dense_dim() > 0
+        and t_inp._nnz() > 0
+        and t_inp.is_cpu
+        and dtype is torch.float16
+        and t_args[0].ndim > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'"
+        )
+    return sample
+
+
+@_apply_requires_grad_to_samples
+def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+
+    blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+    regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse(
+        layout=layout, dense_dim=0, blocksize=blocksize
+    )
+    batch = torch.tensor(
+        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize)
+    hybrid = torch.tensor(
+        [[[1], [2]], [[3], [4]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize)
+
+    if layout is torch.sparse_csr:
+        # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor
+        yield SampleInput(batch, args=(batch,))
+        # RuntimeError: Only tensors with two sparse dimensions can be
+        # converted to the SparseCsr layout, got self with 3 sparse
+        # dimensions.
+        yield SampleInput(
+            torch.zeros_like(hybrid).requires_grad_(requires_grad),
+            args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),),
+        )
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_csc:
+        # RuntimeError: Expected result Tensor to be of format CSR
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsr:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsc:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_coo:
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+        if dtype in {torch.bool, torch.float16} and regular.is_cpu:
+            # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)'
+            yield SampleInput(hybrid, args=(hybrid,))
+
+
+def _validate_sample_input_sparse_elementwise_binary_operation(
+    op_info, sample, check_validate=False
+):
+    if op_info.name == "mul":
+        sample = _validate_sample_input_elementwise_binary_sparse_mul(sample)
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs):
+    """Sample inputs for mul operation on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_elementwise_binary_operation,
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
+    """Error inputs for mul operation on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    from torch.testing._internal.common_utils import TestCase
+
+    for tensor in TestCase().generate_simple_inputs(
+        layout,
+        device=device,
+        dtype=dtype,
+        enable_batch=True,
+        enable_hybrid=True,
+        enable_zero_sized=True,
+        enable_non_contiguous_indices=False,
+        enable_non_contiguous_values=False,
+    ):
+        yield SampleInput(tensor, args=(), kwargs={})
+        yield SampleInput(
+            tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
+        )
+
+        if dtype is not torch.float64:
+            yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
+
+        if torch.cuda.is_available():
+            other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
+            yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
+
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+        yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
+
+        if layout is not torch.sparse_coo:
+            yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
+
+
+def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
+    if (
+        sample.input.layout
+        in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }
+        and op_info.name != "zeros_like"
+    ):
+        if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
+            return ErrorInput(
+                sample,
+                error_regex=(
+                    "empty_like with different sparse layout is not supported"
+                    " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
+                ),
+            )
+    if sample.input.layout is torch.sparse_coo:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
+            ),
+        )
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    if torch.cuda.is_available() and layout is not torch.sparse_coo:
+        other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+
+        blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(device=other_device),
+        )
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(layout=other_layout),
+        )
+
+
+def sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for like-functions on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        _sample_inputs_sparse_like_fns,
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
+    """Error inputs for like-functions on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
+    if op_info.name == "to_sparse":
+        if (
+            sample.input.layout
+            in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+            and len(sample.args) == 1
+            and isinstance(sample.args[0], int)
+            and sample.args[0] != 2
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def validate_sample_input_sparse(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    if isinstance(op_info, ReductionOpInfo):
+        return _validate_sample_input_sparse_reduction(
+            op_info, sample, check_validate=check_validate
+        )
+    elif isinstance(op_info, BinaryUfuncInfo):
+        return _validate_sample_input_sparse_elementwise_binary_operation(
+            op_info, sample, check_validate=check_validate
+        )
+    else:
+        return _validate_sample_input_sparse_default(
+            op_info, sample, check_validate=check_validate
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
new file mode 100644
index 0000000000000000000000000000000000000000..47cbcb1fb4268aa8261e38cd6b197a15c39a4428
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
@@ -0,0 +1,805 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+from itertools import product
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import (
+    precisionOverride,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import all_types_and, floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    DecorateInfo,
+    L,
+    NumericsFilter,
+    OpInfo,
+    S,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.testing._internal.opinfo.refs import (
+    ElementwiseBinaryPythonRefInfo,
+    ElementwiseUnaryPythonRefInfo,
+)
+from torch.testing._internal.opinfo.utils import (
+    np_unary_ufunc_integer_promotion_wrapper,
+)
+
+
+if TEST_SCIPY:
+    import scipy.special
+
+
+# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
+#       supports `exclude` argument.
+#       For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
+def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
+    exclude_zero = requires_grad and op_info.op is torch.special.i0e
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    yield SampleInput(make_arg((S,)))
+    yield SampleInput(make_arg(()))
+
+    if requires_grad and not exclude_zero:
+        # Special Case for gradient
+        # Sample with `0` in the input
+        t = make_arg((S,))
+        t[0] = 0
+
+        yield SampleInput(t)
+
+
+def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        # TODO: eliminate low after gh-106692 is fixed:
+        low=(1 if dtype in {torch.int32, torch.int64} else None),
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    tensor_shapes = ((S, S), ())
+    ns = (1, 2, 3, 4, 5)
+
+    for shape, n in product(tensor_shapes, ns):
+        yield SampleInput(make_arg(shape), args=(n,))
+
+
+def reference_polygamma(x, n):
+    # WEIRD `scipy.special.polygamma` behavior
+    # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
+    # dtype('float64')
+    # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
+    # dtype('float32')
+    #
+    # Thus we cast output to the default torch dtype or preserve double
+    result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
+    if x.dtype == np.double:
+        result_dtype = np.double
+    return scipy.special.polygamma(n, x).astype(result_dtype)
+
+
+def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
+    low, _ = op_info.domain
+
+    if requires_grad:
+        low = 0 + op_info._domain_eps
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((L,)))
+    yield SampleInput(make_arg(()))
+
+
+def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
+    for shape in ((L,), (1, 0, 3), ()):
+        yield SampleInput(
+            make_tensor(
+                shape,
+                device=device,
+                dtype=dtype,
+                low=-5,
+                requires_grad=requires_grad,
+            ),
+        )
+
+
+op_db: list[OpInfo] = [
+    UnaryUfuncInfo(
+        "special.i0e",
+        aten_name="special_i0e",
+        ref=scipy.special.i0e if TEST_SCIPY else None,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1",
+        aten_name="special_i1",
+        ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
+        if TEST_SCIPY
+        else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1e",
+        aten_name="special_i1e",
+        ref=scipy.special.i1e if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtr",
+        aten_name="special_ndtr",
+        decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
+        ref=scipy.special.ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Dispatch stub: unsupported device typemeta
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="meta",
+            ),
+        ),
+    ),
+    # A separate OpInfo entry for special.polygamma is needed to reorder the arguments
+    # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
+    UnaryUfuncInfo(
+        "special.polygamma",
+        op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
+        variant_test_name="special_polygamma_n_0",
+        ref=reference_polygamma if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_polygamma,
+        skips=(
+            # lambda impl
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+        ),
+        sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
+        # polygamma functions have multiple singularities at x having non-positive integer value
+        reference_numerics_filter=NumericsFilter(
+            condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
+        ),
+    ),
+    BinaryUfuncInfo(
+        "special.xlog1py",
+        aten_name="special_xlog1py",
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        promotes_int_to_float=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_one_python_scalar=True,
+        # We don't test -1 as the gradient will be NaN and it'll break
+        rhs_make_tensor_kwargs=dict(low=-0.99),
+    ),
+    BinaryUfuncInfo(
+        "special.zeta",
+        aten_name="special_zeta",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        supports_autograd=False,
+        supports_one_python_scalar=True,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    # TODO: FIXME
+    # OpInfo entry to verify the gradient formula of `other`/`q`
+    # BinaryUfuncInfo('special.zeta',
+    #                 op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
+    #                 aten_name='special_zeta',
+    #                 variant_test_name='grad',
+    #                 dtypes=all_types_and(torch.bool),
+    #                 promotes_int_to_float=True,
+    #                 supports_autograd=True,
+    #                 supports_rhs_python_scalar=False,
+    #                 decorators=[
+    #                     # Derivative wrt first tensor not implemented
+    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
+    #                                  "test_floating_inputs_are_differentiable")
+    #                 ],
+    #                 skips=(
+    #                     # Lambda doesn't work in JIT test
+    #                     # AssertionError: JIT Test does not execute any logic
+    #                     DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
+    #                 )),
+    UnaryUfuncInfo(
+        "special.entr",
+        ref=scipy.special.entr if TEST_SCIPY else None,
+        aten_name="special_entr",
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+        supports_inplace_autograd=False,
+        sample_inputs_func=sample_inputs_entr,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtri",
+        ref=scipy.special.ndtri if TEST_SCIPY else None,
+        domain=(0, 1),
+        aten_name="special_ndtri",
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.log_ndtr",
+        aten_name="special_log_ndtr",
+        ref=scipy.special.log_ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.erfcx",
+        ref=scipy.special.erfcx if TEST_SCIPY else None,
+        aten_name="special_erfcx",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_erfcx,
+    ),
+    UnaryUfuncInfo(
+        "special.airy_ai",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+            ),
+        ),
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_h",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: inf
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+            # Too slow
+            DecorateInfo(
+                unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_he",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: inf
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.laguerre_polynomial_l",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+            # Too slow
+            DecorateInfo(
+                unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.legendre_polynomial_p",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k1",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: nan
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.spherical_bessel_j0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Scipy doesn't support bool inputs to spherical_bessel_j0"
+                ),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_normal",
+                dtypes=(torch.bool,),
+            ),
+        ),
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # Elementwise Unary Special OpInfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j0",
+        torch_opinfo_name="special.bessel_j0",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j1",
+        torch_opinfo_name="special.bessel_j1",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.entr",
+        torch_opinfo_name="special.entr",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.erfcx",
+        torch_opinfo_name="special.erfcx",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i0e",
+        torch_opinfo_name="special.i0e",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1",
+        torch_opinfo_name="special.i1",
+        op_db=op_db,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1e",
+        torch_opinfo_name="special.i1e",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.log_ndtr",
+        torch_opinfo_name="special.log_ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtr",
+        torch_opinfo_name="special.ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtri",
+        torch_opinfo_name="special.ndtri",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.spherical_bessel_j0",
+        torch_opinfo_name="special.spherical_bessel_j0",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Scipy doesn't support bool inputs to spherical_bessel_j0"
+                ),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_normal",
+                dtypes=(torch.bool,),
+            ),
+        ),
+    ),
+    #
+    # Elementwise Binary Special OpInfos
+    #
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.special.zeta",
+        torch_opinfo_name="special.zeta",
+        supports_one_python_scalar=True,
+        op_db=op_db,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..975cd0d852aa6d42bafa166d2da2f1fd1df353cf
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2edf1f62d3b1967b0ea7d5268c26ee086bcdc11
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3b1cd47440cd69a8d56d63f970cca670657c7c8
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88a25cc5e97f6a7e414d43dabdb8e37c89c36206
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c05fe851547a7b53b80ac98b5c72433ded4b797
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/generate_tests.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c4902fc539d1980a7828a36e2a805b9d8209522
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/optests/__pycache__/make_fx.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/static_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a031b0d8f6e685517b7ac51c236e23835501cd9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
@@ -0,0 +1,27 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["module: unknown"]
+
+import torch
+
+
+class StaticModule:
+    def __init__(self, scripted):
+        # this is an nn.Module
+        if hasattr(scripted, "_c"):
+            self.static_module = torch._C._jit_to_static_module(scripted._c)
+        else:
+            self.static_module = torch._C._jit_to_static_module(scripted.graph)
+
+    def __call__(self, *args, **kwargs):
+        return self.static_module(*args, **kwargs)
+
+    def benchmark(self, args, kwargs, warmup_runs, main_runs):
+        self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
+
+    def runAsync(self, args, kwargs):
+        return self.static_module.runAsync(args, kwargs)
+
+    def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
+        return self.static_module.benchmark_individual_ops(
+            args, kwargs, warmup_runs, main_runs
+        )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
new file mode 100644
index 0000000000000000000000000000000000000000..228f98139fea5adc1078cdcf7ede2a4adc4d6ede
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
@@ -0,0 +1,78 @@
+# mypy: ignore-errors
+from typing import Any, Optional
+
+import torch
+import torch.utils._pytree as pytree
+from torch._subclasses.fake_tensor import is_fake
+from torch.testing._internal.two_tensor import TwoTensor
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+class WrapperSubclass(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, outer_size=None, outer_stride=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs)
+
+        return out
+
+    def __init__(self, a, outer_size=None, outer_stride=None):
+        self.a = a
+
+    def __repr__(self):
+        return f"WrapperSubclass({repr(self.a)})"
+
+    def __tensor_flatten__(self):
+        return ["a"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a = inner_tensors["a"]
+        if is_fake(a):
+            assert outer_size is not None
+            assert outer_stride is not None
+        return WrapperSubclass(a, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
+
+        kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_flat = [
+            WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a in out_a_flat
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def __coerce_same_metadata_as_tangent__(
+        self, expected_metadata: Any, expected_type: Optional[type] = None
+    ):
+        if expected_type is type(self.a):
+            return self.a
+        elif expected_type is TwoTensor:
+            return TwoTensor(self.a, self.a.clone())
+
+        return None
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d6394b25a6bbd81d2220b2bf965fe8e4dd4cb4b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da4e524e67d048a177a4c57811dc433b3fafd1a4
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b053e1cf27c0e33dba85f4d22e660b3fda867ec7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0fcbaee30f52a9a0d0f7e72aeaf99582d49f1e0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
@@ -0,0 +1,1043 @@
+# mypy: ignore-errors
+
+import unittest
+
+from torch.testing._internal.inductor_utils import (
+    HAS_CUDA_AND_TRITON,
+    HAS_GPU,
+    HAS_XPU_AND_TRITON,
+)
+from torch.utils._triton import has_triton
+
+
+requires_cuda_and_triton = unittest.skipUnless(
+    HAS_CUDA_AND_TRITON, "requires cuda and triton"
+)
+requires_gpu_and_triton = unittest.skipUnless(
+    HAS_XPU_AND_TRITON or HAS_CUDA_AND_TRITON, "requires gpu and triton"
+)
+requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
+
+if has_triton():
+    import triton
+    from triton import language as tl
+
+    import torch
+
+    def _get_strange_configs() -> list[triton.Config]:
+        if torch.version.hip:
+            configs = [
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 16,
+                        "BLOCK_SIZE_N": 16,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                        "matrix_instr_nonkdim": 16,
+                        "waves_per_eu": 3,
+                        "kpack": 2,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 128,
+                        "BLOCK_SIZE_N": 64,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                        "matrix_instr_nonkdim": 16,
+                        "waves_per_eu": 3,
+                        "kpack": 2,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+            ]
+        else:
+            configs = [
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 16,
+                        "BLOCK_SIZE_N": 16,
+                        "BLOCK_SIZE_K": 16,
+                        "GROUP_SIZE_M": 4,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+                triton.Config(
+                    {
+                        "BLOCK_SIZE_M": 128,
+                        "BLOCK_SIZE_N": 64,
+                        "BLOCK_SIZE_K": 32,
+                        "GROUP_SIZE_M": 8,
+                    },
+                    num_stages=4,
+                    num_warps=4,
+                ),
+            ]
+        return configs
+
+    # Define here so that multiple tests can take advantage of it
+    @triton.jit
+    def add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def sub_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_optional_param(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_none_param_and_equal_to_1_arg(
+        in_ptr0,
+        in_ptr1,  # in_ptr1 could be None
+        out_ptr,
+        n_elements,
+        stride,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets * stride, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets * stride, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def sub_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned_weird_param_order(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        out_ptr,
+    ):
+        # out_ptr is after an autotuned param that's declared as tl.constexpr.
+        # This param ordering can create bugs if not handled correctly.
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
+            ),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_2d_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        x_elements,
+        y_elements,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        xoffset = tl.program_id(0) * BLOCK_SIZE_X
+        xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
+        xmask = xindex < x_elements
+        yoffset = tl.program_id(1) * BLOCK_SIZE_Y
+        yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
+        ymask = yindex < y_elements
+        x1 = xindex
+        y0 = yindex
+        tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
+        tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
+        tmp2 = tmp0 + tmp1
+        tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
+
+    def _dummy_early_config_prune(configs, *_, **__):
+        return configs
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+        warmup=10,
+        rep=20,
+        prune_configs_by={"early_config_prune": _dummy_early_config_prune},
+    )
+    @triton.jit
+    def add_kernel_autotuned_with_unsupported_args(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_scaling(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        scaling_factor,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = (x + y) * scaling_factor
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_tma_1d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+
+        output = a + b
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+        y = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+
+        output = x + y
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_1d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset],
+        )
+        b = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset],
+        )
+
+        output = a + b
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+        )
+        y = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+        )
+
+        output = x + y
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_old_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        a_desc_ptr = workspace
+        b_desc_ptr = workspace + 128
+        c_desc_ptr = workspace + 256
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=a_desc_ptr,
+            global_address=a_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=a_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=b_desc_ptr,
+            global_address=b_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=b_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=c_desc_ptr,
+            global_address=c_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=c_ptr.dtype.element_ty,
+        )
+
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors
+        a = tl._experimental_descriptor_load(
+            a_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            b_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result
+        tl._experimental_descriptor_store(
+            c_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_new_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,  # unused but left here to match the old API kernel
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        # Create tensor descriptors using the new API
+        a_desc = tl.make_tensor_descriptor(
+            base=a_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        b_desc = tl.make_tensor_descriptor(
+            base=b_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        c_desc = tl.make_tensor_descriptor(
+            base=c_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors with the new API
+        a = tl.load_tensor_descriptor(
+            a_desc,
+            [offset_x, offset_y],
+        )
+        b = tl.load_tensor_descriptor(
+            b_desc,
+            [offset_x, offset_y],
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result with the new API
+        tl.store_tensor_descriptor(
+            c_desc,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def mul2_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        output = 2 * x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def mul2_inplace_kernel(
+        ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(ptr + offsets, mask=mask)
+        output = 2 * x
+        tl.store(ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def zero_negs(x):
+        return tl.where(x >= 0, x, 0)
+
+    @triton.jit
+    def indirection_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        ACTIVATION: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        if ACTIVATION == "mul2_inplace_kernel":
+            mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        elif ACTIVATION == "add_kernel":
+            add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        tl.store(out_ptr + offsets, x, mask=mask)
+
+    @triton.jit
+    def double_strided_kernel(
+        in_ptr,
+        out_ptr,
+        in_y_stride,
+        out_y_stride,
+        X_BLOCK_SIZE: "tl.constexpr",
+        Y_BLOCK_SIZE: "tl.constexpr",
+    ):
+        xid = tl.program_id(axis=0)
+        yid = tl.program_id(axis=1)
+        x_start = xid * X_BLOCK_SIZE
+        y_start = yid * Y_BLOCK_SIZE
+        x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
+        y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
+        src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
+        dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
+        src = tl.load(in_ptr + src_offsets)
+        tl.store(out_ptr + dst_offsets, src * 2.0)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_true(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_false(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=False,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def add_kernel_with_block_ptr(
+        x_ptr,
+        y_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        y = tl.load(
+            tl.make_block_ptr(
+                base=y_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        output = x + y
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    @triton.jit
+    def kernel_with_block_ptr_2d(
+        x_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            boundary_check=[0],
+        )
+        output = x
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    from triton.language import load, store
+
+    @triton.jit
+    def add_kernel_with_import(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = load(in_ptr0 + offsets, mask=mask)
+        y = load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def cond_op_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        if tl.program_id(0) == 0:
+            output = x + y
+        else:
+            output = x * y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def atomic_add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.atomic_add(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_4_times_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        for _ in range(2):
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+        i = 2
+        while i > 0:
+            i -= 1
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_out_of_order_fn2(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        out_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 16,
+                    "BLOCK_SIZE_N": 16,
+                    "BLOCK_SIZE_K": 16,
+                    "GROUP_SIZE_M": 4,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 128,
+                    "BLOCK_SIZE_N": 64,
+                    "BLOCK_SIZE_K": 32,
+                    "GROUP_SIZE_M": 8,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+        ],
+        key=["M_ptr", "N", "K"],
+    )
+    @triton.jit
+    def strange_config_matmul_kernel(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        M_ptr,
+        N,
+        K,
+        BLOCK_SIZE_M: tl.constexpr,
+        BLOCK_SIZE_N: tl.constexpr,
+        BLOCK_SIZE_K: tl.constexpr,
+        GROUP_SIZE_M: tl.constexpr,
+    ):
+        # This is a simplified matmul from Triton tutorial.
+        pid = tl.program_id(axis=0)
+        M = tl.load(M_ptr)
+        if M == 0 and BLOCK_SIZE_M > 32:
+            # This will run the full matmul if BLOCK_SIZE_M > 32
+            M = 4096
+        elif M == 0:
+            # This directly returns, which will cut short the bad config of 16-block size.
+            return
+        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+        num_pid_in_group = GROUP_SIZE_M * num_pid_n
+        group_id = pid // num_pid_in_group
+        first_pid_m = group_id * GROUP_SIZE_M
+        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+        pid_n = (pid % num_pid_in_group) // group_size_m
+
+        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+        offs_k = tl.arange(0, BLOCK_SIZE_K)
+        a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
+        b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])
+
+        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+        for k in range(tl.cdiv(K, BLOCK_SIZE_K)):
+            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+            accumulator = tl.dot(a, b, accumulator)
+            a_ptrs += BLOCK_SIZE_K
+            b_ptrs += BLOCK_SIZE_K
+        c = accumulator.to(tl.float16)
+
+        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+        c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
+        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+        tl.store(c_ptrs, c, mask=c_mask)
+
+    @triton.jit
+    def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        """
+        This kernel contains a triple-quote docstring w/ double quotes.
+        Make sure that codegen sanitizes the docstring.
+        """
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        '''
+        This kernel contains a triple-quote docstring w/ single quotes
+        Make sure that codegen sanitizes the docstring.
+        To prevent it from being linted to double quotes: """!!!"""
+        '''
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_double_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm="""
+            {
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                """,
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_single_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm='''
+            {
+                // double quotes to pacify the linter """!!!"""
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                ''',
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    @triton.jit
+    def add_kernel_with_boolean_param(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        add_xy,  # boolean param
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        if add_xy:
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    # support the old (experimental) and new (tensor_descriptor) APIs
+    def create_tensor_descriptor_shim(
+        tensor, block_sizes: list[int], new_api: bool = True
+    ):
+        if new_api:
+            return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+                tensor, block_sizes
+            )
+        else:
+            if len(block_sizes) == 1:
+                return triton.tools.experimental_descriptor.create_1d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    block_sizes[0],
+                    tensor.element_size(),
+                )
+            else:
+                assert len(block_sizes) == 2
+                return triton.tools.experimental_descriptor.create_2d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    tensor.size(1),
+                    block_sizes[0],
+                    block_sizes[1],
+                    tensor.element_size(),
+                )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8197829ac7f44f38d295995dd921ddf58b30adfd
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
@@ -0,0 +1,100 @@
+# mypy: ignore-errors
+
+import torch
+import torch.utils._pytree as pytree
+from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
+class TwoTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        assert (
+            a.device == b.device
+            and a.layout == b.layout
+            and a.requires_grad == b.requires_grad
+            and a.dtype == b.dtype
+        )
+        # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
+        shape = outer_size
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = requires_grad or a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+
+        assert a.shape == b.shape
+        assert a.stride() == b.stride()
+        assert a.storage_offset() == b.storage_offset()
+        return out
+
+    @torch._disable_dynamo
+    @mark_subclass_constructor_exportable_experimental
+    def __init__(self, a, b, outer_size=None, outer_stride=None, *, requires_grad=None):
+        self.a = a
+        self.b = b
+
+    def __repr__(self):
+        a_repr = repr(self.a)
+        b_repr = repr(self.b)
+        return f"TwoTensor({a_repr}, {b_repr})"
+
+    def __tensor_flatten__(self):
+        return ["a", "b"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a, b = inner_tensors["a"], inner_tensors["b"]
+        if type(a) is torch.Tensor:
+            assert outer_size is not None
+            assert outer_stride is not None
+        return TwoTensor(a, b, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
+        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
+
+        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
+        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_b = func(*args_b, **kwargs_b)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_b_flat = pytree.tree_leaves(out_b)
+        # for aten ops that return non-tensors, just assume that
+        # our two inner tensors return the same value
+        out_flat = [
+            cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a, o_b in zip(out_a_flat, out_b_flat, strict=True)
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def get_elem_a(self):
+        return self.a
+
+
+class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        out = func(*args, **kwargs)
+        if torch._subclasses.fake_tensor._is_tensor_constructor(func):
+            out = TwoTensor(out, out.clone())
+        return out